From 1867e2caa52c05eb7f5a8fea984c838cc874018d Mon Sep 17 00:00:00 2001 From: April Kyle Nassi Date: Wed, 28 Jul 2021 14:45:19 -0700 Subject: [PATCH 01/82] Update MAINTAINERS.md (#8352) moved 2 to emeritus list --- MAINTAINERS.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 5ff2c5157b5..5426f83f90b 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,11 +8,10 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [creamsoup](https://github.com/creamsoup), Google LLC + - [dapengzhang0](https://github.com/dapengzhang0), Google LLC - [ejona86](https://github.com/ejona86), Google LLC - [ericgribkoff](https://github.com/ericgribkoff), Google LLC -- [jiangtaoli2016](https://github.com/jiangtaoli2016), Google LLC - [ran-su](https://github.com/ran-su), Google LLC - [sanjaypujare](https://github.com/sanjaypujare), Google LLC - [sergiitk](https://github.com/sergiitk), Google LLC @@ -22,6 +21,8 @@ for general contribution guidelines. ## Emeritus Maintainers (in alphabetical order) - [carl-mastrangelo](https://github.com/carl-mastrangelo), Google LLC +- [creamsoup](https://github.com/creamsoup), Google LLC +- [jiangtaoli2016](https://github.com/jiangtaoli2016), Google LLC - [jtattermusch](https://github.com/jtattermusch), Google LLC - [louiscryan](https://github.com/louiscryan), Google LLC - [nicolasnoble](https://github.com/nicolasnoble), Google LLC From 343eed1c04ef80f455ed9c19c3ebced64156f1ad Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Wed, 28 Jul 2021 15:27:55 -0700 Subject: [PATCH 02/82] Start 1.41.0 development cycle (#8351) --- build.gradle | 2 +- .../src/test/golden/TestDeprecatedService.java.txt | 2 +- compiler/src/test/golden/TestService.java.txt | 2 +- .../src/testLite/golden/TestDeprecatedService.java.txt | 2 +- compiler/src/testLite/golden/TestService.java.txt | 2 +- core/src/main/java/io/grpc/internal/GrpcUtil.java | 2 +- examples/android/clientcache/app/build.gradle | 10 +++++----- examples/android/helloworld/app/build.gradle | 8 ++++---- examples/android/routeguide/app/build.gradle | 8 ++++---- examples/android/strictmode/app/build.gradle | 8 ++++---- examples/build.gradle | 2 +- examples/example-alts/build.gradle | 2 +- examples/example-gauth/build.gradle | 2 +- examples/example-gauth/pom.xml | 4 ++-- examples/example-hostname/build.gradle | 2 +- examples/example-hostname/pom.xml | 4 ++-- examples/example-jwt-auth/build.gradle | 2 +- examples/example-jwt-auth/pom.xml | 4 ++-- examples/example-tls/build.gradle | 2 +- examples/example-tls/pom.xml | 4 ++-- examples/example-xds/build.gradle | 2 +- examples/pom.xml | 4 ++-- 22 files changed, 40 insertions(+), 40 deletions(-) diff --git a/build.gradle b/build.gradle index 94590cfa97c..a77f3801f0e 100644 --- a/build.gradle +++ b/build.gradle @@ -18,7 +18,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.40.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.41.0-SNAPSHOT" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index ecf5e3889dd..018849586de 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 6abbd4732fc..18af83a9119 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 4de0f949c59..30f22366765 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index b2481063ca6..38626900571 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 45c0fce7122..4f8568c60aa 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -197,7 +197,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - private static final String IMPLEMENTATION_VERSION = "1.40.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + private static final String IMPLEMENTATION_VERSION = "1.41.0-SNAPSHOT"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 2d6cb8e0097..fab0934405e 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -34,7 +34,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.12' testImplementation 'com.google.truth:truth:1.0.1' - testImplementation 'io.grpc:grpc-testing:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 57490bc84ff..1fddbd6d481 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index ed7ae228853..250b10c3653 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index b117eccc12a..f68e8584ef0 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/build.gradle b/examples/build.gradle index 9b2226e38c6..03967a41c0d 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index e3f38649708..b1265e89440 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.17.2' dependencies { diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 708365e8973..f7f332e6962 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index cf901838401..1cc4c3ba6a7 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT 3.17.2 1.7 diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index 9ff9210ab7d..44048a78e34 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,7 +21,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' dependencies { diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 17e954bf67e..9af512c3952 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT 3.17.2 1.7 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index c3593a2dd08..cf59da51d4a 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index 6e76f72b1e2..c4ae09cda90 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,13 +7,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT 3.17.2 3.17.2 diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index a272fc30e6a..61f13e050de 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.17.2' dependencies { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index ec49708c2b0..d83a0937725 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT example-tls https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT 3.17.2 2.0.34.Final diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index dfa860ec0d0..01ef4ba9266 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION def nettyTcNativeVersion = '2.0.31.Final' def protocVersion = '3.17.2' diff --git a/examples/pom.xml b/examples/pom.xml index 3ba96b6c197..156b11fb7ac 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT examples https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0-SNAPSHOT 3.17.2 3.17.2 From d836f38979c92ba2dcbce96c46ec14238208f8e2 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 28 Jul 2021 18:59:50 -0700 Subject: [PATCH 03/82] core: add real transport test for retry buffer limit (#8354) The unit tests in RetriableStreamTest do not really test buffer limit from end to end, because the buffer limit is implemented using ClientStreamTracer.Factory, and the tracer callback outboundMessageSize() is only triggered in AbstractClientStream after message serialization. In fact, it was broken without failing any existing tests (#8343 (comment)) This PR adds a retry buffer limit test that runs through the AbstractClientStream code path. --- interop-testing/build.gradle | 1 + .../grpc/testing/integration/RetryTest.java | 160 ++++++++++++++++++ 2 files changed, 161 insertions(+) create mode 100644 interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 79aa5356ecd..14c92a9fd1d 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -43,6 +43,7 @@ dependencies { libraries.netty_tcnative, project(':grpc-grpclb') testImplementation project(':grpc-context').sourceSets.test.output, + project(':grpc-api').sourceSets.test.output, libraries.mockito alpnagent libraries.jetty_alpn_agent } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java new file mode 100644 index 00000000000..4824f05313b --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -0,0 +1,160 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.IntegerMarshaller; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StringMarshaller; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.testing.GrpcCleanupRule; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class RetryTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + @Mock + private ClientCall.Listener mockCallListener; + + @Test + public void retryUntilBufferLimitExceeded() throws Exception { + String message = "String of length 20."; + int bufferLimit = message.length() * 2 - 1; // Can buffer no more than 1 message. + + MethodDescriptor clientStreamingMethod = + MethodDescriptor.newBuilder() + .setType(MethodType.CLIENT_STREAMING) + .setFullMethodName("service/method") + .setRequestMarshaller(new StringMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + final LinkedBlockingQueue> serverCalls = + new LinkedBlockingQueue<>(); + ServerMethodDefinition methodDefinition = ServerMethodDefinition.create( + clientStreamingMethod, + new ServerCallHandler() { + @Override + public Listener startCall(ServerCall call, Metadata headers) { + serverCalls.offer(call); + return new Listener() {}; + } + } + ); + ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder(clientStreamingMethod.getServiceName()) + .addMethod(methodDefinition) + .build(); + EventLoopGroup group = new DefaultEventLoopGroup(); + LocalAddress localAddress = new LocalAddress("RetryTest.retryUntilBufferLimitExceeded"); + Server localServer = cleanupRule.register(NettyServerBuilder.forAddress(localAddress) + .channelType(LocalServerChannel.class) + .bossEventLoopGroup(group) + .workerEventLoopGroup(group) + .addService(serviceDefinition) + .build()); + localServer.start(); + + Map retryPolicy = new HashMap<>(); + retryPolicy.put("maxAttempts", 4D); + retryPolicy.put("initialBackoff", "10s"); + retryPolicy.put("maxBackoff", "10s"); + retryPolicy.put("backoffMultiplier", 1D); + retryPolicy.put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")); + Map methodConfig = new HashMap<>(); + Map name = new HashMap<>(); + name.put("service", "service"); + methodConfig.put("name", Arrays.asList(name)); + methodConfig.put("retryPolicy", retryPolicy); + Map rawServiceConfig = new HashMap<>(); + rawServiceConfig.put("methodConfig", Arrays.asList(methodConfig)); + ManagedChannel channel = cleanupRule.register( + NettyChannelBuilder.forAddress(localAddress) + .channelType(LocalChannel.class) + .eventLoopGroup(group) + .usePlaintext() + .enableRetry() + .perRpcBufferLimit(bufferLimit) + .defaultServiceConfig(rawServiceConfig) + .build()); + ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + call.sendMessage(message); + + ServerCall serverCall = serverCalls.poll(5, TimeUnit.SECONDS); + serverCall.request(2); + // trigger retry + Metadata pushBackMetadata = new Metadata(); + pushBackMetadata.put( + Metadata.Key.of("grpc-retry-pushback-ms", Metadata.ASCII_STRING_MARSHALLER), + "0"); // retry immediately + serverCall.close( + Status.UNAVAILABLE.withDescription("original attempt failed"), + pushBackMetadata); + // 2nd attempt received + serverCall = serverCalls.poll(5, TimeUnit.SECONDS); + serverCall.request(2); + verify(mockCallListener, never()).onClose(any(Status.class), any(Metadata.class)); + // send one more message, should exceed buffer limit + call.sendMessage(message); + // let attempt fail + serverCall.close( + Status.UNAVAILABLE.withDescription("2nd attempt failed"), + new Metadata()); + // no more retry + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + verify(mockCallListener, timeout(5000)).onClose(statusCaptor.capture(), any(Metadata.class)); + assertThat(statusCaptor.getValue().getDescription()).contains("2nd attempt failed"); + } +} From b2764595e6f5e49ba9dc7ebb39cdb5d266c71ccf Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 30 Jul 2021 13:24:05 -0700 Subject: [PATCH 04/82] netty: Refine workaround for Netty header processing for transparent retries Nginx and C core don't do graceful GOAWAY and retries have matured such that transparent retries may soon be on by default. Refining the workaround thus can reduces error rate for users. Fixes #8310 --- .../main/java/io/grpc/netty/NettyClientHandler.java | 10 ++++++---- .../java/io/grpc/netty/NettyClientHandlerTest.java | 11 ++++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index d263356204e..22d8fcadb75 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -822,6 +822,7 @@ private void goingAway(long errorCode, byte[] debugData) { // UNAVAILABLE. https://github.com/netty/netty/issues/10670 final Status abruptGoAwayStatusConservative = statusFromH2Error( null, "Abrupt GOAWAY closed sent stream", errorCode, debugData); + final boolean mayBeHittingNettyBug = errorCode != Http2Error.NO_ERROR.code(); // Try to allocate as many in-flight streams as possible, to reduce race window of // https://github.com/grpc/grpc-java/issues/2562 . To be of any help, the server has to // gracefully shut down the connection with two GOAWAYs. gRPC servers generally send a PING @@ -848,11 +849,12 @@ public boolean visit(Http2Stream stream) throws Http2Exception { if (clientStream != null) { // RpcProgress _should_ be REFUSED, but are being conservative. See comment for // abruptGoAwayStatusConservative. This does reduce our ability to perform transparent - // retries, but our main goal of transporent retries is to resolve the local race. We - // still hope/expect servers to use the graceful double-GOAWAY when closing - // connections. + // retries, but only if something else caused a connection failure. + RpcProgress progress = mayBeHittingNettyBug + ? RpcProgress.PROCESSED + : RpcProgress.REFUSED; clientStream.transportReportStatus( - abruptGoAwayStatusConservative, RpcProgress.PROCESSED, false, new Metadata()); + abruptGoAwayStatusConservative, progress, false, new Metadata()); } stream.close(); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 5f1d27c37e2..d0d48fe9b48 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -331,9 +331,18 @@ public void inboundShouldForwardToStream() throws Exception { } @Test - public void receivedGoAwayShouldRefuseLaterStreamId() throws Exception { + public void receivedGoAwayNoErrorShouldRefuseLaterStreamId() throws Exception { ChannelFuture future = enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); channelRead(goAwayFrame(streamId - 1)); + verify(streamListener).closed(any(Status.class), eq(REFUSED), any(Metadata.class)); + assertTrue(future.isDone()); + } + + @Test + public void receivedGoAwayErrorShouldRefuseLaterStreamId() throws Exception { + ChannelFuture future = enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); + channelRead( + goAwayFrame(streamId - 1, (int) Http2Error.PROTOCOL_ERROR.code(), Unpooled.EMPTY_BUFFER)); // This _should_ be REFUSED, but we purposefully use PROCESSED. See comment for // abruptGoAwayStatusConservative in NettyClientHandler verify(streamListener).closed(any(Status.class), eq(PROCESSED), any(Metadata.class)); From 860e97d12ab46e184ac4448eeca8a45671de6462 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Sat, 31 Jul 2021 18:33:02 -0700 Subject: [PATCH 05/82] all: API refactoring in preparation to support retry stats (#8355) Rebased PR #8343 into the first commit of this PR, then (the 2nd commit) reverted the part for metric recording of retry attempts. The PR as a whole is mechanical refactoring. No behavior change (except that some of the old code path when tracer is created is moved into the new method `streamCreated()`). The API change is documented in go/grpc-stats-api-change-for-retry-java --- .../main/java/io/grpc/ClientStreamTracer.java | 82 ++++-- .../test/java/io/grpc/CallOptionsTest.java | 6 +- .../grpc/binder/internal/BinderTransport.java | 21 +- .../io/grpc/census/CensusStatsModule.java | 57 +++-- .../io/grpc/census/CensusTracingModule.java | 21 +- .../io/grpc/census/CensusModulesTest.java | 71 ++++-- .../internal/StatsTraceContextBenchmark.java | 5 +- .../io/grpc/inprocess/InProcessTransport.java | 24 +- ...llCredentialsApplyingTransportFactory.java | 12 +- .../java/io/grpc/internal/ClientCallImpl.java | 5 +- .../io/grpc/internal/ClientTransport.java | 8 +- .../grpc/internal/DelayedClientTransport.java | 35 ++- .../java/io/grpc/internal/DelayedStream.java | 4 + .../io/grpc/internal/FailingClientStream.java | 13 +- .../grpc/internal/FailingClientTransport.java | 6 +- .../ForwardingClientStreamTracer.java | 101 ++++++++ .../ForwardingConnectionClientTransport.java | 6 +- .../main/java/io/grpc/internal/GrpcUtil.java | 77 +++++- .../io/grpc/internal/InternalSubchannel.java | 6 +- .../io/grpc/internal/ManagedChannelImpl.java | 13 +- .../io/grpc/internal/MetadataApplierImpl.java | 10 +- .../java/io/grpc/internal/OobChannel.java | 5 +- .../io/grpc/internal/RetriableStream.java | 20 +- .../io/grpc/internal/StatsTraceContext.java | 20 +- .../io/grpc/internal/SubchannelChannel.java | 5 +- .../util/ForwardingClientStreamTracer.java | 6 + .../java/io/grpc/ClientStreamTracerTest.java | 4 + .../grpc/internal/AbstractTransportTest.java | 181 +++++++------ .../CallCredentials2ApplyingTest.java | 67 +++-- .../internal/CallCredentialsApplyingTest.java | 102 +++++--- .../io/grpc/internal/ClientCallImplTest.java | 7 +- .../internal/DelayedClientTransportTest.java | 116 ++++++--- .../internal/FailingClientStreamTest.java | 8 +- .../internal/FailingClientTransportTest.java | 6 +- .../ForwardingClientStreamTracerTest.java | 49 ++++ .../java/io/grpc/internal/GrpcUtilTest.java | 51 +++- .../grpc/internal/ManagedChannelImplTest.java | 239 ++++++++++++------ .../io/grpc/internal/RetriableStreamTest.java | 5 +- .../test/java/io/grpc/internal/TestUtils.java | 5 +- .../ForwardingClientStreamTracerTest.java | 1 + .../io/grpc/cronet/CronetClientTransport.java | 5 +- .../grpc/cronet/CronetChannelBuilderTest.java | 9 +- .../cronet/CronetClientTransportTest.java | 9 +- .../grpc/grpclb/GrpclbClientLoadRecorder.java | 2 +- .../grpclb/TokenAttachingTracerFactory.java | 36 ++- .../grpc/grpclb/GrpclbLoadBalancerTest.java | 3 + .../TokenAttachingTracerFactoryTest.java | 47 ++-- .../integration/AbstractInteropTest.java | 14 +- .../io/grpc/netty/NettyClientTransport.java | 8 +- .../grpc/netty/NettyClientTransportTest.java | 5 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 12 +- .../okhttp/OkHttpClientTransportTest.java | 123 ++++----- .../io/grpc/xds/ClusterImplLoadBalancer.java | 5 +- .../java/io/grpc/xds/OrcaPerRequestUtil.java | 7 +- .../grpc/xds/ClusterImplLoadBalancerTest.java | 8 +- 55 files changed, 1210 insertions(+), 563 deletions(-) create mode 100644 core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java create mode 100644 core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java index 6259522487a..6a5d3cc3397 100644 --- a/api/src/main/java/io/grpc/ClientStreamTracer.java +++ b/api/src/main/java/io/grpc/ClientStreamTracer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; -import io.grpc.Grpc; import javax.annotation.concurrent.ThreadSafe; /** @@ -28,6 +27,18 @@ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") @ThreadSafe public abstract class ClientStreamTracer extends StreamTracer { + + /** + * The stream is being created on a ready transport. + * + * @param headers the mutable initial metadata. Modifications to it will be sent to the socket but + * not be seen by client interceptors and the application. + * + * @since 1.40.0 + */ + public void streamCreated(@Grpc.TransportAttr Attributes transportAttrs, Metadata headers) { + } + /** * Headers has been sent to the socket. */ @@ -54,22 +65,6 @@ public void inboundTrailers(Metadata trailers) { * Factory class for {@link ClientStreamTracer}. */ public abstract static class Factory { - /** - * Creates a {@link ClientStreamTracer} for a new client stream. - * - * @param callOptions the effective CallOptions of the call - * @param headers the mutable headers of the stream. It can be safely mutated within this - * method. It should not be saved because it is not safe for read or write after the - * method returns. - * - * @deprecated use {@link - * #newClientStreamTracer(io.grpc.ClientStreamTracer.StreamInfo, io.grpc.Metadata)} instead. - */ - @Deprecated - public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) { - throw new UnsupportedOperationException("Not implemented"); - } - /** * Creates a {@link ClientStreamTracer} for a new client stream. This is called inside the * transport when it's creating the stream. @@ -81,12 +76,15 @@ public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadat * * @since 1.20.0 */ - @SuppressWarnings("deprecation") public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { - return newClientStreamTracer(info.getCallOptions(), headers); + throw new UnsupportedOperationException("Not implemented"); } } + /** An abstract class for internal use only. */ + @Internal + public abstract static class InternalLimitedInfoFactory extends Factory {} + /** * Information about a stream. * @@ -99,15 +97,21 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header public static final class StreamInfo { private final Attributes transportAttrs; private final CallOptions callOptions; + private final boolean isTransparentRetry; - StreamInfo(Attributes transportAttrs, CallOptions callOptions) { + StreamInfo(Attributes transportAttrs, CallOptions callOptions, boolean isTransparentRetry) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs"); this.callOptions = checkNotNull(callOptions, "callOptions"); + this.isTransparentRetry = isTransparentRetry; } /** * Returns the attributes of the transport that this stream was created on. + * + * @deprecated Use {@link ClientStreamTracer#streamCreated(Attributes, Metadata)} to handle + * the transport Attributes instead. */ + @Deprecated @Grpc.TransportAttr public Attributes getTransportAttrs() { return transportAttrs; @@ -120,16 +124,25 @@ public CallOptions getCallOptions() { return callOptions; } + /** + * Whether the stream is a transparent retry. + * + * @since 1.40.0 + */ + public boolean isTransparentRetry() { + return isTransparentRetry; + } + /** * Converts this StreamInfo into a new Builder. * * @since 1.21.0 */ public Builder toBuilder() { - Builder builder = new Builder(); - builder.setTransportAttrs(transportAttrs); - builder.setCallOptions(callOptions); - return builder; + return new Builder() + .setCallOptions(callOptions) + .setTransportAttrs(transportAttrs) + .setIsTransparentRetry(isTransparentRetry); } /** @@ -146,6 +159,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("transportAttrs", transportAttrs) .add("callOptions", callOptions) + .add("isTransparentRetry", isTransparentRetry) .toString(); } @@ -157,6 +171,7 @@ public String toString() { public static final class Builder { private Attributes transportAttrs = Attributes.EMPTY; private CallOptions callOptions = CallOptions.DEFAULT; + private boolean isTransparentRetry; Builder() { } @@ -164,9 +179,12 @@ public static final class Builder { /** * Sets the attributes of the transport that this stream was created on. This field is * optional. + * + * @deprecated Use {@link ClientStreamTracer#streamCreated(Attributes, Metadata)} to handle + * the transport Attributes instead. */ - @Grpc.TransportAttr - public Builder setTransportAttrs(Attributes transportAttrs) { + @Deprecated + public Builder setTransportAttrs(@Grpc.TransportAttr Attributes transportAttrs) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs cannot be null"); return this; } @@ -179,11 +197,21 @@ public Builder setCallOptions(CallOptions callOptions) { return this; } + /** + * Sets whether the stream is a transparent retry. + * + * @since 1.40.0 + */ + public Builder setIsTransparentRetry(boolean isTransparentRetry) { + this.isTransparentRetry = isTransparentRetry; + return this; + } + /** * Builds a new StreamInfo. */ public StreamInfo build() { - return new StreamInfo(transportAttrs, callOptions); + return new StreamInfo(transportAttrs, callOptions, isTransparentRetry); } } } diff --git a/api/src/test/java/io/grpc/CallOptionsTest.java b/api/src/test/java/io/grpc/CallOptionsTest.java index 31861306891..0bc0d357358 100644 --- a/api/src/test/java/io/grpc/CallOptionsTest.java +++ b/api/src/test/java/io/grpc/CallOptionsTest.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.mock; import com.google.common.base.Objects; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.internal.SerializingExecutor; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -271,7 +272,7 @@ public void increment(long period, TimeUnit unit) { } } - private static class FakeTracerFactory extends ClientStreamTracer.Factory { + private static class FakeTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { final String name; FakeTracerFactory(String name) { @@ -279,8 +280,7 @@ private static class FakeTracerFactory extends ClientStreamTracer.Factory { } @Override - public ClientStreamTracer newClientStreamTracer( - ClientStreamTracer.StreamInfo info, Metadata headers) { + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { return new ClientStreamTracer() {}; } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 04070ddfcef..b132844069c 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -32,6 +32,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.Internal; import io.grpc.InternalChannelz.SocketStats; @@ -632,28 +633,28 @@ public synchronized Runnable start(ManagedClientTransport.Listener clientTranspo public synchronized ClientStream newStream( final MethodDescriptor method, final Metadata headers, - final CallOptions callOptions) { + final CallOptions callOptions, + ClientStreamTracer[] tracers) { if (isShutdown()) { - return newFailingClientStream(shutdownStatus, callOptions, attributes, headers); + return newFailingClientStream(shutdownStatus, attributes, headers, tracers); } else { int callId = latestCallId++; if (latestCallId == LAST_CALL_ID) { latestCallId = FIRST_CALL_ID; } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); Inbound.ClientInbound inbound = new Inbound.ClientInbound( this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); if (ongoingCalls.putIfAbsent(callId, inbound) != null) { Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); shutdownInternal(failure, true); - return newFailingClientStream(failure, callOptions, attributes, headers); + return newFailingClientStream(failure, attributes, headers, tracers); } else { if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { clientTransportListener.transportInUse(true); } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(callOptions, attributes, headers); - Outbound.ClientOutbound outbound = new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); if (method.getType().clientSendsOneMessage()) { @@ -763,12 +764,12 @@ protected void handlePingResponse(Parcel parcel) { } private static ClientStream newFailingClientStream( - Status failure, CallOptions callOptions, Attributes attributes, Metadata headers) { + Status failure, Attributes attributes, Metadata headers, + ClientStreamTracer[] tracers) { StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(callOptions, attributes, headers); + StatsTraceContext.newClientContext(tracers, attributes, headers); statsTraceContext.clientOutboundHeaders(); - statsTraceContext.streamClosed(failure); - return new FailingClientStream(failure); + return new FailingClientStream(failure, tracers); } private static InternalLogId buildLogId( diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index d625a6f5c6f..ac5f4e705e3 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -22,6 +22,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -138,15 +139,6 @@ public TagContext parseBytes(byte[] serialized) { }); } - /** - * Creates a {@link ClientCallTracer} for a new call. - */ - @VisibleForTesting - ClientCallTracer newClientCallTracer( - TagContext parentCtx, String fullMethodName) { - return new ClientCallTracer(this, parentCtx, fullMethodName); - } - /** * Returns the server tracer factory. */ @@ -231,6 +223,7 @@ private static final class ClientTracer extends ClientStreamTracer { } private final CensusStatsModule module; + final TagContext parentCtx; private final TagContext startCtx; volatile long outboundMessageCount; @@ -240,11 +233,22 @@ private static final class ClientTracer extends ClientStreamTracer { volatile long outboundUncompressedSize; volatile long inboundUncompressedSize; - ClientTracer(CensusStatsModule module, TagContext startCtx) { + ClientTracer(CensusStatsModule module, TagContext parentCtx, TagContext startCtx) { this.module = checkNotNull(module, "module"); + this.parentCtx = parentCtx; this.startCtx = checkNotNull(startCtx, "startCtx"); } + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + if (module.propagateTags) { + headers.discardAll(module.statsHeader); + if (!module.tagger.empty().equals(parentCtx)) { + headers.put(module.statsHeader, parentCtx); + } + } + } + @Override @SuppressWarnings("NonAtomicVolatileUpdate") public void outboundWireSize(long bytes) { @@ -315,12 +319,14 @@ public void outboundMessage(int seqNo) { } @VisibleForTesting - static final class ClientCallTracer extends ClientStreamTracer.Factory { + static final class CallAttemptsTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { @Nullable - private static final AtomicReferenceFieldUpdater + private static final AtomicReferenceFieldUpdater streamTracerUpdater; - @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Nullable + private static final AtomicIntegerFieldUpdater callEndedUpdater; /** * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their @@ -328,14 +334,14 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { * (potentially racy) direct updates of the volatile variables. */ static { - AtomicReferenceFieldUpdater tmpStreamTracerUpdater; - AtomicIntegerFieldUpdater tmpCallEndedUpdater; + AtomicReferenceFieldUpdater tmpStreamTracerUpdater; + AtomicIntegerFieldUpdater tmpCallEndedUpdater; try { tmpStreamTracerUpdater = AtomicReferenceFieldUpdater.newUpdater( - ClientCallTracer.class, ClientTracer.class, "streamTracer"); + CallAttemptsTracerFactory.class, ClientTracer.class, "streamTracer"); tmpCallEndedUpdater = - AtomicIntegerFieldUpdater.newUpdater(ClientCallTracer.class, "callEnded"); + AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); } catch (Throwable t) { logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); tmpStreamTracerUpdater = null; @@ -352,7 +358,8 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { private final TagContext parentCtx; private final TagContext startCtx; - ClientCallTracer(CensusStatsModule module, TagContext parentCtx, String fullMethodName) { + CallAttemptsTracerFactory( + CensusStatsModule module, TagContext parentCtx, String fullMethodName) { this.module = checkNotNull(module); this.parentCtx = checkNotNull(parentCtx); TagValue methodTag = TagValue.create(fullMethodName); @@ -370,7 +377,7 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - ClientTracer tracer = new ClientTracer(module, startCtx); + ClientTracer tracer = new ClientTracer(module, parentCtx, startCtx); // TODO(zhangkun83): Once retry or hedging is implemented, a ClientCall may start more than // one streams. We will need to update this file to support them. if (streamTracerUpdater != null) { @@ -383,12 +390,6 @@ public ClientStreamTracer newClientStreamTracer( "Are you creating multiple streams per call? This class doesn't yet support this case"); streamTracer = tracer; } - if (module.propagateTags) { - headers.discardAll(module.statsHeader); - if (!module.tagger.empty().equals(parentCtx)) { - headers.put(module.statsHeader, parentCtx); - } - } return tracer; } @@ -416,7 +417,7 @@ void callEnded(Status status) { long roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); ClientTracer tracer = streamTracer; if (tracer == null) { - tracer = new ClientTracer(module, startCtx); + tracer = new ClientTracer(module, parentCtx, startCtx); } MeasureMap measureMap = module.statsRecorder.newMeasureMap() // TODO(songya): remove the deprecated measure constants once they are completed removed. @@ -686,8 +687,8 @@ public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { // New RPCs on client-side inherit the tag context from the current Context. TagContext parentCtx = tagger.getCurrentTagContext(); - final ClientCallTracer tracerFactory = - newClientCallTracer(parentCtx, method.getFullMethodName()); + final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( + CensusStatsModule.this, parentCtx, method.getFullMethodName()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { diff --git a/census/src/main/java/io/grpc/census/CensusTracingModule.java b/census/src/main/java/io/grpc/census/CensusTracingModule.java index fc35d89db55..dac62206fd2 100644 --- a/census/src/main/java/io/grpc/census/CensusTracingModule.java +++ b/census/src/main/java/io/grpc/census/CensusTracingModule.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -222,7 +223,7 @@ private static void recordMessageEvent( } @VisibleForTesting - final class ClientCallTracer extends ClientStreamTracer.Factory { + final class ClientCallTracer extends ClientStreamTracer.InternalLimitedInfoFactory { volatile int callEnded; private final boolean isSampledToLocalTracing; @@ -243,11 +244,7 @@ final class ClientCallTracer extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - if (span != BlankSpan.INSTANCE) { - headers.discardAll(tracingHeader); - headers.put(tracingHeader, span.getContext()); - } - return new ClientTracer(span); + return new ClientTracer(span, tracingHeader); } /** @@ -273,9 +270,19 @@ void callEnded(io.grpc.Status status) { private static final class ClientTracer extends ClientStreamTracer { private final Span span; + final Metadata.Key tracingHeader; - ClientTracer(Span span) { + ClientTracer(Span span, Metadata.Key tracingHeader) { this.span = checkNotNull(span, "span"); + this.tracingHeader = tracingHeader; + } + + @Override + public void streamCreated(Attributes transportAtts, Metadata headers) { + if (span != BlankSpan.INSTANCE) { + headers.discardAll(tracingHeader); + headers.put(tracingHeader, span.getContext()); + } } @Override diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index fbbcd44150c..fd3a049f7a4 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -295,7 +295,7 @@ public ClientCall interceptCall( instanceof CensusTracingModule.ClientCallTracer); assertTrue( capturedCallOptions.get().getStreamTracerFactories().get(1) - instanceof CensusStatsModule.ClientCallTracer); + instanceof CensusStatsModule.CallAttemptsTracerFactory); // Make the call Metadata headers = new Metadata(); @@ -388,11 +388,12 @@ private void subtestClientBasicStatsDefaultContext( new CensusStatsModule( tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true, recordStarts, recordFinishes, recordRealTime); - CensusStatsModule.ClientCallTracer callTracer = - localCensusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); Metadata headers = new Metadata(); - ClientStreamTracer tracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); if (recordStarts) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -455,7 +456,7 @@ private void subtestClientBasicStatsDefaultContext( tracer.inboundUncompressedSize(552); tracer.streamClosed(Status.OK); - callTracer.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); if (recordFinishes) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -522,6 +523,7 @@ public void clientBasicTracingDefaultSpan() { censusTracing.newClientCallTracer(null, method); Metadata headers = new Metadata(); ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), ArgumentMatchers.isNull()); verify(spyClientSpan, never()).end(any(EndSpanOptions.class)); @@ -575,11 +577,15 @@ public void clientTracingSampledToLocalSpanStore() { @Test public void clientStreamNeverCreatedStillRecordStats() { - CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); - + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + censusStats, tagger.empty(), method.getFullMethodName()); + ClientStreamTracer streamTracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); fakeClock.forwardTime(3000, MILLISECONDS); - callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); + Status status = Status.DEADLINE_EXCEEDED.withDescription("3 seconds"); + streamTracer.streamClosed(status); + callAttemptsTracerFactory.callEnded(status); // Upstart record StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -680,10 +686,13 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS fakeClock.getStopwatchSupplier(), propagate, recordStats, recordStats, recordStats); Metadata headers = new Metadata(); - CensusStatsModule.ClientCallTracer callTracer = - census.newClientCallTracer(clientCtx, method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + census, clientCtx, method.getFullMethodName()); // This propagates clientCtx to headers if propagates==true - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); if (recordStats) { // Client upstart record StatsTestUtils.MetricsRecord clientRecord = statsRecorder.pollRecord(); @@ -746,7 +755,8 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS // Verifies that the client tracer factory uses clientCtx, which includes the custom tags, to // record stats. - callTracer.callEnded(Status.OK); + streamTracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); if (recordStats) { // Client completion record @@ -769,10 +779,12 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS @Test public void statsHeadersNotPropagateDefaultContext() { - CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + censusStats, tagger.empty(), method.getFullMethodName()); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers) + .streamCreated(Attributes.EMPTY, headers); assertFalse(headers.containsKey(censusStats.statsHeader)); // Clear recorded stats to satisfy the assertions in wrapUp() statsRecorder.rolloverRecords(); @@ -803,7 +815,8 @@ public void traceHeadersPropagateSpanContext() throws Exception { CensusTracingModule.ClientCallTracer callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); verify(mockTracingPropagationHandler).toByteArray(same(fakeClientSpanContext)); verifyNoMoreInteractions(mockTracingPropagationHandler); @@ -831,7 +844,8 @@ public void traceHeaders_propagateSpanContext() throws Exception { censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).isNotEmpty(); } @@ -845,7 +859,7 @@ public void traceHeaders_missingCensusImpl_notPropagateSpanContext() CensusTracingModule.ClientCallTracer callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).isEmpty(); } @@ -862,7 +876,7 @@ public void traceHeaders_clientMissingCensusImpl_preservingHeaders() throws Exce CensusTracingModule.ClientCallTracer callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).containsExactlyElementsIn(originalHeaderKeys); } @@ -1186,13 +1200,18 @@ public void newTagsPopulateOldViews() throws InterruptedException { tagger, tagCtxSerializer, localStats.getStatsRecorder(), fakeClock.getStopwatchSupplier(), false, false, true, false /* real-time */); - CensusStatsModule.ClientCallTracer callTracer = - localCensusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); - callTracer.newClientStreamTracer(STREAM_INFO, new Metadata()); + Metadata headers = new Metadata(); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); + tracer.streamCreated(Attributes.EMPTY, headers); fakeClock.forwardTime(30, MILLISECONDS); - callTracer.callEnded(Status.PERMISSION_DENIED.withDescription("No you don't")); + Status status = Status.PERMISSION_DENIED.withDescription("No you don't"); + tracer.streamClosed(status); + callAttemptsTracerFactory.callEnded(status); // Give OpenCensus a chance to update the views asynchronously. Thread.sleep(100); diff --git a/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java b/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java index aec2659f024..4d4349eef1b 100644 --- a/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java +++ b/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java @@ -17,7 +17,7 @@ package io.grpc.internal; import io.grpc.Attributes; -import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; @@ -50,7 +50,8 @@ public class StatsTraceContextBenchmark { @BenchmarkMode(Mode.SampleTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) public StatsTraceContext newClientContext() { - return StatsTraceContext.newClientContext(CallOptions.DEFAULT, Attributes.EMPTY, emptyMetadata); + return StatsTraceContext.newClientContext( + new ClientStreamTracer[] { new ClientStreamTracer() {} }, Attributes.EMPTY, emptyMetadata); } /** diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 58df4371e72..895b709559b 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -26,6 +26,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Compressor; import io.grpc.Deadline; import io.grpc.Decompressor; @@ -205,10 +206,12 @@ public void run() { @Override public synchronized ClientStream newStream( - final MethodDescriptor method, final Metadata headers, final CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); if (shutdownStatus != null) { - return failedClientStream( - StatsTraceContext.newClientContext(callOptions, attributes, headers), shutdownStatus); + return failedClientStream(statsTraceContext, shutdownStatus); } headers.put(GrpcUtil.USER_AGENT_KEY, userAgent); @@ -226,12 +229,12 @@ public synchronized ClientStream newStream( "Request metadata larger than %d: %d", serverMaxInboundMetadataSize, metadataSize)); - return failedClientStream( - StatsTraceContext.newClientContext(callOptions, attributes, headers), status); + return failedClientStream(statsTraceContext, status); } } - return new InProcessStream(method, headers, callOptions, authority).clientStream; + return new InProcessStream(method, headers, callOptions, authority, statsTraceContext) + .clientStream; } private ClientStream failedClientStream( @@ -377,12 +380,12 @@ private class InProcessStream { private InProcessStream( MethodDescriptor method, Metadata headers, CallOptions callOptions, - String authority) { + String authority , StatsTraceContext statsTraceContext) { this.method = checkNotNull(method, "method"); this.headers = checkNotNull(headers, "headers"); this.callOptions = checkNotNull(callOptions, "callOptions"); this.authority = authority; - this.clientStream = new InProcessClientStream(callOptions, headers); + this.clientStream = new InProcessClientStream(callOptions, statsTraceContext); this.serverStream = new InProcessServerStream(method, headers); } @@ -673,9 +676,10 @@ private class InProcessClientStream implements ClientStream { @GuardedBy("this") private int outboundSeqNo; - InProcessClientStream(CallOptions callOptions, Metadata headers) { + InProcessClientStream( + CallOptions callOptions, StatsTraceContext statsTraceContext) { this.callOptions = callOptions; - statsTraceCtx = StatsTraceContext.newClientContext(callOptions, attributes, headers); + statsTraceCtx = statsTraceContext; } private synchronized void setListener(ServerStreamListener listener) { diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 0b1ce3514a2..6b6472825d2 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -25,6 +25,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.CompositeCallCredentials; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -104,7 +105,8 @@ protected ConnectionClientTransport delegate() { @Override @SuppressWarnings("deprecation") public ClientStream newStream( - final MethodDescriptor method, Metadata headers, final CallOptions callOptions) { + final MethodDescriptor method, Metadata headers, final CallOptions callOptions, + ClientStreamTracer[] tracers) { CallCredentials creds = callOptions.getCredentials(); if (creds == null) { creds = channelCallCredentials; @@ -113,10 +115,10 @@ public ClientStream newStream( } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions, applierListener); + delegate, method, headers, callOptions, applierListener, tracers); if (pendingApplier.incrementAndGet() > 0) { applierListener.onComplete(); - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } RequestInfo requestInfo = new RequestInfo() { @Override @@ -152,9 +154,9 @@ public Attributes getTransportAttrs() { return applier.returnStream(); } else { if (pendingApplier.get() >= 0) { - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } - return delegate.newStream(method, headers, callOptions); + return delegate.newStream(method, headers, callOptions, tracers); } } diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index c2e1bd2b1f2..28cd3351203 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -33,6 +33,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.CompressorRegistry; @@ -254,9 +255,11 @@ public void runInContext() { effectiveDeadline, context.getDeadline(), callOptions.getDeadline()); stream = clientStreamProvider.newStream(method, callOptions, headers, context); } else { + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers(callOptions, headers, false); stream = new FailingClientStream( DEADLINE_EXCEEDED.withDescription( - "ClientCall started after deadline exceeded: " + effectiveDeadline)); + "ClientCall started after deadline exceeded: " + effectiveDeadline), + tracers); } if (callExecutorIsDirect) { diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index cc8471ab6a3..a569a7922df 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -17,6 +17,7 @@ package io.grpc.internal; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.Metadata; @@ -46,10 +47,15 @@ public interface ClientTransport extends InternalInstrumented { * @param method the descriptor of the remote method to be called for this stream. * @param headers to send at the beginning of the call * @param callOptions runtime options of the call + * @param tracers a non-empty array of tracers. The last element in it is reserved to be set by + * the load balancer's pick result and otherwise is a no-op tracer. * @return the newly created stream. */ // TODO(nmittler): Consider also throwing for stopping. - ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions callOptions); + ClientStream newStream( + MethodDescriptor method, Metadata headers, CallOptions callOptions, + // Using array for tracers instead of a list or composition for better performance. + ClientStreamTracer[] tracers); /** * Pings a remote endpoint. When an acknowledgement is received, the given callback will be diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 6a72eb7c21e..2b1145d1c4b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; @@ -133,7 +134,8 @@ public void run() { */ @Override public final ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { try { PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions); SubchannelPicker picker = null; @@ -141,14 +143,14 @@ public final ClientStream newStream( while (true) { synchronized (lock) { if (shutdownStatus != null) { - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } if (lastPicker == null) { - return createPendingStream(args); + return createPendingStream(args, tracers); } // Check for second time through the loop, and whether anything changed if (picker != null && pickerVersion == lastPickerVersion) { - return createPendingStream(args); + return createPendingStream(args, tracers); } picker = lastPicker; pickerVersion = lastPickerVersion; @@ -158,7 +160,8 @@ public final ClientStream newStream( callOptions.isWaitForReady()); if (transport != null) { return transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions()); + args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + tracers); } // This picker's conclusion is "buffer". If there hasn't been a newer picker set (possible // race with reprocess()), we will buffer it. Otherwise, will try with the new picker. @@ -173,8 +176,9 @@ public final ClientStream newStream( * schedule tasks on syncContext. */ @GuardedBy("lock") - private PendingStream createPendingStream(PickSubchannelArgs args) { - PendingStream pendingStream = new PendingStream(args); + private PendingStream createPendingStream( + PickSubchannelArgs args, ClientStreamTracer[] tracers) { + PendingStream pendingStream = new PendingStream(args, tracers); pendingStreams.add(pendingStream); if (getPendingStreamsCount() == 1) { syncContext.executeLater(reportTransportInUse); @@ -239,7 +243,8 @@ public final void shutdownNow(Status status) { } if (savedReportTransportTerminated != null) { for (PendingStream stream : savedPendingStreams) { - Runnable runnable = stream.setStream(new FailingClientStream(status, RpcProgress.REFUSED)); + Runnable runnable = stream.setStream( + new FailingClientStream(status, RpcProgress.REFUSED, stream.tracers)); if (runnable != null) { // Drain in-line instead of using an executor as failing stream just throws everything // away. This is essentially the same behavior as DelayedStream.cancel() but can be done @@ -346,9 +351,11 @@ public InternalLogId getLogId() { private class PendingStream extends DelayedStream { private final PickSubchannelArgs args; private final Context context = Context.current(); + private final ClientStreamTracer[] tracers; - private PendingStream(PickSubchannelArgs args) { + private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { this.args = args; + this.tracers = tracers; } /** Runnable may be null. */ @@ -357,7 +364,8 @@ private Runnable createRealStream(ClientTransport transport) { Context origContext = context.attach(); try { realStream = transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions()); + args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + tracers); } finally { context.detach(origContext); } @@ -382,6 +390,13 @@ public void cancel(Status reason) { syncContext.drain(); } + @Override + protected void onEarlyCancellation(Status reason) { + for (ClientStreamTracer tracer : tracers) { + tracer.streamClosed(reason); + } + } + @Override public void appendTimeoutInsight(InsightBuilder insight) { if (args.getCallOptions().isWaitForReady()) { diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index f0a378e8124..28ce2764c75 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -324,11 +324,15 @@ public void run() { }); } else { drainPendingCalls(); + onEarlyCancellation(reason); // Note that listener is a DelayedStreamListener listener.closed(reason, RpcProgress.PROCESSED, new Metadata()); } } + protected void onEarlyCancellation(Status reason) { + } + @GuardedBy("this") private void setRealStream(ClientStream realStream) { checkState(this.realStream == null, "realStream already set to %s", this.realStream); diff --git a/core/src/main/java/io/grpc/internal/FailingClientStream.java b/core/src/main/java/io/grpc/internal/FailingClientStream.java index 6d368b6975f..6388ef8b6ee 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientStream.java +++ b/core/src/main/java/io/grpc/internal/FailingClientStream.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -30,27 +31,33 @@ public final class FailingClientStream extends NoopClientStream { private boolean started; private final Status error; private final RpcProgress rpcProgress; + private final ClientStreamTracer[] tracers; /** * Creates a {@code FailingClientStream} that would fail with the given error. */ - public FailingClientStream(Status error) { - this(error, RpcProgress.PROCESSED); + public FailingClientStream(Status error, ClientStreamTracer[] tracers) { + this(error, RpcProgress.PROCESSED, tracers); } /** * Creates a {@code FailingClientStream} that would fail with the given error. */ - public FailingClientStream(Status error, RpcProgress rpcProgress) { + public FailingClientStream( + Status error, RpcProgress rpcProgress, ClientStreamTracer[] tracers) { Preconditions.checkArgument(!error.isOk(), "error must not be OK"); this.error = error; this.rpcProgress = rpcProgress; + this.tracers = tracers; } @Override public void start(ClientStreamListener listener) { Preconditions.checkState(!started, "already started"); started = true; + for (ClientStreamTracer tracer : tracers) { + tracer.streamClosed(error); + } listener.closed(error, rpcProgress, new Metadata()); } diff --git a/core/src/main/java/io/grpc/internal/FailingClientTransport.java b/core/src/main/java/io/grpc/internal/FailingClientTransport.java index 25d20017c92..5b31e6e5073 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientTransport.java +++ b/core/src/main/java/io/grpc/internal/FailingClientTransport.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -45,8 +46,9 @@ class FailingClientTransport implements ClientTransport { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return new FailingClientStream(error, rpcProgress); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + return new FailingClientStream(error, rpcProgress, tracers); } @Override diff --git a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java new file mode 100644 index 00000000000..fd03564d396 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java @@ -0,0 +1,101 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import com.google.common.base.MoreObjects; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.Status; + +public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { + + /** + * Returns the underlying {@code ClientStreamTracer}. + */ + protected abstract ClientStreamTracer delegate(); + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + delegate().streamCreated(transportAttrs, headers); + } + + @Override + public void outboundHeaders() { + delegate().outboundHeaders(); + } + + @Override + public void inboundHeaders() { + delegate().inboundHeaders(); + } + + @Override + public void inboundTrailers(Metadata trailers) { + delegate().inboundTrailers(trailers); + } + + @Override + public void streamClosed(Status status) { + delegate().streamClosed(status); + } + + @Override + public void outboundMessage(int seqNo) { + delegate().outboundMessage(seqNo); + } + + @Override + public void inboundMessage(int seqNo) { + delegate().inboundMessage(seqNo); + } + + @Override + public void outboundMessageSent(int seqNo, long optionalWireSize, long optionalUncompressedSize) { + delegate().outboundMessageSent(seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead(int seqNo, long optionalWireSize, long optionalUncompressedSize) { + delegate().inboundMessageRead(seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void outboundWireSize(long bytes) { + delegate().outboundWireSize(bytes); + } + + @Override + public void outboundUncompressedSize(long bytes) { + delegate().outboundUncompressedSize(bytes); + } + + @Override + public void inboundWireSize(long bytes) { + delegate().inboundWireSize(bytes); + } + + @Override + public void inboundUncompressedSize(long bytes) { + delegate().inboundUncompressedSize(bytes); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); + } +} diff --git a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java index e54f8b169d6..bfdccbe5d6a 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -45,8 +46,9 @@ public void shutdownNow(Status status) { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return delegate().newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + return delegate().newStream(method, headers, callOptions, tracers); } @Override diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 4f8568c60aa..0e11e5ece39 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; @@ -26,8 +27,11 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.InternalLimitedInfoFactory; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; @@ -54,12 +58,14 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Collection; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -253,6 +259,8 @@ public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) { public static final CallOptions.Key CALL_OPTIONS_RPC_OWNED_BY_BALANCER = CallOptions.Key.create("io.grpc.internal.CALL_OPTIONS_RPC_OWNED_BY_BALANCER"); + private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; + /** * Returns true if an RPC with the given properties should be counted when calculating the * in-use state of a transport. @@ -711,9 +719,14 @@ static ClientTransport getTransportFromPickResult(PickResult result, boolean isW return new ClientTransport() { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return transport.newStream( - method, headers, callOptions.withStreamTracerFactory(streamTracerFactory)); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + StreamInfo info = StreamInfo.newBuilder().setCallOptions(callOptions).build(); + ClientStreamTracer streamTracer = + newClientStreamTracer(streamTracerFactory, info, headers); + checkState(tracers[tracers.length - 1] == NOOP_TRACER, "lb tracer already assigned"); + tracers[tracers.length - 1] = streamTracer; + return transport.newStream(method, headers, callOptions, tracers); } @Override @@ -743,6 +756,64 @@ public ListenableFuture getStats() { return null; } + /** Gets stream tracers based on CallOptions. */ + public static ClientStreamTracer[] getClientStreamTracers( + CallOptions callOptions, Metadata headers, boolean isTransparentRetry) { + List factories = callOptions.getStreamTracerFactories(); + ClientStreamTracer[] tracers = new ClientStreamTracer[factories.size() + 1]; + StreamInfo streamInfo = StreamInfo.newBuilder() + .setCallOptions(callOptions) + .setIsTransparentRetry(isTransparentRetry) + .build(); + for (int i = 0; i < factories.size(); i++) { + tracers[i] = newClientStreamTracer(factories.get(i), streamInfo, headers); + } + // Reserved to be set later by the lb as per the API contract of ClientTransport.newStream(). + // See also GrpcUtil.getTransportFromPickResult() + tracers[tracers.length - 1] = NOOP_TRACER; + return tracers; + } + + // A util function for backward compatibility to support deprecated StreamInfo.getAttributes(). + @VisibleForTesting + static ClientStreamTracer newClientStreamTracer( + final ClientStreamTracer.Factory streamTracerFactory, final StreamInfo info, + final Metadata headers) { + ClientStreamTracer streamTracer; + if (streamTracerFactory instanceof InternalLimitedInfoFactory) { + streamTracer = streamTracerFactory.newClientStreamTracer(info, headers); + } else { + streamTracer = new ForwardingClientStreamTracer() { + final ClientStreamTracer noop = new ClientStreamTracer() {}; + AtomicReference delegate = new AtomicReference<>(noop); + + void maybeInit(StreamInfo info, Metadata headers) { + delegate.compareAndSet(noop, streamTracerFactory.newClientStreamTracer(info, headers)); + } + + @Override + protected ClientStreamTracer delegate() { + return delegate.get(); + } + + @SuppressWarnings("deprecation") + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + StreamInfo streamInfo = info.toBuilder().setTransportAttrs(transportAttrs).build(); + maybeInit(streamInfo, headers); + delegate().streamCreated(transportAttrs, headers); + } + + @Override + public void streamClosed(Status status) { + maybeInit(info, headers); + delegate().streamClosed(status); + } + }; + } + return streamTracer; + } + /** Quietly closes all messages in MessageProducer. */ static void closeQuietly(MessageProducer producer) { InputStream message; diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index 331add6c8c4..fa2bf2e46bc 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -34,6 +34,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; @@ -667,8 +668,9 @@ protected ConnectionClientTransport delegate() { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - final ClientStream streamDelegate = super.newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + final ClientStream streamDelegate = super.newStream(method, headers, callOptions, tracers); return new ForwardingClientStream() { @Override protected ClientStream delegate() { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index a9d24cd247a..87162d9aba2 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -532,8 +532,10 @@ public ClientStream newStream( ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, headers, callOptions)); Context origContext = context.attach(); + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, /* isTransparentRetry= */ false); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } @@ -569,13 +571,16 @@ void postCommit() { } @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata newHeaders) { - CallOptions newOptions = callOptions.withStreamTracerFactory(tracerFactory); + ClientStream newSubstream( + Metadata newHeaders, ClientStreamTracer.Factory factory, boolean isTransparentRetry) { + CallOptions newOptions = callOptions.withStreamTracerFactory(factory); + ClientStreamTracer[] tracers = + GrpcUtil.getClientStreamTracers(newOptions, newHeaders, isTransparentRetry); ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, newHeaders, newOptions)); Context origContext = context.attach(); try { - return transport.newStream(method, newHeaders, newOptions); + return transport.newStream(method, newHeaders, newOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 76d280b2d00..6893713c1d2 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -22,6 +22,7 @@ import io.grpc.CallCredentials.MetadataApplier; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -36,7 +37,7 @@ final class MetadataApplierImpl extends MetadataApplier { private final CallOptions callOptions; private final Context ctx; private final MetadataApplierListener listener; - + private final ClientStreamTracer[] tracers; private final Object lock = new Object(); // null if neither apply() or returnStream() are called. @@ -52,13 +53,14 @@ final class MetadataApplierImpl extends MetadataApplier { MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, - CallOptions callOptions, MetadataApplierListener listener) { + CallOptions callOptions, MetadataApplierListener listener, ClientStreamTracer[] tracers) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); this.listener = listener; + this.tracers = tracers; } @Override @@ -69,7 +71,7 @@ public void apply(Metadata headers) { ClientStream realStream; Context origCtx = ctx.attach(); try { - realStream = transport.newStream(method, origHeaders, callOptions); + realStream = transport.newStream(method, origHeaders, callOptions, tracers); } finally { ctx.detach(origCtx); } @@ -80,7 +82,7 @@ public void apply(Metadata headers) { public void fail(Status status) { checkArgument(!status.isOk(), "Cannot fail with OK status"); checkState(!finalized, "apply() or fail() already called"); - finalizeWith(new FailingClientStream(status)); + finalizeWith(new FailingClientStream(status, tracers)); } private void finalizeWith(ClientStream stream) { diff --git a/core/src/main/java/io/grpc/internal/OobChannel.java b/core/src/main/java/io/grpc/internal/OobChannel.java index f69fd17e5c4..b628842efe4 100644 --- a/core/src/main/java/io/grpc/internal/OobChannel.java +++ b/core/src/main/java/io/grpc/internal/OobChannel.java @@ -26,6 +26,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Context; @@ -86,12 +87,14 @@ final class OobChannel extends ManagedChannel implements InternalInstrumented method, CallOptions callOptions, Metadata headers, Context context) { + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, /* isTransparentRetry= */ false); Context origContext = context.attach(); // delayed transport's newStream() always acquires a lock, but concurrent performance doesn't // matter here because OOB communication should be sparse, and it's not on application RPC's // critical path. try { - return delayedTransport.newStream(method, headers, callOptions); + return delayedTransport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 9d752b86576..23725788466 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -203,11 +203,11 @@ private void commitAndRun(Substream winningSubstream) { } } - private Substream createSubstream(int previousAttemptCount) { + private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { Substream sub = new Substream(previousAttemptCount); // one tracer per substream final ClientStreamTracer bufferSizeTracer = new BufferSizeTracer(sub); - ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -217,7 +217,7 @@ public ClientStreamTracer newClientStreamTracer( Metadata newHeaders = updateHeaders(headers, previousAttemptCount); // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newSubstream(tracerFactory, newHeaders); + sub.stream = newSubstream(newHeaders, tracerFactory, isTransparentRetry); return sub; } @@ -226,7 +226,7 @@ public ClientStreamTracer newClientStreamTracer( * Client stream is not yet started. */ abstract ClientStream newSubstream( - ClientStreamTracer.Factory tracerFactory, Metadata headers); + Metadata headers, ClientStreamTracer.Factory tracerFactory, boolean isTransparentRetry); /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting @@ -322,7 +322,7 @@ public void runWith(Substream substream) { state.buffer.add(new StartEntry()); } - Substream substream = createSubstream(0); + Substream substream = createSubstream(0, false); if (isHedging) { FutureCanceller scheduledHedgingRef = null; @@ -399,7 +399,7 @@ public void run() { // If this run is not cancelled, the value of state.hedgingAttemptCount won't change // until state.addActiveHedge() is called subsequently, even the state could possibly // change. - Substream newSubstream = createSubstream(state.hedgingAttemptCount); + Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); boolean cancelled = false; FutureCanceller future = null; @@ -784,8 +784,7 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (rpcProgress == RpcProgress.REFUSED && noMoreTransparentRetry.compareAndSet(false, true)) { // transparent retry - final Substream newSubstream = createSubstream( - substream.previousAttemptCount); + final Substream newSubstream = createSubstream(substream.previousAttemptCount, true); if (isHedging) { boolean commit = false; synchronized (lock) { @@ -863,8 +862,9 @@ public void run() { @Override public void run() { // retry - Substream newSubstream = - createSubstream(substream.previousAttemptCount + 1); + Substream newSubstream = createSubstream( + substream.previousAttemptCount + 1, + false); drain(newSubstream); } }); diff --git a/core/src/main/java/io/grpc/internal/StatsTraceContext.java b/core/src/main/java/io/grpc/internal/StatsTraceContext.java index adb0b63ec8a..33e84e5a0b8 100644 --- a/core/src/main/java/io/grpc/internal/StatsTraceContext.java +++ b/core/src/main/java/io/grpc/internal/StatsTraceContext.java @@ -20,7 +20,6 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; -import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; @@ -48,21 +47,12 @@ public final class StatsTraceContext { * Factory method for the client-side. */ public static StatsTraceContext newClientContext( - final CallOptions callOptions, final Attributes transportAttrs, Metadata headers) { - List factories = callOptions.getStreamTracerFactories(); - if (factories.isEmpty()) { - return NOOP; + ClientStreamTracer[] tracers, Attributes transportAtts, Metadata headers) { + StatsTraceContext ctx = new StatsTraceContext(tracers); + for (ClientStreamTracer tracer : tracers) { + tracer.streamCreated(transportAtts, headers); } - ClientStreamTracer.StreamInfo info = - ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs(transportAttrs).setCallOptions(callOptions).build(); - // This array will be iterated multiple times per RPC. Use primitive array instead of Collection - // so that for-each doesn't create an Iterator every time. - StreamTracer[] tracers = new StreamTracer[factories.size()]; - for (int i = 0; i < tracers.length; i++) { - tracers[i] = factories.get(i).newClientStreamTracer(info, headers); - } - return new StatsTraceContext(tracers); + return ctx; } /** diff --git a/core/src/main/java/io/grpc/internal/SubchannelChannel.java b/core/src/main/java/io/grpc/internal/SubchannelChannel.java index 6c316e4f185..1380a6bc716 100644 --- a/core/src/main/java/io/grpc/internal/SubchannelChannel.java +++ b/core/src/main/java/io/grpc/internal/SubchannelChannel.java @@ -22,6 +22,7 @@ import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.InternalConfigSelector; import io.grpc.Metadata; @@ -57,9 +58,11 @@ public ClientStream newStream(MethodDescriptor method, if (transport == null) { transport = notReadyTransport; } + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, /* isTransparentRetry= */ false); Context origContext = context.attach(); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java index de7d12e397c..7bb9d8cf71a 100644 --- a/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java +++ b/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java @@ -17,6 +17,7 @@ package io.grpc.util; import com.google.common.base.MoreObjects; +import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.ExperimentalApi; import io.grpc.Metadata; @@ -27,6 +28,11 @@ public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { /** Returns the underlying {@code ClientStreamTracer}. */ protected abstract ClientStreamTracer delegate(); + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + delegate().streamCreated(transportAttrs, headers); + } + @Override public void outboundHeaders() { delegate().outboundHeaders(); diff --git a/core/src/test/java/io/grpc/ClientStreamTracerTest.java b/core/src/test/java/io/grpc/ClientStreamTracerTest.java index 2008a3de5c7..df450adc630 100644 --- a/core/src/test/java/io/grpc/ClientStreamTracerTest.java +++ b/core/src/test/java/io/grpc/ClientStreamTracerTest.java @@ -34,6 +34,7 @@ public class ClientStreamTracerTest { Attributes.newBuilder().set(TRANSPORT_ATTR_KEY, "value").build(); @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_empty() { StreamInfo info = StreamInfo.newBuilder().build(); assertThat(info.getCallOptions()).isSameInstanceAs(CallOptions.DEFAULT); @@ -41,6 +42,7 @@ public void streamInfo_empty() { } @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_withInfo() { StreamInfo info = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); @@ -49,6 +51,7 @@ public void streamInfo_withInfo() { } @Test + @SuppressWarnings("deprecation") // info.setTransportAttrs() public void streamInfo_noEquality() { StreamInfo info1 = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); @@ -60,6 +63,7 @@ public void streamInfo_noEquality() { } @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_toBuilder() { StreamInfo info1 = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index 091415efadc..cd522181311 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -48,7 +48,6 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; -import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Grpc; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -172,7 +171,7 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} .setRequestMarshaller(StringMarshaller.INSTANCE) .setResponseMarshaller(StringMarshaller.INSTANCE) .build(); - private CallOptions callOptions; + private final CallOptions callOptions = CallOptions.DEFAULT; private Metadata.Key asciiKey = Metadata.Key.of( "ascii-key", Metadata.ASCII_STRING_MARSHALLER); @@ -186,24 +185,14 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} = mock(ManagedClientTransport.Listener.class); private MockServerListener serverListener = new MockServerListener(); private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private final TestClientStreamTracer clientStreamTracer1 = new TestClientStreamTracer(); - private final TestClientStreamTracer clientStreamTracer2 = new TestClientStreamTracer(); - private final ClientStreamTracer.Factory clientStreamTracerFactory = mock( - ClientStreamTracer.Factory.class, - delegatesTo(new ClientStreamTracer.Factory() { - final ArrayDeque tracers = - new ArrayDeque<>(Arrays.asList(clientStreamTracer1, clientStreamTracer2)); - - @Override - public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { - metadata.put(tracerHeaderKey, tracerKeyValue); - TestClientStreamTracer tracer = tracers.poll(); - if (tracer != null) { - return tracer; - } - return new TestClientStreamTracer(); - } - })); + private final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); + private final TestClientStreamTracer clientStreamTracer2 = new TestHeaderClientStreamTracer(); + private final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + clientStreamTracer1, clientStreamTracer2 + }; + private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer(); @@ -230,7 +219,6 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata @Before public void setUp() { server = newServer(Arrays.asList(serverStreamTracerFactory)); - callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); } @After @@ -291,7 +279,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { // after having sent a RST_STREAM to the server. Previously, this would have broken the // Netty channel. - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -314,7 +303,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { // Test that the channel is still usable i.e. we can receive headers from the server on a // new stream. - stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); stream.start(mockClientStreamListener2); serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); @@ -449,7 +439,8 @@ public void openStreamPreventsTermination() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -501,7 +492,8 @@ public void shutdownNowKillsClientStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -539,7 +531,8 @@ public void shutdownNowKillsServerStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -594,7 +587,8 @@ public void ping_duringShutdown() throws Exception { client = newClientTransport(server); startTransport(client, mockClientTransportListener); // Stream prevents termination - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); @@ -633,22 +627,19 @@ public void ping_afterTermination() throws Exception { @Test public void newStream_duringShutdown() throws Exception { - InOrder inOrder = inOrder(clientStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); // Stream prevents termination - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); - inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); - inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); Status clientStreamStatus2 = @@ -683,15 +674,14 @@ public void newStream_afterTermination() throws Exception { client.shutdown(shutdownReason); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); Thread.sleep(100); - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); assertEquals( shutdownReason, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); - verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(shutdownReason, clientStreamTracer1.getStatus()); // Assert no interactions @@ -708,7 +698,8 @@ public void transportInUse_balancerRpcsNotCounted() throws Exception { // CALL_OPTIONS_RPC_OWNED_BY_BALANCER in CallOptions. It won't be counted for in-use signal. ClientStream stream1 = client.newStream( methodDescriptor, new Metadata(), - callOptions.withOption(GrpcUtil.CALL_OPTIONS_RPC_OWNED_BY_BALANCER, Boolean.TRUE)); + callOptions.withOption(GrpcUtil.CALL_OPTIONS_RPC_OWNED_BY_BALANCER, Boolean.TRUE), + noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); MockServerTransportListener serverTransportListener @@ -717,7 +708,8 @@ methodDescriptor, new Metadata(), = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); // stream2 is the normal RPC, and will be counted for in-use - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -743,7 +735,8 @@ public void transportInUse_normalClose() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream1 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -751,7 +744,8 @@ public void transportInUse_normalClose() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); StreamCreation serverStreamCreation1 = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); StreamCreation serverStreamCreation2 @@ -773,11 +767,13 @@ public void transportInUse_clientCancel() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream1 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); @@ -792,7 +788,6 @@ public void transportInUse_clientCancel() throws Exception { @Test public void basicStream() throws Exception { - InOrder clientInOrder = inOrder(clientStreamTracerFactory); InOrder serverInOrder = inOrder(serverStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); @@ -816,14 +811,10 @@ public void basicStream() throws Exception { Metadata clientHeadersCopy = new Metadata(); clientHeadersCopy.merge(clientHeaders); - ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); - ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); - clientInOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - streamInfoCaptor.capture(), same(clientHeaders)); - ClientStreamTracer.StreamInfo streamInfo = streamInfoCaptor.getValue(); - assertThat(streamInfo.getTransportAttrs()).isSameInstanceAs( - ((ConnectionClientTransport) client).getAttributes()); - assertThat(streamInfo.getCallOptions()).isSameInstanceAs(callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, clientHeaders, callOptions, tracers); + assertThat(((TestHeaderClientStreamTracer) clientStreamTracer1).transportAttrs) + .isSameInstanceAs(((ConnectionClientTransport) client).getAttributes()); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -974,7 +965,8 @@ public void authorityPropagation() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientHeaders = new Metadata(); - ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, clientHeaders, callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1005,7 +997,8 @@ public void zeroMessageStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1044,7 +1037,8 @@ public void earlyServerClose_withServerHeaders() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1080,7 +1074,8 @@ public void earlyServerClose_noServerHeaders() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1122,7 +1117,8 @@ public void earlyServerClose_serverFailure() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1155,7 +1151,8 @@ public void earlyServerClose_serverFailure_withClientCancelOnListenerClosed() th serverTransport = serverTransportListener.transport; final ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase() { @Override @@ -1196,7 +1193,8 @@ public void clientCancel() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1230,7 +1228,8 @@ public void clientCancelFromWithinMessageRead() throws Exception { final SettableFuture closedCalled = SettableFuture.create(); final ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); final Status status = Status.CANCELLED.withDescription("nevermind"); clientStream.start(new ClientStreamListener() { private boolean messageReceived = false; @@ -1311,7 +1310,8 @@ public void serverCancel() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1331,8 +1331,6 @@ public void serverCancel() throws Exception { // Cause should not be transmitted between server and client assertNull(clientStreamStatus.getCause()); - verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); assertTrue(clientStreamTracer1.getOutboundHeaders()); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); @@ -1353,7 +1351,8 @@ public void flowControlPushBack() throws Exception { serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = @@ -1515,7 +1514,8 @@ public void interactionsAfterServerStreamCloseAreNoops() throws Exception { // boilerplate ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation server @@ -1547,7 +1547,8 @@ public void interactionsAfterClientStreamCancelAreNoops() throws Exception { // boilerplate ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListener clientListener = mock(ClientStreamListener.class); clientStream.start(clientListener); StreamCreation server @@ -1594,7 +1595,8 @@ public void transportTracer_streamStarted() throws Exception { assertEquals(0, clientBefore.streamsStarted); assertEquals(0, clientBefore.lastRemoteStreamCreatedTimeNanos); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener @@ -1624,7 +1626,8 @@ public void transportTracer_streamStarted() throws Exception { TransportStats clientBefore = getTransportStats(client); assertEquals(1, clientBefore.streamsStarted); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener @@ -1654,7 +1657,8 @@ public void transportTracer_server_streamEnded_ok() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1693,7 +1697,8 @@ public void transportTracer_server_streamEnded_nonOk() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1733,7 +1738,8 @@ public void transportTracer_client_streamEnded_nonOk() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener = @@ -1768,7 +1774,8 @@ public void transportTracer_server_receive_msg() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1809,7 +1816,8 @@ public void transportTracer_server_send_msg() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1849,7 +1857,8 @@ public void socketStats() throws Exception { server.start(serverListener); ManagedClientTransport client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1896,8 +1905,8 @@ public void serverChecksInboundMetadataSize() throws Exception { Metadata.Key.of("foo-bin", Metadata.BINARY_BYTE_MARSHALLER), new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); - ClientStream clientStream = - client.newStream(methodDescriptor, tooLargeMetadata, callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, tooLargeMetadata, callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1931,7 +1940,8 @@ public void clientChecksInboundMetadataSize_header() throws Exception { new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1975,7 +1985,8 @@ public void clientChecksInboundMetadataSize_trailer() throws Exception { new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -2011,7 +2022,9 @@ private void doPingPong(MockServerListener serverListener) throws Exception { ManagedClientTransport client = newClientTransport(server); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); startTransport(client, listener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, + new ClientStreamTracer[] { new ClientStreamTracer() {} }); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -2092,6 +2105,16 @@ private static void startTransport( verify(listener, timeout(TIMEOUT_MS)).transportReady(); } + private final class TestHeaderClientStreamTracer extends TestClientStreamTracer { + Attributes transportAttrs; + + @Override + public void streamCreated(Attributes transportAttrs, Metadata metadata) { + this.transportAttrs = transportAttrs; + metadata.put(tracerHeaderKey, tracerKeyValue); + } + } + private static class MockServerListener implements ServerListener { public final BlockingQueue listeners = new LinkedBlockingQueue<>(); diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index 7725c46726b..963a586319b 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -34,6 +34,7 @@ import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -48,6 +49,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; @@ -103,6 +105,9 @@ public class CallCredentials2ApplyingTest { private static final Metadata.Key CREDS_KEY = Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER); private static final String CREDS_VALUE = "some credentials"; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final Metadata origHeaders = new Metadata(); private ForwardingConnectionClientTransport transport; @@ -118,7 +123,9 @@ public void setUp() { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, null, mockExecutor); @@ -134,7 +141,7 @@ public void parameterPropagation_base() { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -155,7 +162,7 @@ public void parameterPropagation_transportSetSecurityLevel() { .build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -176,8 +183,10 @@ public void parameterPropagation_callOptionsSetAuthority() { when(mockTransport.getAttributes()).thenReturn(transportAttrs); Executor anotherExecutor = mock(Executor.class); - transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -199,13 +208,15 @@ public void credentialThrows() { any(io.grpc.CallCredentials2.MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -226,14 +237,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { any(RequestInfo.class), same(mockExecutor), any(io.grpc.CallCredentials2.MetadataApplier.class)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -254,12 +265,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { any(io.grpc.CallCredentials2.MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertSame(error, stream.getError()); transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @@ -269,12 +282,15 @@ public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); transport.shutdown(Status.UNAVAILABLE); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); @@ -283,11 +299,11 @@ public void applyMetadata_delayed() { headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -297,7 +313,8 @@ public void fail_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -306,11 +323,13 @@ public void fail_delayed() { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -318,14 +337,14 @@ public void fail_delayed() { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 61a221f73de..ef49e66bf2d 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -35,6 +35,7 @@ import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -49,6 +50,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; @@ -86,6 +88,9 @@ public class CallCredentialsApplyingTest { @Mock private ChannelLogger channelLogger; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private static final String AUTHORITY = "testauthority"; private static final String USER_AGENT = "testuseragent"; private static final Attributes.Key ATTR_KEY = Attributes.Key.create("somekey"); @@ -117,7 +122,9 @@ public void setUp() { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, null, mockExecutor); @@ -133,7 +140,7 @@ public void parameterPropagation_base() { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), same(mockExecutor), @@ -154,8 +161,10 @@ public void parameterPropagation_overrideByCallOptions() { when(mockTransport.getAttributes()).thenReturn(transportAttrs); Executor anotherExecutor = mock(Executor.class); - transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), @@ -175,15 +184,17 @@ public void credentialThrows() { any(RequestInfo.class), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -193,14 +204,15 @@ public void applyMetadata_inline() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -220,13 +232,15 @@ public Void answer(InvocationOnMock invocation) throws Throwable { }).when(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertSame(error, stream.getError()); transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @@ -236,23 +250,26 @@ public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); transport.shutdown(Status.UNAVAILABLE); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); @@ -261,20 +278,20 @@ public void applyMetadata_delayed() { @Test public void delayedShutdown_shutdownShutdownNowThenApply() { - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); transport.shutdownNow(Status.ABORTED); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(any(Status.class)); verify(mockTransport, never()).shutdownNow(any(Status.class)); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); verify(mockTransport).shutdownNow(Status.ABORTED); @@ -282,12 +299,12 @@ public void delayedShutdown_shutdownShutdownNowThenApply() { @Test public void delayedShutdown_shutdownThenApplyThenShutdownNow() { - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(any(Status.class)); Metadata headers = new Metadata(); @@ -308,25 +325,25 @@ public void delayedShutdown_shutdownMulti() { Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); - transport.newStream(method, origHeaders, callOptions); - transport.newStream(method, origHeaders, callOptions); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); + transport.newStream(method, origHeaders, callOptions, tracers); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); applierCaptor.getAllValues().get(1).apply(headers); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); applierCaptor.getAllValues().get(0).apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); applierCaptor.getAllValues().get(2).apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -336,7 +353,8 @@ public void fail_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), @@ -345,11 +363,13 @@ public void fail_delayed() { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -357,14 +377,15 @@ public void fail_delayed() { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -373,7 +394,8 @@ public void noCreds() { public void justCallOptionCreds() { callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); @@ -388,7 +410,8 @@ public void justChannelCreds() { transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); @@ -406,7 +429,8 @@ public void callOptionAndChanelCreds() { String creds2Value = "some more credentials"; callOptions = callOptions.withCallCredentials(new FakeCallCredentials(creds2Key, creds2Value)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 1808a4bd478..0e5e5f50599 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -37,6 +37,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; @@ -47,6 +48,7 @@ import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Codec; import io.grpc.Context; import io.grpc.Deadline; @@ -143,6 +145,8 @@ public void setUp() { any(Metadata.class), any(Context.class))) .thenReturn(stream); + when(streamTracerFactory.newClientStreamTracer(any(StreamInfo.class), any(Metadata.class))) + .thenReturn(new ClientStreamTracer() {}); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock in) { @@ -156,7 +160,7 @@ public Void answer(InvocationOnMock in) { @After public void tearDown() { - verifyNoInteractions(streamTracerFactory); + verifyNoMoreInteractions(streamTracerFactory); } @Test @@ -763,6 +767,7 @@ public void deadlineExceededBeforeCallStarted() { channelCallTracer, configSelector) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); + verify(streamTracerFactory).newClientStreamTracer(any(StreamInfo.class), any(Metadata.class)); verify(clientStreamProvider, never()) .newStream( (MethodDescriptor) any(MethodDescriptor.class), diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 9f48b8987d1..4cae565a19e 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -36,6 +36,7 @@ import static org.mockito.Mockito.when; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -57,6 +58,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -89,6 +91,9 @@ public class DelayedClientTransportTest { = CallOptions.Key.createWithDefault("shard-id", -1); private static final Status SHUTDOWN_STATUS = Status.UNAVAILABLE.withDescription("shutdown called"); + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final MethodDescriptor method = MethodDescriptor.newBuilder() @@ -122,9 +127,13 @@ public void uncaughtException(Thread t, Throwable e) { .thenReturn(PickResult.withSubchannel(mockSubchannel)); when(mockSubchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); when(mockInternalSubchannel.obtainActiveTransport()).thenReturn(mockRealTransport); - when(mockRealTransport.newStream(same(method), same(headers), same(callOptions))) + when(mockRealTransport.newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any())) .thenReturn(mockRealStream); - when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2))) + when(mockRealTransport2.newStream( + same(method2), same(headers2), same(callOptions2), + ArgumentMatchers.any())) .thenReturn(mockRealStream2); delayedTransport.start(transportListener); } @@ -135,7 +144,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void streamStartThenAssignTransport() { assertFalse(delayedTransport.hasPendingStreams()); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(delayedTransport.hasPendingStreams()); @@ -145,7 +155,9 @@ public void uncaughtException(Thread t, Throwable e) { assertEquals(0, delayedTransport.getPendingStreamsCount()); assertFalse(delayedTransport.hasPendingStreams()); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); verify(mockRealStream).start(listenerCaptor.capture()); verifyNoMoreInteractions(streamListener); listenerCaptor.getValue().onReady(); @@ -154,7 +166,7 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenAssignTransportThenShutdown() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.reprocess(mockPicker); @@ -163,7 +175,9 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); assertEquals(0, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); } @@ -181,11 +195,13 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); assertEquals(0, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof FailingClientStream); verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test public void assignTransportThenShutdownNowThenNewStream() { @@ -193,15 +209,18 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); assertEquals(0, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof FailingClientStream); verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test public void startThenCancelStreamWithoutSetTransport() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.cancel(Status.CANCELLED); @@ -213,7 +232,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenShutdownTransportThenAssignTransport() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); delayedTransport.shutdown(SHUTDOWN_STATUS); @@ -225,7 +245,8 @@ public void uncaughtException(Thread t, Throwable e) { // ... and will proceed if a real transport is available delayedTransport.reprocess(mockPicker); fakeExecutor.runDueTasks(); - verify(mockRealTransport).newStream(method, headers, callOptions); + verify(mockRealTransport).newStream( + method, headers, callOptions, tracers); verify(mockRealStream).start(any(ClientStreamListener.class)); // Since no more streams are pending, delayed transport is now terminated @@ -233,7 +254,8 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportTerminated(); // Further newStream() will return a failing stream - stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); verify(streamListener, never()).closed( any(Status.class), any(RpcProgress.class), any(Metadata.class)); stream.start(streamListener); @@ -247,7 +269,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenShutdownTransportThenCancelStream() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); @@ -264,7 +287,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); @@ -272,7 +296,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void startStreamThenShutdownNow() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); @@ -286,7 +311,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); @@ -301,55 +327,59 @@ public void uncaughtException(Thread t, Throwable e) { AbstractSubchannel subchannel1 = mock(AbstractSubchannel.class); AbstractSubchannel subchannel2 = mock(AbstractSubchannel.class); AbstractSubchannel subchannel3 = mock(AbstractSubchannel.class); - when(mockRealTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream); - when(mockRealTransport2.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream2); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + when(mockRealTransport2.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream2); when(subchannel1.getInternalSubchannel()).thenReturn(newTransportProvider(mockRealTransport)); when(subchannel2.getInternalSubchannel()).thenReturn(newTransportProvider(mockRealTransport2)); when(subchannel3.getInternalSubchannel()).thenReturn(newTransportProvider(null)); // Fail-fast streams DelayedStream ff1 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions); + method, headers, failFastCallOptions, tracers); ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); PickSubchannelArgsImpl ff1args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions); + method2, headers2, failFastCallOptions, tracers); PickSubchannelArgsImpl ff2args = new PickSubchannelArgsImpl(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions); + method, headers, failFastCallOptions, tracers); PickSubchannelArgsImpl ff3args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions); + method2, headers2, failFastCallOptions, tracers); PickSubchannelArgsImpl ff4args = new PickSubchannelArgsImpl(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions); + method, headers, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr1args = new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions); + method2, headers2, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr2args = new PickSubchannelArgsImpl(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( - method, headers, wfr3callOptions); + method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); wfr3.halfClose(); PickSubchannelArgsImpl wfr3args = new PickSubchannelArgsImpl(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions); + method2, headers2, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr4args = new PickSubchannelArgsImpl(method2, headers2, waitForReadyCallOptions); @@ -386,8 +416,10 @@ public void uncaughtException(Thread t, Throwable e) { // streams are now owned by a real transport (which should prevent the Channel from // terminating). // ff1 and wfr1 went through - verify(mockRealTransport).newStream(method, headers, failFastCallOptions); - verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions); + verify(mockRealTransport).newStream( + method, headers, failFastCallOptions, tracers); + verify(mockRealTransport2).newStream( + method, headers, waitForReadyCallOptions, tracers); assertSame(mockRealStream, ff1.getRealStream()); assertSame(mockRealStream2, wfr1.getRealStream()); verify(mockRealStream).start(any(ClientStreamListener.class)); @@ -443,7 +475,7 @@ public void uncaughtException(Thread t, Throwable e) { // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions); + method, headers, waitForReadyCallOptions, tracers); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions)); @@ -474,14 +506,17 @@ public void reprocess_NoPendingStream() { when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel)); - when(mockRealTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); delayedTransport.reprocess(picker); verifyNoMoreInteractions(picker); verifyNoMoreInteractions(transportListener); // Though picker was not originally used, it will be saved and serve future streams. - ClientStream stream = delayedTransport.newStream(method, headers, CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, headers, CallOptions.DEFAULT, tracers); verify(picker).pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT)); verify(mockInternalSubchannel).obtainActiveTransport(); assertSame(mockRealStream, stream); @@ -519,7 +554,7 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable { @Override public void run() { // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers, callOptions); + delayedTransport.newStream(method, headers, callOptions, tracers); } }; sideThread.start(); @@ -552,7 +587,7 @@ public void run() { @Override public void run() { // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers2, callOptions); + delayedTransport.newStream(method, headers2, callOptions, tracers); } }; sideThread2.start(); @@ -600,7 +635,8 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { // Because there is no pending stream yet, it will do nothing but save the picker. delayedTransport.reprocess(picker); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); assertTrue(delayedTransport.hasPendingStreams()); verify(transportListener).transportInUse(true); @@ -609,7 +645,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { @Test public void pendingStream_appendTimeoutInsight_waitForReady() { ClientStream stream = delayedTransport.newStream( - method, headers, callOptions.withWaitForReady()); + method, headers, callOptions.withWaitForReady(), tracers); stream.start(streamListener); InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); diff --git a/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java b/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java index dad82902395..c07812577d5 100644 --- a/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -33,13 +34,16 @@ */ @RunWith(JUnit4.class) public class FailingClientStreamTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; @Test public void processedRpcProgressPopulatedToListener() { ClientStreamListener listener = mock(ClientStreamListener.class); Status status = Status.UNAVAILABLE; - ClientStream stream = new FailingClientStream(status); + ClientStream stream = new FailingClientStream(status, RpcProgress.PROCESSED, tracers); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.PROCESSED), any(Metadata.class)); } @@ -49,7 +53,7 @@ public void droppedRpcProgressPopulatedToListener() { ClientStreamListener listener = mock(ClientStreamListener.class); Status status = Status.UNAVAILABLE; - ClientStream stream = new FailingClientStream(status, RpcProgress.DROPPED); + ClientStream stream = new FailingClientStream(status, RpcProgress.DROPPED, tracers); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.DROPPED), any(Metadata.class)); } diff --git a/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java b/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java index ff15ef7ff02..98749d74910 100644 --- a/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.verify; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -41,8 +42,9 @@ public void newStreamStart() { Status error = Status.UNAVAILABLE; RpcProgress rpcProgress = RpcProgress.DROPPED; FailingClientTransport transport = new FailingClientTransport(error, rpcProgress); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + new ClientStreamTracer[] { new ClientStreamTracer() {} }); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); diff --git a/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java b/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java new file mode 100644 index 00000000000..5eb5b49fa19 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.mockito.Mockito.mock; + +import io.grpc.ClientStreamTracer; +import io.grpc.ForwardingTestUtil; +import java.lang.reflect.Method; +import java.util.Collections; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ForwardingClientStreamTracer}. */ +@RunWith(JUnit4.class) +public class ForwardingClientStreamTracerTest { + private final ClientStreamTracer mockDelegate = mock(ClientStreamTracer.class); + + @Test + public void allMethodsForwarded() throws Exception { + ForwardingTestUtil.testMethodsForwarded( + ClientStreamTracer.class, + mockDelegate, + new ForwardingClientStreamTracerTest.TestClientStreamTracer(), + Collections.emptyList()); + } + + private final class TestClientStreamTracer extends ForwardingClientStreamTracer { + @Override + protected ClientStreamTracer delegate() { + return mockDelegate; + } + } +} diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 7a3808de6e3..95d1c448f4f 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -27,13 +28,17 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.LoadBalancer.PickResult; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil.Http2Error; import io.grpc.testing.TestMethodDescriptors; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -44,6 +49,10 @@ @RunWith(JUnit4.class) public class GrpcUtilTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; + @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); @@ -244,8 +253,9 @@ public void getTransportFromPickResult_errorPickResult_failFast() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); @@ -260,8 +270,9 @@ public void getTransportFromPickResult_dropPickResult_waitForReady() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); @@ -276,11 +287,39 @@ public void getTransportFromPickResult_dropPickResult_failFast() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.DROPPED), any(Metadata.class)); } + + @Test + public void clientStreamTracerFactoryBackwardCompatibility() { + final AtomicReference transportAttrsRef = new AtomicReference<>(); + final ClientStreamTracer mockTracer = mock(ClientStreamTracer.class); + final Metadata.Key key = Metadata.Key.of("fake-key", Metadata.ASCII_STRING_MARSHALLER); + ClientStreamTracer.Factory oldFactoryImpl = new ClientStreamTracer.Factory() { + @SuppressWarnings("deprecation") + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + transportAttrsRef.set(info.getTransportAttrs()); + headers.put(key, "fake-value"); + return mockTracer; + } + }; + + StreamInfo info = + StreamInfo.newBuilder().setCallOptions(CallOptions.DEFAULT.withWaitForReady()).build(); + Metadata metadata = new Metadata(); + Attributes transAttrs = + Attributes.newBuilder().set(Attributes.Key.create("foo"), "bar").build(); + ClientStreamTracer tracer = GrpcUtil.newClientStreamTracer(oldFactoryImpl, info, metadata); + tracer.streamCreated(transAttrs, metadata); + + assertThat(transportAttrsRef.get()).isEqualTo(transAttrs); + assertThat(metadata.get(key)).isEqualTo("fake-value"); + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index e5e017d756a..ccfb5f074c5 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -71,6 +71,7 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.CompositeChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -151,6 +152,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -225,6 +227,8 @@ public boolean shouldAccept(Runnable command) { private ArgumentCaptor statusCaptor; @Captor private ArgumentCaptor callOptionsCaptor; + @Captor + private ArgumentCaptor tracersCaptor; @Mock private LoadBalancer mockLoadBalancer; @Mock @@ -525,7 +529,9 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -534,7 +540,9 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), executor.runDueTasks(); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); assertThat(callOptionsCaptor.getValue().isWaitForReady()).isTrue(); verify(mockStream).start(streamListenerCaptor.capture()); @@ -600,7 +608,9 @@ public ClientCall interceptCall( MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -609,7 +619,9 @@ public ClientCall interceptCall( executor.runDueTasks(); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); assertThat(callOptionsCaptor.getValue().getOption(callOptionsKey)).isEqualTo("fooValue"); verify(mockStream).start(streamListenerCaptor.capture()); @@ -800,9 +812,13 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft ConnectionClientTransport mockTransport = transportInfo.transport; verify(mockTransport).start(any(ManagedClientTransport.Listener.class)); ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) + when(mockTransport.newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any())) .thenReturn(mockStream); - when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT))) + when(mockTransport.newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any())) .thenReturn(mockStream2); transportListener.transportReady(); when(mockPicker.pickSubchannel( @@ -820,14 +836,19 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); call.start(mockCallListener, headers); - verify(mockTransport, never()) - .newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport, never()).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // Second RPC, will be assigned to the real transport ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener2, headers2); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); + verify(mockTransport).newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); verify(mockStream2).start(any(ClientStreamListener.class)); // Shutdown @@ -872,7 +893,9 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, picker2); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -1021,7 +1044,9 @@ public void callOptionsExecutor() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -1031,7 +1056,8 @@ public void callOptionsExecutor() { // Real streams are started in the call executor if they were previously buffered. assertEquals(1, callExecutor.runDueTasks()); - verify(mockTransport).newStream(same(method), same(headers), same(options)); + verify(mockTransport).newStream( + same(method), same(headers), same(options), ArgumentMatchers.any()); verify(mockStream).start(streamListenerCaptor.capture()); // Call listener callbacks are also run in the call executor @@ -1298,7 +1324,8 @@ public void firstResolvedServerFailedToConnect() throws Exception { same(goodAddress), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo goodTransportInfo = transports.poll(); when(goodTransportInfo.transport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mock(ClientStream.class)); goodTransportInfo.listener.transportReady(); @@ -1310,11 +1337,13 @@ public void firstResolvedServerFailedToConnect() throws Exception { // Delayed transport uses the app executor to create real streams. executor.runDueTasks(); - verify(goodTransportInfo.transport).newStream(same(method), same(headers), - same(CallOptions.DEFAULT)); + verify(goodTransportInfo.transport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // The bad transport was never used. - verify(badTransportInfo.transport, times(0)).newStream(any(MethodDescriptor.class), - any(Metadata.class), any(CallOptions.class)); + verify(badTransportInfo.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test @@ -1464,10 +1493,12 @@ public void allServersFailedToConnect() throws Exception { // ... while the wait-for-ready call stays verifyNoMoreInteractions(mockCallListener); // No real stream was ever created - verify(transportInfo1.transport, times(0)) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - verify(transportInfo2.transport, times(0)) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + verify(transportInfo1.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + verify(transportInfo2.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test @@ -1763,8 +1794,9 @@ public void oobchannels() { assertEquals(0, balancerRpcExecutor.numPendingTasks()); transportInfo.listener.transportReady(); assertEquals(1, balancerRpcExecutor.runDueTasks()); - verify(transportInfo.transport).newStream(same(method), same(headers), - same(CallOptions.DEFAULT)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // The transport goes away transportInfo.listener.transportShutdown(Status.UNAVAILABLE); @@ -1870,7 +1902,9 @@ public void oobChannelHasNoChannelCallCredentials() { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(channelCredValue, callCredValue).inOrder(); @@ -1887,7 +1921,9 @@ public void oobChannelHasNoChannelCallCredentials() { transportInfo.listener.transportReady(); balancerRpcExecutor.runDueTasks(); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); @@ -1919,7 +1955,9 @@ public void oobChannelHasNoChannelCallCredentials() { call.start(mockCallListener2, headers); // CallOptions may contain StreamTracerFactory for census that is added by default. - verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + verify(transportInfo.transport).newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); } @@ -1962,7 +2000,9 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(channelCredValue, callCredValue).inOrder(); @@ -1998,7 +2038,9 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { call.start(mockCallListener2, headers); // CallOptions may contain StreamTracerFactory for census that is added by default. - verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + verify(transportInfo.transport).newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(oobChannelCredValue, callCredValue).inOrder(); oob.shutdownNow(); @@ -2097,7 +2139,9 @@ public void subchannelChannel_normalUsage() { ClientCall call = sChannel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); CallOptions capturedCallOption = callOptionsCaptor.getValue(); assertThat(capturedCallOption.getDeadline()).isSameInstanceAs(callOptions.getDeadline()); @@ -2125,7 +2169,8 @@ public void subchannelChannel_failWhenNotReady() { ClientCall call = sChannel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, headers); verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verifyNoInteractions(mockCallListener); assertEquals(1, balancerRpcExecutor.runDueTasks()); @@ -2157,7 +2202,8 @@ public void subchannelChannel_failWaitForReady() { sChannel.newCall(method, CallOptions.DEFAULT.withWaitForReady()); call.start(mockCallListener, headers); verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verifyNoInteractions(mockCallListener); assertEquals(1, balancerRpcExecutor.runDueTasks()); @@ -2332,7 +2378,8 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { return mock(ClientStream.class); } }).when(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(creds, never()).applyRequestMetadata( any(RequestInfo.class), any(Executor.class), any(CallCredentials.MetadataApplier.class)); @@ -2351,11 +2398,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { assertEquals(AUTHORITY, infoCaptor.getValue().getAuthority()); assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); verify(transport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); // newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport).newStream( + same(method), any(Metadata.class), same(callOptions), + ArgumentMatchers.any()); assertEquals("testValue", testKey.get(newStreamContexts.poll())); // The context should not live beyond the scope of newStream() and applyRequestMetadata() assertNull(testKey.get()); @@ -2374,11 +2424,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); // This is from the first call verify(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); // Still, newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport, times(2)).newStream( + same(method), any(Metadata.class), same(callOptions), + ArgumentMatchers.any()); assertEquals("testValue", testKey.get(newStreamContexts.poll())); assertNull(testKey.get()); @@ -2387,8 +2440,20 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { @Test public void pickerReturnsStreamTracer_noDelay() { ClientStream mockStream = mock(ClientStream.class); - ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); - ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); + final ClientStreamTracer tracer1 = new ClientStreamTracer() {}; + final ClientStreamTracer tracer2 = new ClientStreamTracer() {}; + ClientStreamTracer.Factory factory1 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer1; + } + }; + ClientStreamTracer.Factory factory2 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer2; + } + }; createChannel(); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); @@ -2397,7 +2462,8 @@ public void pickerReturnsStreamTracer_noDelay() { transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( @@ -2409,20 +2475,29 @@ public void pickerReturnsStreamTracer_noDelay() { call.start(mockCallListener, new Metadata()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); - verify(mockTransport).newStream(same(method), any(Metadata.class), callOptionsCaptor.capture()); - assertEquals( - Arrays.asList(factory1, factory2), - callOptionsCaptor.getValue().getStreamTracerFactories()); - // The factories are safely not stubbed because we do not expect any usage of them. - verifyNoInteractions(factory1); - verifyNoInteractions(factory2); + verify(mockTransport).newStream( + same(method), any(Metadata.class), callOptionsCaptor.capture(), + tracersCaptor.capture()); + assertThat(tracersCaptor.getValue()).isEqualTo(new ClientStreamTracer[] {tracer1, tracer2}); } @Test public void pickerReturnsStreamTracer_delayed() { ClientStream mockStream = mock(ClientStream.class); - ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); - ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); + final ClientStreamTracer tracer1 = new ClientStreamTracer() {}; + final ClientStreamTracer tracer2 = new ClientStreamTracer() {}; + ClientStreamTracer.Factory factory1 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer1; + } + }; + ClientStreamTracer.Factory factory2 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer2; + } + }; createChannel(); CallOptions callOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1); @@ -2436,7 +2511,8 @@ public void pickerReturnsStreamTracer_delayed() { transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory2)); @@ -2445,13 +2521,10 @@ public void pickerReturnsStreamTracer_delayed() { assertEquals(1, executor.runDueTasks()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); - verify(mockTransport).newStream(same(method), any(Metadata.class), callOptionsCaptor.capture()); - assertEquals( - Arrays.asList(factory1, factory2), - callOptionsCaptor.getValue().getStreamTracerFactories()); - // The factories are safely not stubbed because we do not expect any usage of them. - verifyNoInteractions(factory1); - verifyNoInteractions(factory2); + verify(mockTransport).newStream( + same(method), any(Metadata.class), callOptionsCaptor.capture(), + tracersCaptor.capture()); + assertThat(tracersCaptor.getValue()).isEqualTo(new ClientStreamTracer[] {tracer1, tracer2}); } @Test @@ -2818,7 +2891,9 @@ public void idleMode_resetsDelayedTransportPicker() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2829,7 +2904,9 @@ public void idleMode_resetsDelayedTransportPicker() { executor.runDueTasks(); // Verify the buffered call was drained - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2888,7 +2965,9 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -2898,7 +2977,9 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { // Verify the original call was drained executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2920,7 +3001,9 @@ public void updateBalancingStateDoesUpdatePicker() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2929,8 +3012,9 @@ public void updateBalancingStateDoesUpdatePicker() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport, never()) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream, never()).start(any(ClientStreamListener.class)); @@ -2939,7 +3023,9 @@ public void updateBalancingStateDoesUpdatePicker() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2958,7 +3044,9 @@ public void updateBalancingState_withWrappedSubchannel() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2973,7 +3061,9 @@ protected Subchannel delegate() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -3405,7 +3495,8 @@ private void channelsAndSubchannels_instrumented0(boolean success) throws Except transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory)); @@ -3478,7 +3569,9 @@ private void channelsAndSubchannels_oob_instrumented0(boolean success) throws Ex MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); // subchannel stat bumped when call gets assigned to it @@ -3650,7 +3743,9 @@ public double nextDouble() { ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); updateBalancingStateSafely(helper, READY, mockPicker); @@ -3754,7 +3849,9 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); updateBalancingStateSafely(helper, READY, mockPicker); diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index a83964b5e91..26c6fcf9b4e 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -163,9 +163,10 @@ void postCommit() { } @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata metadata) { + ClientStream newSubstream( + Metadata metadata, ClientStreamTracer.Factory tracerFactory, boolean isTransparentRetry) { bufferSizeTracer = - tracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracerFactory.newClientStreamTracer(STREAM_INFO, metadata); int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null ? 0 : Integer.valueOf(metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); return retriableStreamRecorder.newSubstream(actualPreviousRpcAttemptsInHeader); diff --git a/core/src/test/java/io/grpc/internal/TestUtils.java b/core/src/test/java/io/grpc/internal/TestUtils.java index d5b4ce4949e..974f36e595c 100644 --- a/core/src/test/java/io/grpc/internal/TestUtils.java +++ b/core/src/test/java/io/grpc/internal/TestUtils.java @@ -23,6 +23,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.InternalLogId; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -35,6 +36,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import javax.annotation.Nullable; +import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -118,7 +120,8 @@ public ConnectionClientTransport answer(InvocationOnMock invocation) throws Thro when(mockTransport.getLogId()) .thenReturn(InternalLogId.allocate("mocktransport", /*details=*/ null)); when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mock(ClientStream.class)); // Save the listener doAnswer(new Answer() { diff --git a/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java b/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java index fcb19b69eb8..dbd7e99b29a 100644 --- a/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java +++ b/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java @@ -40,6 +40,7 @@ public void allMethodsForwarded() throws Exception { Collections.emptyList()); } + @SuppressWarnings("deprecation") private final class TestClientStreamTracer extends ForwardingClientStreamTracer { @Override protected ClientStreamTracer delegate() { diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java index dc4fc45ae4e..d41ec372d4c 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -118,7 +119,7 @@ public ListenableFuture getStats() { @Override public CronetClientStream newStream(final MethodDescriptor method, final Metadata headers, - final CallOptions callOptions) { + final CallOptions callOptions, ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); @@ -126,7 +127,7 @@ public CronetClientStream newStream(final MethodDescriptor method, final M final String url = "https://" + authority + defaultPath; final StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, attrs, headers); + StatsTraceContext.newClientContext(tracers, attrs, headers); class StartCallback implements Runnable { final CronetClientStream clientStream = new CronetClientStream( url, userAgent, executor, headers, CronetClientTransport.this, this, lock, maxMessageSize, diff --git a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java index 39fe03991e4..c27963c6d56 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java @@ -25,6 +25,7 @@ import android.os.Build; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.cronet.CronetChannelBuilder.CronetTransportFactory; @@ -50,6 +51,8 @@ public final class CronetChannelBuilderTest { @Mock private ExperimentalCronetEngine mockEngine; @Mock private ChannelLogger channelLogger; + private final ClientStreamTracer[] tracers = + new ClientStreamTracer[]{ new ClientStreamTracer() {} }; private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @Before @@ -69,7 +72,8 @@ public void alwaysUsePutTrue_cronetStreamIsIdempotent() throws Exception { new InetSocketAddress("localhost", 443), new ClientTransportOptions(), channelLogger); - CronetClientStream stream = transport.newStream(method, new Metadata(), CallOptions.DEFAULT); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); assertTrue(stream.idempotent); } @@ -85,7 +89,8 @@ public void alwaysUsePut_defaultsToFalse() throws Exception { new InetSocketAddress("localhost", 443), new ClientTransportOptions(), channelLogger); - CronetClientStream stream = transport.newStream(method, new Metadata(), CallOptions.DEFAULT); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); assertFalse(stream.idempotent); } diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java index 50017cb43f8..9503481e747 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java @@ -27,6 +27,7 @@ import android.os.Build; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -60,6 +61,8 @@ public final class CronetClientTransportTest { private static final Attributes EAG_ATTRS = Attributes.newBuilder().set(EAG_ATTR_KEY, "value").build(); + private final ClientStreamTracer[] tracers = + new ClientStreamTracer[]{ new ClientStreamTracer() {} }; private CronetClientTransport transport; @Mock private StreamBuilderFactory streamFactory; @Mock private Executor executor; @@ -101,9 +104,9 @@ public void transportAttributes() { @Test public void shutdownTransport() throws Exception { CronetClientStream stream1 = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); CronetClientStream stream2 = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); // Create a transport and start two streams on it. ArgumentCaptor callbackCaptor = @@ -137,7 +140,7 @@ public void shutdownTransport() throws Exception { @Test public void startStreamAfterShutdown() throws Exception { CronetClientStream stream = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdown(); BaseClientStreamListener listener = new BaseClientStreamListener(); stream.start(listener); diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java index d27c485dc13..75f2481254d 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java @@ -37,7 +37,7 @@ * span of an LB stream with the remote load-balancer. */ @ThreadSafe -final class GrpclbClientLoadRecorder extends ClientStreamTracer.Factory { +final class GrpclbClientLoadRecorder extends ClientStreamTracer.InternalLimitedInfoFactory { private static final AtomicLongFieldUpdater callsStartedUpdater = AtomicLongFieldUpdater.newUpdater(GrpclbClientLoadRecorder.class, "callsStarted"); diff --git a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java index 03b9bdf7f1b..03e1447bb2c 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java +++ b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java @@ -22,6 +22,7 @@ import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.internal.GrpcAttributes; import javax.annotation.Nullable; @@ -29,7 +30,7 @@ * Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and * attaches them to headers. This is only used in the PICK_FIRST mode. */ -final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { +final class TokenAttachingTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; @Nullable @@ -42,19 +43,30 @@ final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs"); - Attributes eagAttrs = - checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); - String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); - headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); - if (token != null) { - headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); - } - if (delegate != null) { - return delegate.newClientStreamTracer(info, headers); - } else { + if (delegate == null) { return NOOP_TRACER; } + final ClientStreamTracer clientStreamTracer = delegate.newClientStreamTracer(info, headers); + class TokenPropagationTracer extends ForwardingClientStreamTracer { + @Override + protected ClientStreamTracer delegate() { + return clientStreamTracer; + } + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + Attributes eagAttrs = + checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); + String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); + headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); + if (token != null) { + headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); + } + delegate().streamCreated(transportAttrs, headers); + } + } + + return new TokenPropagationTracer(); } @Override diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 39f736dbcf4..a68962ad7d9 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -481,6 +481,7 @@ public void loadReporting() { ClientStreamTracer tracer1 = pick1.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer1.streamCreated(Attributes.EMPTY, new Metadata()); PickResult pick2 = picker.pickSubchannel(args); assertNull(pick2.getSubchannel()); @@ -504,6 +505,7 @@ public void loadReporting() { assertSame(getLoadRecorder(), pick3.getStreamTracerFactory()); ClientStreamTracer tracer3 = pick3.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer3.streamCreated(Attributes.EMPTY, new Metadata()); // pick3 has sent out headers tracer3.outboundHeaders(); @@ -541,6 +543,7 @@ public void loadReporting() { assertSame(getLoadRecorder(), pick5.getStreamTracerFactory()); ClientStreamTracer tracer5 = pick5.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer5.streamCreated(Attributes.EMPTY, new Metadata()); // pick3 ended without receiving response headers tracer3.streamClosed(Status.DEADLINE_EXCEEDED); diff --git a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java index 34b0a8ea1aa..29ded18d913 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java @@ -33,12 +33,23 @@ /** Unit tests for {@link TokenAttachingTracerFactory}. */ @RunWith(JUnit4.class) public class TokenAttachingTracerFactoryTest { - private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {}; + private static final class FakeClientStreamTracer extends ClientStreamTracer { + Attributes transportAttrs; + Metadata headers; + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + this.transportAttrs = transportAttrs; + this.headers = headers; + } + } + + private static final FakeClientStreamTracer fakeTracer = new FakeClientStreamTracer(); private final ClientStreamTracer.Factory delegate = mock( ClientStreamTracer.Factory.class, delegatesTo( - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -51,28 +62,25 @@ public void hasToken() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); Attributes eagAttrs = Attributes.newBuilder() .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build(); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); // Preexisting token should be replaced headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token"); ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); verify(delegate).newClientStreamTracer(same(info), same(headers)); - assertThat(tracer).isSameInstanceAs(fakeTracer); + Attributes transportAttrs = + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); + tracer.streamCreated(transportAttrs, headers); + assertThat(fakeTracer.transportAttrs).isSameInstanceAs(transportAttrs); + assertThat(fakeTracer.headers).isSameInstanceAs(headers); assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001"); } @Test public void noToken() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder() - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); // Preexisting token should be removed @@ -80,22 +88,25 @@ public void noToken() { ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); verify(delegate).newClientStreamTracer(same(info), same(headers)); - assertThat(tracer).isSameInstanceAs(fakeTracer); + Attributes transportAttrs = + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(); + tracer.streamCreated(transportAttrs, headers); + assertThat(fakeTracer.transportAttrs).isSameInstanceAs(transportAttrs); + assertThat(fakeTracer.headers).isSameInstanceAs(headers); assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); } @Test public void nullDelegate() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder() - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); + tracer.streamCreated( + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(), + headers); assertThat(tracer).isNotNull(); assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 758f99d5353..1b447a63c32 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -289,7 +289,7 @@ final SocketAddress getListenAddress() { new LinkedBlockingQueue<>(); private final ClientStreamTracer.Factory clientStreamTracerFactory = - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -375,7 +375,8 @@ protected final ClientInterceptor createCensusStatsClientInterceptor() { .getClientInterceptor( tagger, tagContextBinarySerializer, clientStatsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, - true, true, true, false /* real-time metrics */); + true, true, true, + /* recordRealTimeMetrics= */ false); } protected final ServerStreamTracer.Factory createCustomCensusTracerFactory() { @@ -1179,6 +1180,7 @@ public void deadlineExceeded() throws Exception { public void deadlineExceededServerStreaming() throws Exception { // warm up the channel and JVM blockingStub.emptyCall(Empty.getDefaultInstance()); + assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); ResponseParameters.Builder responseParameters = ResponseParameters.newBuilder() .setSize(1) .setIntervalUs(10000); @@ -1195,7 +1197,6 @@ public void deadlineExceededServerStreaming() throws Exception { recorder.awaitCompletion(); assertEquals(Status.DEADLINE_EXCEEDED.getCode(), Status.fromThrowable(recorder.getError()).getCode()); - assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); if (metricsExpected()) { // Stream may not have been created when deadline is exceeded, thus we don't check tracer // stats. @@ -1239,6 +1240,12 @@ public void deadlineInPast() throws Exception { // warm up the channel blockingStub.emptyCall(Empty.getDefaultInstance()); + if (metricsExpected()) { + // clientStartRecord + clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + // clientEndRecord + clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + } try { blockingStub .withDeadlineAfter(-10, TimeUnit.SECONDS) @@ -1249,7 +1256,6 @@ public void deadlineInPast() throws Exception { assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after deadline exceeded"); } - assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); if (metricsExpected()) { MetricsRecord clientStartRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); checkStartTags(clientStartRecord, "grpc.testing.TestService/EmptyCall", true); diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index c3807986c9f..a7a1044059c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -28,6 +28,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -167,14 +168,15 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); if (channel == null) { - return new FailingClientStream(statusExplainingWhyTheChannelIsNull); + return new FailingClientStream(statusExplainingWhyTheChannelIsNull, tracers); } StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, getAttributes(), headers); + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); return new NettyClientStream( new NettyClientStream.TransportState( handler, diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index e4165e89243..018ca9b6594 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -41,6 +41,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.InternalChannelz; import io.grpc.Metadata; @@ -828,7 +829,9 @@ private static class Rpc { } Rpc(NettyClientTransport transport, Metadata headers) { - stream = transport.newStream(METHOD, headers, CallOptions.DEFAULT); + stream = transport.newStream( + METHOD, headers, CallOptions.DEFAULT, + new ClientStreamTracer[]{ new ClientStreamTracer() {} }); stream.start(listener); stream.request(1); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8))); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index a001ddb73e7..121093716db 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -34,6 +34,7 @@ import com.squareup.okhttp.internal.http.StatusLine; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz; @@ -387,12 +388,13 @@ public void ping(final PingCallback callback, Executor executor) { } @Override - public OkHttpClientStream newStream(final MethodDescriptor method, - final Metadata headers, CallOptions callOptions) { + public OkHttpClientStream newStream( + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); - StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, attributes, headers); + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -406,7 +408,7 @@ public OkHttpClientStream newStream(final MethodDescriptor method, initialWindowSize, defaultAuthority, userAgent, - statsTraceCtx, + statsTraceContext, transportTracer, callOptions, useGetForSafeMethods); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index b03a2dedc00..b70b832a797 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -56,6 +56,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -146,6 +147,9 @@ public class OkHttpClientTransportTest { private static final int DEFAULT_MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE; private static final Attributes EAG_ATTRS = Attributes.EMPTY; private static final Logger logger = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; @Rule public final Timeout globalTimeout = Timeout.seconds(10); @@ -299,7 +303,7 @@ public void close() throws SecurityException { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -387,7 +391,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -443,11 +447,11 @@ public void nextFrameThrowIoException() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); assertEquals(2, activeStreamCount()); @@ -477,7 +481,7 @@ public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertEquals(1, activeStreamCount()); @@ -498,7 +502,7 @@ public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); frameReader.nextFrameAtEndOfStream(); @@ -516,7 +520,7 @@ public void readMessages() throws Exception { final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); assertContainStream(3); @@ -566,7 +570,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -590,7 +594,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -610,7 +614,7 @@ public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); @@ -624,7 +628,7 @@ public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().rstStream(3, ErrorCode.PROTOCOL_ERROR); @@ -641,7 +645,7 @@ public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); @@ -661,7 +665,7 @@ public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -676,7 +680,7 @@ public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), GrpcUtil.getGrpcUserAgent("okhttp", null)); @@ -695,7 +699,7 @@ public void overrideDefaultUserAgent() throws Exception { startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "fakeUserAgent"); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List
expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), @@ -714,7 +718,7 @@ public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -728,7 +732,7 @@ public void writeMessage() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); assertEquals(12, input.available()); @@ -772,12 +776,12 @@ public void windowUpdate() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); assertEquals(2, activeStreamCount()); @@ -838,7 +842,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; byte[] fakeMessage = new byte[messageLength]; @@ -874,7 +878,7 @@ public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); // Outbound window always starts at 65535 until changed by Settings.INITIAL_WINDOW_SIZE @@ -920,7 +924,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 75; @@ -963,7 +967,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 100000; @@ -999,7 +1003,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; setInitialWindowSize(HEADER_LENGTH + 10); @@ -1045,7 +1049,7 @@ public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() thr initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; setInitialWindowSize(HEADER_LENGTH + 10); @@ -1080,10 +1084,10 @@ public void stopNormally() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); clientTransport.shutdown(SHUTDOWN_REASON); @@ -1110,11 +1114,11 @@ public void receiveGoAway() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); assertEquals(2, activeStreamCount()); @@ -1168,7 +1172,7 @@ public void streamIdExhausted() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1204,11 +1208,11 @@ public void pendingStreamSucceed() throws Exception { final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); @@ -1241,7 +1245,7 @@ public void pendingStreamCancelled() throws Exception { setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); stream.cancel(Status.CANCELLED); @@ -1260,11 +1264,11 @@ public void pendingStreamFailedByGoAway() throws Exception { final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); waitForStreamPending(1); @@ -1290,7 +1294,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1314,15 +1318,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(2); @@ -1346,7 +1350,7 @@ public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1398,7 +1402,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( eq(false), eq(false), eq(3), eq(0), ArgumentMatchers.
anyList()); @@ -1415,7 +1419,7 @@ public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); @@ -1437,7 +1441,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); @@ -1459,7 +1463,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1000]); @@ -1480,7 +1484,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1507,7 +1511,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); // This should be ignored. @@ -1527,7 +1531,7 @@ public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); assertTrue(listener.isOnReadyCalled()); @@ -1545,7 +1549,7 @@ public void notifyOnReady() throws Exception { setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); // Be notified at the beginning. @@ -1695,7 +1699,7 @@ public void writeBeforeConnected() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); stream.writeMessage(input); @@ -1720,7 +1724,7 @@ public void cancelBeforeConnected() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); stream.writeMessage(input); @@ -1738,7 +1742,7 @@ public void shutdownDuringConnecting() throws Exception { initTransportAndDelayConnected(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); clientTransport.shutdown(SHUTDOWN_REASON); allowTransportConnected(); @@ -1810,7 +1814,8 @@ public void unreachableServer() throws Exception { assertTrue(status.getCause().toString(), status.getCause() instanceof IOException); MockStreamListener streamListener = new MockStreamListener(); - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT).start(streamListener); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers) + .start(streamListener); streamListener.waitUntilStreamClosed(); assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode()); } @@ -2054,13 +2059,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2094,13 +2099,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); assertEquals(3, activeStreamCount()); @@ -2158,7 +2163,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); assertFalse(listener.status.isOk()); diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 5beefc3384c..dffbe3dade7 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -31,8 +31,8 @@ import io.grpc.LoadBalancer; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.internal.ObjectPool; -import io.grpc.util.ForwardingClientStreamTracer; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; @@ -329,7 +329,8 @@ public String toString() { } } - private static final class CountingStreamTracerFactory extends ClientStreamTracer.Factory { + private static final class CountingStreamTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { private ClusterLocalityStats stats; private final AtomicLong inFlights; @Nullable diff --git a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java index c193f5e35e5..156d53f638e 100644 --- a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java +++ b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java @@ -25,8 +25,8 @@ import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.LoadBalancer; import io.grpc.Metadata; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.protobuf.ProtoUtils; -import io.grpc.util.ForwardingClientStreamTracer; import java.util.ArrayList; import java.util.List; @@ -37,7 +37,7 @@ abstract class OrcaPerRequestUtil { private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER = new ClientStreamTracer() {}; private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY = - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { return NOOP_CLIENT_STREAM_TRACER; @@ -189,7 +189,8 @@ public interface OrcaPerRequestReportListener { * per-request ORCA reports and push to registered listeners for calls they trace. */ @VisibleForTesting - static final class OrcaReportingTracerFactory extends ClientStreamTracer.Factory { + static final class OrcaReportingTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { @VisibleForTesting static final Metadata.Key ORCA_ENDPOINT_LOAD_METRICS_KEY = diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 74aa85501a9..3b2a54c2c25 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -341,8 +341,8 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); - streamTracerFactory.newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), - new Metadata()); + streamTracerFactory.newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); } ClusterStats clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); @@ -429,8 +429,8 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); - streamTracerFactory.newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), - new Metadata()); + streamTracerFactory.newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); } ClusterStats clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); From f6ce672369a26fc4eca56360ec4683cd48a5848b Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 2 Aug 2021 09:08:26 -0700 Subject: [PATCH 06/82] Revert "core: correcting a minor resource releasing issue" This reverts commit b9becb5c8e797f202ed0e566e752bd7a91e599e1. It is the caller's responsibility to close their InputStream. The change ended up closing the passed InputStream. --- .../java/io/grpc/util/CertificateUtils.java | 42 +++++++------------ 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/io/grpc/util/CertificateUtils.java b/core/src/main/java/io/grpc/util/CertificateUtils.java index e8bbc90cb36..980862d3836 100644 --- a/core/src/main/java/io/grpc/util/CertificateUtils.java +++ b/core/src/main/java/io/grpc/util/CertificateUtils.java @@ -65,36 +65,24 @@ public static X509Certificate[] getX509Certificates(InputStream inputStream) public static PrivateKey getPrivateKey(InputStream inputStream) throws UnsupportedEncodingException, IOException, NoSuchAlgorithmException, InvalidKeySpecException { - InputStreamReader isr = null; - BufferedReader reader = null; - try { - isr = new InputStreamReader(inputStream, "UTF-8"); - reader = new BufferedReader(isr); - String line; - while ((line = reader.readLine()) != null) { - if ("-----BEGIN PRIVATE KEY-----".equals(line)) { - break; - } + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")); + String line; + while ((line = reader.readLine()) != null) { + if ("-----BEGIN PRIVATE KEY-----".equals(line)) { + break; } - StringBuilder keyContent = new StringBuilder(); - while ((line = reader.readLine()) != null) { - if ("-----END PRIVATE KEY-----".equals(line)) { - break; - } - keyContent.append(line); - } - byte[] decodedKeyBytes = BaseEncoding.base64().decode(keyContent.toString()); - KeyFactory keyFactory = KeyFactory.getInstance("RSA"); - PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKeyBytes); - return keyFactory.generatePrivate(keySpec); - } finally { - if (null != reader) { - reader.close(); - } - if (null != isr) { - isr.close(); + } + StringBuilder keyContent = new StringBuilder(); + while ((line = reader.readLine()) != null) { + if ("-----END PRIVATE KEY-----".equals(line)) { + break; } + keyContent.append(line); } + byte[] decodedKeyBytes = BaseEncoding.base64().decode(keyContent.toString()); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKeyBytes); + return keyFactory.generatePrivate(keySpec); } } From 62b4364a77fe5ceef30afe98497a1ea9b5ca7902 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 2 Aug 2021 13:09:40 -0700 Subject: [PATCH 07/82] api: Fix Javadoc reference to NameResolver.Args NameResolver.Helper was a short-lived class that didn't get very far. We chose NameResolver.Args instead and didn't mirror LoadBalancer. --- api/src/main/java/io/grpc/ProxyDetector.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/java/io/grpc/ProxyDetector.java b/api/src/main/java/io/grpc/ProxyDetector.java index 5202516bca7..7c04329c079 100644 --- a/api/src/main/java/io/grpc/ProxyDetector.java +++ b/api/src/main/java/io/grpc/ProxyDetector.java @@ -32,7 +32,7 @@ * underlying transport need to work together. * *

The {@link NameResolver} should invoke the {@link ProxyDetector} retrieved from the {@link - * NameResolver.Helper#getProxyDetector}, and pass the returned {@link ProxiedSocketAddress} to + * NameResolver.Args#getProxyDetector}, and pass the returned {@link ProxiedSocketAddress} to * {@link NameResolver.Listener#onAddresses}. The DNS name resolver shipped with gRPC is already * doing so. * From 75691c85889e37411fffc580c0c996117a775ca9 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 2 Aug 2021 12:54:52 -0700 Subject: [PATCH 08/82] build.gradle: Remove non-standard MANIFEST.MF attributes Including the build user's name doesn't provide much value and may surprise some people. Built-JDK is actually wrong, as it is reporting Gradle's Java version, not the javac version. And Source-/Target- Compatibility isn't useful if nobody looks at it. Generally people just look at the bytecode version itself, which is much more reliable and doesn't have questions as to whether it should be '8' or '1.8'. --- build.gradle | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/build.gradle b/build.gradle index a77f3801f0e..393b91a30bd 100644 --- a/build.gradle +++ b/build.gradle @@ -269,11 +269,7 @@ subprojects { jar.manifest { attributes('Implementation-Title': name, - 'Implementation-Version': version, - 'Built-By': System.getProperty('user.name'), - 'Built-JDK': System.getProperty('java.version'), - 'Source-Compatibility': sourceCompatibility, - 'Target-Compatibility': targetCompatibility) + 'Implementation-Version': version) } javadoc.options { From 1833587597b63ed04823e4d48cc75a5f5b9feaed Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Mon, 2 Aug 2021 15:54:55 -0700 Subject: [PATCH 09/82] binder: fix binder build (#8366) The binder module fail to compile because in #8355, the class BinderClientTransportTest was not refactored correspondingly. --- .../internal/BinderClientTransportTest.java | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index 0d3c3bf4b51..b99114bb501 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -30,6 +30,7 @@ import com.google.protobuf.Empty; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; @@ -68,6 +69,10 @@ */ @RunWith(AndroidJUnit4.class) public final class BinderClientTransportTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; + private final Context appContext = ApplicationProvider.getApplicationContext(); MethodDescriptor.Marshaller marshaller = @@ -155,7 +160,8 @@ public void tearDown() throws Exception { @Test public void testShutdownBeforeStreamStart_b153326034() throws Exception { - ClientStream stream = transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); // This shouldn't throw an exception. @@ -165,7 +171,7 @@ public void testShutdownBeforeStreamStart_b153326034() throws Exception { @Test public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception { ClientStream stream = - transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); @@ -183,7 +189,7 @@ public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception @Test public void testTransactionForDiscardedCall_b155244043() throws Exception { ClientStream stream = - transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); @@ -201,7 +207,7 @@ public void testTransactionForDiscardedCall_b155244043() throws Exception { @Test public void testBadTransactionStreamThroughput_b163053382() throws Exception { ClientStream stream = - transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); @@ -220,7 +226,7 @@ public void testBadTransactionStreamThroughput_b163053382() throws Exception { @Test public void testMessageProducerClosedAfterStream_b169313545() { ClientStream stream = - transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); From f781d24ddd4ccb34e16cb900b3906f87e0ac6574 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 2 Aug 2021 13:18:53 -0700 Subject: [PATCH 10/82] Mostly revert "Run binderchannel android tests. (#8306)" This partilaly reverts commit 5e18ff208aacd1d626716603a8aa57563b399ace. It leaves the compilation fix that was made to BinderClientTransportTest. Running instrumentation tests via firebase requires a `--app` argument. However, we don't have such an app and it isn't immediately clear how we'll go about making one. Revert the change to let android-testing to start passing again. This problem wasn't noticed before merging the original commit because android-testing is a post-commit CI. --- buildscripts/kokoro/android-interop.sh | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/buildscripts/kokoro/android-interop.sh b/buildscripts/kokoro/android-interop.sh index 8a8a2bc7bc5..5d9774bb12f 100755 --- a/buildscripts/kokoro/android-interop.sh +++ b/buildscripts/kokoro/android-interop.sh @@ -38,18 +38,3 @@ gcloud firebase test android run \ --device model=Nexus6P,version=23,locale=en,orientation=portrait \ --device model=Nexus6,version=22,locale=en,orientation=portrait \ --device model=Nexus6,version=21,locale=en,orientation=portrait - -# Build and run binder transport instrumentation tests on Firebase Test Lab -cd ../binder -../gradlew assembleDebug -../gradlew assembleDebugAndroidTest -gcloud firebase test android run \ - --type instrumentation \ - --test build/outputs/apk/androidTest/debug/grpc-binder-debug-androidTest.apk \ - --device model=Nexus6P,version=27,locale=en,orientation=portrait \ - --device model=Nexus6P,version=26,locale=en,orientation=portrait \ - --device model=Nexus6P,version=25,locale=en,orientation=portrait \ - --device model=Nexus6P,version=24,locale=en,orientation=portrait \ - --device model=Nexus6P,version=23,locale=en,orientation=portrait \ - --device model=Nexus6,version=22,locale=en,orientation=portrait \ - --device model=Nexus6,version=21,locale=en,orientation=portrait From 57bd087cdf416343296bf832092826111afab752 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 2 Aug 2021 12:53:09 -0700 Subject: [PATCH 11/82] buildscripts: Build android instrumentation tests in android CI Binder's :build was missing. Cronet build failed without specifying Java 8 because of the transitive Guava dependency. --- buildscripts/kokoro/android.sh | 2 ++ cronet/build.gradle | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/buildscripts/kokoro/android.sh b/buildscripts/kokoro/android.sh index 50337ef878b..3ca4fa88bc0 100755 --- a/buildscripts/kokoro/android.sh +++ b/buildscripts/kokoro/android.sh @@ -31,6 +31,8 @@ buildscripts/make_dependencies.sh :grpc-android-interop-testing:build \ :grpc-android:build \ :grpc-cronet:build \ + :grpc-binder:build \ + assembleAndroidTest \ publishToMavenLocal if [[ ! -z $(git status --porcelain) ]]; then diff --git a/cronet/build.gradle b/cronet/build.gradle index 2d73cc4194f..7a6d58a7e4b 100644 --- a/cronet/build.gradle +++ b/cronet/build.gradle @@ -28,6 +28,10 @@ android { consumerProguardFiles 'proguard-rules.pro' } } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } testOptions { unitTests { includeAndroidResources = true } } lintOptions { disable 'InvalidPackage' } } From 0d80c33bce84edd678726279df71a2f3755422ca Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Tue, 3 Aug 2021 13:01:09 -0700 Subject: [PATCH 12/82] xds: log error and fail start() if server-listener-resource-name-template not set or not using xds_v3 (#8375) --- .../io/grpc/xds/XdsClientWrapperForServerSds.java | 12 ++++++++++-- .../test/java/io/grpc/xds/XdsServerTestHelper.java | 2 +- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java index c73c071d951..9a1b659ef1e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java @@ -125,14 +125,22 @@ public void onError(Status error) { } } }; + newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); + if (!newServerApi) { + reportError( + new XdsInitializationException( + "requires use of xds_v3 in xds bootstrap"), + true); + return; + } grpcServerResourceId = xdsClient.getBootstrapInfo() .getServerListenerResourceNameTemplate(); - newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); - if (newServerApi && grpcServerResourceId == null) { + if (grpcServerResourceId == null) { reportError( new XdsInitializationException( "missing server_listener_resource_name_template value in xds bootstrap"), true); + return; } grpcServerResourceId = grpcServerResourceId.replaceAll("%s", "0.0.0.0:" + port); xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 071fbd8a108..2c455673239 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -47,7 +47,7 @@ public class XdsServerTestHelper { static final Bootstrapper.BootstrapInfo BOOTSTRAP_INFO = new Bootstrapper.BootstrapInfo( Arrays.asList( - new Bootstrapper.ServerInfo(SERVER_URI, InsecureChannelCredentials.create(), false)), + new Bootstrapper.ServerInfo(SERVER_URI, InsecureChannelCredentials.create(), true)), BOOTSTRAP_NODE, null, "grpc/server?udpa.resource.listening_address=%s"); From c77083f013bbd3f1bb199ba367d434432e28b682 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 4 Aug 2021 14:32:49 -0700 Subject: [PATCH 13/82] core: fix old ClientStreamTracer.Factory creating tracers twice (#8381) Fix a bug introduced in #8355 : old ClientStreamTracer.Factory implementation creates tracers twice. --- core/src/main/java/io/grpc/internal/GrpcUtil.java | 14 ++++++++++---- .../test/java/io/grpc/internal/GrpcUtilTest.java | 9 ++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 0e11e5ece39..5b5a062e95d 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -65,7 +65,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -785,15 +784,22 @@ static ClientStreamTracer newClientStreamTracer( } else { streamTracer = new ForwardingClientStreamTracer() { final ClientStreamTracer noop = new ClientStreamTracer() {}; - AtomicReference delegate = new AtomicReference<>(noop); + volatile ClientStreamTracer delegate = noop; void maybeInit(StreamInfo info, Metadata headers) { - delegate.compareAndSet(noop, streamTracerFactory.newClientStreamTracer(info, headers)); + if (delegate != noop) { + return; + } + synchronized (this) { + if (delegate == noop) { + delegate = streamTracerFactory.newClientStreamTracer(info, headers); + } + } } @Override protected ClientStreamTracer delegate() { - return delegate.get(); + return delegate; } @SuppressWarnings("deprecation") diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 95d1c448f4f..6d2c21ddab8 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -38,6 +38,7 @@ import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil.Http2Error; import io.grpc.testing.TestMethodDescriptors; +import java.util.ArrayDeque; import java.util.concurrent.atomic.AtomicReference; import org.junit.Rule; import org.junit.Test; @@ -301,12 +302,14 @@ public void clientStreamTracerFactoryBackwardCompatibility() { final AtomicReference transportAttrsRef = new AtomicReference<>(); final ClientStreamTracer mockTracer = mock(ClientStreamTracer.class); final Metadata.Key key = Metadata.Key.of("fake-key", Metadata.ASCII_STRING_MARSHALLER); + final ArrayDeque tracers = new ArrayDeque<>(); ClientStreamTracer.Factory oldFactoryImpl = new ClientStreamTracer.Factory() { @SuppressWarnings("deprecation") @Override public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { transportAttrsRef.set(info.getTransportAttrs()); headers.put(key, "fake-value"); + tracers.offer(mockTracer); return mockTracer; } }; @@ -318,8 +321,12 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header Attributes.newBuilder().set(Attributes.Key.create("foo"), "bar").build(); ClientStreamTracer tracer = GrpcUtil.newClientStreamTracer(oldFactoryImpl, info, metadata); tracer.streamCreated(transAttrs, metadata); - + assertThat(tracers.poll()).isSameInstanceAs(mockTracer); assertThat(transportAttrsRef.get()).isEqualTo(transAttrs); assertThat(metadata.get(key)).isEqualTo("fake-value"); + + tracer.streamClosed(Status.UNAVAILABLE); + // verify that newClientStreamTracer() is called no more than once + assertThat(tracers).isEmpty(); } } From 9dd0c66929b5df730d48ac9437fe2f1790c7c717 Mon Sep 17 00:00:00 2001 From: Nick Ufer Date: Thu, 24 Jun 2021 18:44:31 +0200 Subject: [PATCH 14/82] netty: removes TODO in test for NettyServer --- netty/src/test/java/io/grpc/netty/NettyServerTest.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/netty/src/test/java/io/grpc/netty/NettyServerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerTest.java index 3f277ed4356..12dc5b9fa51 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerTest.java @@ -486,9 +486,8 @@ public void run() {} assertEquals(ns.getListenSocketAddress(), socketStats.local); assertNull(socketStats.remote); - // TODO(zpencer): uncomment when sock options are exposed // by default, there are some socket options set on the listen socket - // assertThat(socketStats.socketOptions.additional).isNotEmpty(); + assertThat(socketStats.socketOptions.others).isNotEmpty(); // Cleanup ns.shutdown(); From 3668f2e52c8f8283b0d7004000921c143f0b61dc Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Thu, 5 Aug 2021 18:22:37 -0700 Subject: [PATCH 15/82] core: fix bug RetriableStream cancel() racing with start() (#8386) There is a bug in the scenario of the following sequence of events: - `call.start()` - received retryable status and about to retry - The retry attempt Substream is created but not yet `drain()` - `call.cancel()` But `stream.cancel()` cannot be called prior to `stream.start()`, otherwise retry will cause a failure with IllegalStateException. The current RetriableStream code must be fixed to not cancel a stream until `start()` is called in `drain()`. --- .../io/grpc/internal/RetriableStream.java | 43 +++++++++++++------ 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 23725788466..396c7cedfe2 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -104,6 +104,7 @@ abstract class RetriableStream implements ClientStream { @GuardedBy("lock") private FutureCanceller scheduledHedging; private long nextBackoffIntervalNanos; + private Status cancellationStatus; RetriableStream( MethodDescriptor method, Metadata headers, @@ -244,14 +245,16 @@ private void drain(Substream substream) { int index = 0; int chunk = 0x80; List list = null; + boolean streamStarted = false; while (true) { State savedState; synchronized (lock) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me + if (savedState.winningSubstream != null && savedState.winningSubstream != substream + && streamStarted) { + // committed but not me, to be cancelled break; } if (index == savedState.buffer.size()) { // I'm drained @@ -275,17 +278,22 @@ private void drain(Substream substream) { for (BufferEntry bufferEntry : list) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me + if (savedState.winningSubstream != null && savedState.winningSubstream != substream + && streamStarted) { + // committed but not me, to be cancelled break; } - if (savedState.cancelled) { + if (savedState.cancelled && streamStarted) { checkState( savedState.winningSubstream == substream, "substream should be CANCELLED_BECAUSE_COMMITTED already"); + substream.stream.cancel(cancellationStatus); return; } bufferEntry.runWith(substream); + if (bufferEntry instanceof RetriableStream.StartEntry) { + streamStarted = true; + } } } @@ -299,6 +307,13 @@ private void drain(Substream substream) { @Nullable abstract Status prestart(); + class StartEntry implements BufferEntry { + @Override + public void runWith(Substream substream) { + substream.stream.start(new Sublistener(substream)); + } + } + /** Starts the first PRC attempt. */ @Override public final void start(ClientStreamListener listener) { @@ -311,13 +326,6 @@ public final void start(ClientStreamListener listener) { return; } - class StartEntry implements BufferEntry { - @Override - public void runWith(Substream substream) { - substream.stream.start(new Sublistener(substream)); - } - } - synchronized (lock) { state.buffer.add(new StartEntry()); } @@ -450,11 +458,18 @@ public final void cancel(Status reason) { return; } - state.winningSubstream.stream.cancel(reason); + Substream winningSubstreamToCancel = null; synchronized (lock) { - // This is not required, but causes a short-circuit in the draining process. + if (state.drainedSubstreams.contains(state.winningSubstream)) { + winningSubstreamToCancel = state.winningSubstream; + } else { // the winningSubstream will be cancelled while draining + cancellationStatus = reason; + } state = state.cancelled(); } + if (winningSubstreamToCancel != null) { + winningSubstreamToCancel.stream.cancel(reason); + } } private void delayOrExecute(BufferEntry bufferEntry) { From 0e7e0b4f57c95e6130a3e5cdca6a6651ff37399c Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 6 Aug 2021 09:10:26 -0700 Subject: [PATCH 16/82] api: Clarify Server APIs can be used before start() Fixes #8349 --- api/src/main/java/io/grpc/Server.java | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/io/grpc/Server.java b/api/src/main/java/io/grpc/Server.java index 781455b18ee..31e0a6478ed 100644 --- a/api/src/main/java/io/grpc/Server.java +++ b/api/src/main/java/io/grpc/Server.java @@ -43,7 +43,7 @@ public abstract class Server { * listening socket(s). * * @return {@code this} object - * @throws IllegalStateException if already started + * @throws IllegalStateException if already started or shut down * @throws IOException if unable to bind * @since 1.0.0 */ @@ -119,6 +119,9 @@ public List getMutableServices() { * {@link #awaitTermination()} or {@link #awaitTermination(long, TimeUnit)} needs to be called to * wait for existing calls to finish. * + *

Calling this method before {@code start()} will shut down and terminate the server like + * normal, but prevents starting the server in the future. + * * @return {@code this} object * @since 1.0.0 */ @@ -130,6 +133,9 @@ public List getMutableServices() { * return {@code false} immediately after this method returns. After this call returns, this * server has released the listening socket(s) and may be reused by another server. * + *

Calling this method before {@code start()} will shut down and terminate the server like + * normal, but prevents starting the server in the future. + * * @return {@code this} object * @since 1.0.0 */ @@ -157,6 +163,9 @@ public List getMutableServices() { /** * Waits for the server to become terminated, giving up if the timeout is reached. * + *

Calling this method before {@code start()} or {@code shutdown()} is permitted and does not + * change its behavior. + * * @return whether the server is terminated, as would be done by {@link #isTerminated()}. */ public abstract boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException; @@ -164,6 +173,9 @@ public List getMutableServices() { /** * Waits for the server to become terminated. * + *

Calling this method before {@code start()} or {@code shutdown()} is permitted and does not + * change its behavior. + * * @since 1.0.0 */ public abstract void awaitTermination() throws InterruptedException; From 7942f35c477c2aa91654b19541960b1ba7c44396 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 6 Aug 2021 11:47:35 -0700 Subject: [PATCH 17/82] binder: Disable flaky SecurityPolicy tests Not using `@Ignore` because the tests can probably run successfully under Bazel. See #8391 --- binder/build.gradle | 2 ++ 1 file changed, 2 insertions(+) diff --git a/binder/build.gradle b/binder/build.gradle index 537c23a0092..c5bb9885623 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -13,6 +13,8 @@ android { srcDirs += "${projectDir}/../core/src/test/java/" setIncludes(["io/grpc/internal/FakeClock.java", "io/grpc/binder/**"]) + exclude 'io/grpc/binder/ServerSecurityPolicyTest.java' + exclude 'io/grpc/binder/SecurityPoliciesTest.java' } } androidTest { From 20ac1999d406a1cf3148aa8003cb6cfeaf19b986 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 6 Aug 2021 14:14:17 -0700 Subject: [PATCH 18/82] stub: Mark Stub-based MetadataUtils methods deprecated We don't want other APIs to copy the stub-based API to attach the interceptor. The API has a shorter name, but isn't actually all that easier to use and isn't fluent like using the interceptor API. These are _very_ old methods, so we won't be quick to delete them. Seems we should have them deprecated at least a year or two; they are easy to maintain in the mean time. See API Review notes in #1789 --- .../integration/AbstractInteropTest.java | 26 +++++++++---------- .../main/java/io/grpc/stub/MetadataUtils.java | 4 +++ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 1b447a63c32..8a6e41722ab 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -1043,19 +1043,18 @@ public void veryLargeResponse() throws Exception { @Test public void exchangeMetadataUnaryCall() throws Exception { - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub; - // Capture the metadata exchange Metadata fixedHeaders = new Metadata(); // Send a context proto (as it's in the default extension registry) Messages.SimpleContext contextValue = Messages.SimpleContext.newBuilder().setValue("dog").build(); fixedHeaders.put(Util.METADATA_KEY, contextValue); - stub = MetadataUtils.attachHeaders(stub, fixedHeaders); // .. and expect it to be echoed back in trailers AtomicReference trailersCapture = new AtomicReference<>(); AtomicReference headersCapture = new AtomicReference<>(); - stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(fixedHeaders), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); assertNotNull(stub.emptyCall(EMPTY)); @@ -1066,19 +1065,18 @@ public void exchangeMetadataUnaryCall() throws Exception { @Test public void exchangeMetadataStreamingCall() throws Exception { - TestServiceGrpc.TestServiceStub stub = asyncStub; - // Capture the metadata exchange Metadata fixedHeaders = new Metadata(); // Send a context proto (as it's in the default extension registry) Messages.SimpleContext contextValue = Messages.SimpleContext.newBuilder().setValue("dog").build(); fixedHeaders.put(Util.METADATA_KEY, contextValue); - stub = MetadataUtils.attachHeaders(stub, fixedHeaders); // .. and expect it to be echoed back in trailers AtomicReference trailersCapture = new AtomicReference<>(); AtomicReference headersCapture = new AtomicReference<>(); - stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceStub stub = asyncStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(fixedHeaders), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); List responseSizes = Arrays.asList(50, 100, 150, 200); Messages.StreamingOutputCallRequest.Builder streamingOutputBuilder = @@ -1490,11 +1488,11 @@ public void customMetadata() throws Exception { Metadata metadata = new Metadata(); metadata.put(Util.ECHO_INITIAL_METADATA_KEY, "test_initial_metadata_value"); metadata.put(Util.ECHO_TRAILING_METADATA_KEY, trailingBytes); - TestServiceGrpc.TestServiceBlockingStub blockingStub = this.blockingStub; - blockingStub = MetadataUtils.attachHeaders(blockingStub, metadata); AtomicReference headersCapture = new AtomicReference<>(); AtomicReference trailersCapture = new AtomicReference<>(); - blockingStub = MetadataUtils.captureMetadata(blockingStub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceBlockingStub blockingStub = this.blockingStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(metadata), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); SimpleResponse response = blockingStub.unaryCall(request); assertResponse(goldenResponse, response); @@ -1509,11 +1507,11 @@ public void customMetadata() throws Exception { metadata = new Metadata(); metadata.put(Util.ECHO_INITIAL_METADATA_KEY, "test_initial_metadata_value"); metadata.put(Util.ECHO_TRAILING_METADATA_KEY, trailingBytes); - TestServiceGrpc.TestServiceStub stub = asyncStub; - stub = MetadataUtils.attachHeaders(stub, metadata); headersCapture = new AtomicReference<>(); trailersCapture = new AtomicReference<>(); - stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceStub stub = asyncStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(metadata), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); StreamRecorder recorder = StreamRecorder.create(); StreamObserver requestStream = diff --git a/stub/src/main/java/io/grpc/stub/MetadataUtils.java b/stub/src/main/java/io/grpc/stub/MetadataUtils.java index 0fedf3711f7..94dfb8e56ee 100644 --- a/stub/src/main/java/io/grpc/stub/MetadataUtils.java +++ b/stub/src/main/java/io/grpc/stub/MetadataUtils.java @@ -43,8 +43,10 @@ private MetadataUtils() {} * @param stub to bind the headers to. * @param extraHeaders the headers to be passed by each call on the returned stub. * @return an implementation of the stub with {@code extraHeaders} bound to each call. + * @deprecated Use {@code stub.withInterceptors(newAttachHeadersInterceptor(...))} instead. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1789") + @Deprecated public static > T attachHeaders(T stub, Metadata extraHeaders) { return stub.withInterceptors(newAttachHeadersInterceptor(extraHeaders)); } @@ -98,8 +100,10 @@ public void start(Listener responseListener, Metadata headers) { * @param trailersCapture to record the last received trailers * @return an implementation of the stub that allows to access the last received call's * headers and trailers via {@code headersCapture} and {@code trailersCapture}. + * @deprecated Use {@code stub.withInterceptors(newCaptureMetadataInterceptor())} instead. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1789") + @Deprecated public static > T captureMetadata( T stub, AtomicReference headersCapture, From cbda32a3c1c8e253d8d78afef4b5f934f0db1826 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Fri, 6 Aug 2021 18:32:55 -0700 Subject: [PATCH 19/82] core: fix RetriableStream edge case bug introduced in #8386 (#8393) While adding regression tests to #8386, I found a bug in an edge case: while retry attempt is draining the last buffered entry, if it is in the mean time committed and then we cancel the call, the stream will never be cancelled. See the regression test case `commitAndCancelWhileDraining()`. --- .../io/grpc/internal/RetriableStream.java | 38 +++--- .../io/grpc/internal/RetriableStreamTest.java | 126 ++++++++++++++++++ 2 files changed, 146 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 396c7cedfe2..d19a260049b 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -252,10 +252,14 @@ private void drain(Substream substream) { synchronized (lock) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream - && streamStarted) { - // committed but not me, to be cancelled - break; + if (streamStarted) { + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; + } } if (index == savedState.buffer.size()) { // I'm drained state = savedState.substreamDrained(substream); @@ -277,27 +281,25 @@ private void drain(Substream substream) { } for (BufferEntry bufferEntry : list) { - savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream - && streamStarted) { - // committed but not me, to be cancelled - break; - } - if (savedState.cancelled && streamStarted) { - checkState( - savedState.winningSubstream == substream, - "substream should be CANCELLED_BECAUSE_COMMITTED already"); - substream.stream.cancel(cancellationStatus); - return; - } bufferEntry.runWith(substream); if (bufferEntry instanceof RetriableStream.StartEntry) { streamStarted = true; } + if (streamStarted) { + savedState = state; + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; + } + } } } - substream.stream.cancel(CANCELLED_BECAUSE_COMMITTED); + substream.stream.cancel( + state.winningSubstream == substream ? cancellationStatus : CANCELLED_BECAUSE_COMMITTED); } /** diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 26c6fcf9b4e..95d2c2ba8b5 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -749,6 +749,91 @@ public void request(int numMessages) { inOrder.verify(mockStream2, never()).writeMessage(any(InputStream.class)); } + @Test + public void cancelWhileDraining() { + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void request(int numMessages) { + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while requesting")); + } + })); + + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + retriableStream.request(3); + inOrder.verify(mockStream1).request(3); + + // retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(mockStream2).request(3); + inOrder.verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()) + .isEqualTo("Stream thrown away because RetriableStream committed"); + verify(masterListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while requesting"); + } + + @Test + public void cancelWhileRetryStart() { + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void start(ClientStreamListener listener) { + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while retry start")); + } + })); + + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + + // retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()) + .isEqualTo("Stream thrown away because RetriableStream committed"); + verify(masterListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while retry start"); + } + @Test public void operationsAfterImmediateCommit() { ArgumentCaptor sublistenerCaptor1 = @@ -916,6 +1001,47 @@ public void start(ClientStreamListener listener) { verify(mockStream3).request(1); } + @Test + public void commitAndCancelWhileDraining() { + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void start(ClientStreamListener listener) { + // commit while draining + listener.headersRead(new Metadata()); + // cancel while draining + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while drained")); + } + })); + + when(retriableStreamRecorder.newSubstream(anyInt())) + .thenReturn(mockStream1, mockStream2); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + ClientStreamListener listener1 = sublistenerCaptor1.getValue(); + + // retry + listener1.closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + verify(mockStream2).start(any(ClientStreamListener.class)); + verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while drained"); + } @Test public void perRpcBufferLimitExceeded() { From bb06739cd735bf29f8076dd2217056aab3e980f7 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Mon, 9 Aug 2021 09:32:36 -0700 Subject: [PATCH 20/82] xds: refactor xdsServer wrapper, modify filter chain matching handler for server routing config (#8333) This is split from #8318, refactoring changes include: 1. FilterChainMatchingHandler 1.1. Previously filter chain match is built-in in XdsServerCredential for xdsServer. (But it does not have to be XdsServerCredential.) The protocol negotiator associated with the XdsServerCredential does the filter chain match computation. Now filter chain match is through a FilterChainMatchingHandler and it always run. As a result, it sets attributes of sslContextProviderSupplier from xds config in protocol negotiation event. 1.2. The previous protocol negotiator associated with the XdsServerCredential is modified to just lookup the config in the attribute set above and decide to use xds config credential or fallback credential. 1.3. Previously credential is a must in XdsBuilder. Now credential becomes optional to allow routing config to be fetched. Xds TCP listener update will always be used to run filter chain match. Later, we will add routing config in filter chain match and apply http filter configs by installing ConfigApplyingInterceptor. 2. Removed xdsClientWrapperForServerXds, unnecessarily complicated. 3. Changed event attribute key. Previously filter chain matching happens in the xdsClientWrapperForServerXds, the xds client wrapper is passed to negotiation handler via attributes to allow protocol negotiator to trigger the filter chain matching computation. Now the attributes becomes an atomic config selector reference that xdsServerWrapper will inject by watching xds resources updates via xds client. 4. Previously there are multiple server states enum in xdsServerWrapper, this is removed because it is unnecessarily complicated. But there are still isServing status to avoid re-start delegate upon listener update. 5. Previously xdsServerWrapper ignores any xds updates once initial started, now we allow dynamic update to happen even if server is up. This is done via updating config selector atomic reference upon listener update. 6. Previously xdsServerWrapper synchronizes on the server object, this is modified to syncContext to be more manageable. --- xds/build.gradle | 2 +- .../io/grpc/xds/EnvoyServerProtoData.java | 2 +- ...ilterChainMatchingProtocolNegotiators.java | 372 +++++ .../io/grpc/xds/InternalXdsAttributes.java | 10 + .../xds/XdsClientWrapperForServerSds.java | 455 ------ .../java/io/grpc/xds/XdsServerBuilder.java | 45 +- .../java/io/grpc/xds/XdsServerWrapper.java | 444 ++++++ .../internal/sds/SdsProtocolNegotiators.java | 26 +- .../xds/internal/sds/ServerWrapperForXds.java | 368 ----- .../io/grpc/xds/FilterChainMatchTest.java | 941 ------------ ...rChainMatchingProtocolNegotiatorsTest.java | 1318 +++++++++++++++++ .../io/grpc/xds/ServerWrapperForXdsTest.java | 320 ---- .../XdsClientWrapperForServerSdsTestMisc.java | 458 ++++-- .../io/grpc/xds/XdsSdsClientServerTest.java | 175 ++- .../io/grpc/xds/XdsServerBuilderTest.java | 102 +- .../java/io/grpc/xds/XdsServerTestHelper.java | 215 +-- .../io/grpc/xds/XdsServerWrapperTest.java | 456 ++++++ .../sds/SdsProtocolNegotiatorsTest.java | 88 +- 18 files changed, 3207 insertions(+), 2590 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java delete mode 100644 xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java create mode 100644 xds/src/main/java/io/grpc/xds/XdsServerWrapper.java delete mode 100644 xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java delete mode 100644 xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java create mode 100644 xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java delete mode 100644 xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java create mode 100644 xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java diff --git a/xds/build.gradle b/xds/build.gradle index ae8d8d208a9..fdee8fab203 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -97,9 +97,9 @@ javadoc { exclude 'io/grpc/xds/*LoadBalancer*' exclude 'io/grpc/xds/Bootstrapper.java' exclude 'io/grpc/xds/Envoy*' + exclude 'io/grpc/xds/FilterChainMatchingProtocolNegotiators.java' exclude 'io/grpc/xds/TlsContextManager.java' exclude 'io/grpc/xds/XdsAttributes.java' - exclude 'io/grpc/xds/XdsClientWrapperForServerSds.java' exclude 'io/grpc/xds/XdsInitializationException.java' exclude 'io/grpc/xds/XdsNameResolverProvider.java' exclude 'io/grpc/xds/internal/**' diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index f05e6dd6c9a..aa53d834d3b 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -313,7 +313,7 @@ public String toString() { /** * Corresponds to Envoy proto message {@link io.envoyproxy.envoy.api.v2.listener.FilterChain}. */ - public static final class FilterChain { + static final class FilterChain { // Unique name for the FilterChain. private final String name; // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java new file mode 100644 index 00000000000..34211d79751 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -0,0 +1,372 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_REF; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.common.collect.Iterables; +import com.google.protobuf.UInt32Value; +import io.grpc.Attributes; +import io.grpc.internal.ObjectPool; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.ProtocolNegotiationEvent; +import io.grpc.xds.EnvoyServerProtoData.CidrRange; +import io.grpc.xds.EnvoyServerProtoData.ConnectionSourceType; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.internal.Matchers.CidrMatcher; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.AsciiString; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + + +/** + * Handles L4 filter chain match for the connection based on the xds configuration. + * */ +final class FilterChainMatchingProtocolNegotiators { + private static final Logger log = Logger.getLogger( + FilterChainMatchingProtocolNegotiators.class.getName()); + + private static final AsciiString SCHEME = AsciiString.of("http"); + + private FilterChainMatchingProtocolNegotiators() { + } + + @VisibleForTesting + static final class FilterChainMatchingHandler extends ChannelInboundHandlerAdapter { + + private final GrpcHttp2ConnectionHandler grpcHandler; + private final FilterChainSelector selector; + private final ProtocolNegotiator delegate; + + FilterChainMatchingHandler( + GrpcHttp2ConnectionHandler grpcHandler, FilterChainSelector selector, + ProtocolNegotiator delegate) { + this.grpcHandler = checkNotNull(grpcHandler, "grpcHandler"); + this.selector = checkNotNull(selector, "selector"); + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (!(evt instanceof ProtocolNegotiationEvent)) { + super.userEventTriggered(ctx, evt); + return; + } + SelectedConfig config; + try { + config = selector.select( + (InetSocketAddress) ctx.channel().localAddress(), + (InetSocketAddress) ctx.channel().remoteAddress()); + } catch (IllegalStateException ex) { + log.log(Level.FINE, "Did not find exactly one filter chain: " + ex.getMessage()); + ctx.fireExceptionCaught(ex); + return; + } + ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent) evt; + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(pne) + .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, + config.sslContextProviderSupplier).build(); + pne = InternalProtocolNegotiationEvent.withAttributes(pne, attr); + ctx.pipeline().replace(this, null, delegate.newHandler(grpcHandler)); + ctx.fireUserEventTriggered(pne); + } + + static final class FilterChainSelector { + public static final FilterChainSelector NO_FILTER_CHAIN = new FilterChainSelector( + Collections.emptyList(), null); + + private final List filterChainList; + @Nullable + private final SslContextProviderSupplier defaultSslContextProviderSupplier; + + FilterChainSelector(List filterChainList, + @Nullable SslContextProviderSupplier defaultSslContextProviderSupplier) { + checkNotNull(filterChainList, "filterChainList"); + this.filterChainList = filterChainList; + this.defaultSslContextProviderSupplier = defaultSslContextProviderSupplier; + } + + @VisibleForTesting + List getFilterChains() { + return filterChainList; + } + + @VisibleForTesting + SslContextProviderSupplier getDefaultSslContextProviderSupplier() { + return defaultSslContextProviderSupplier; + } + + /** + * Throws IllegalStateException when no exact one match, and we should close the connection. + */ + SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) { + Collection filterChains = new ArrayList<>(filterChainList); + filterChains = filterOnDestinationPort(filterChains); + filterChains = filterOnIpAddress(filterChains, localAddr.getAddress(), true); + filterChains = filterOnServerNames(filterChains); + filterChains = filterOnTransportProtocol(filterChains); + filterChains = filterOnApplicationProtocols(filterChains); + filterChains = + filterOnSourceType(filterChains, remoteAddr.getAddress(), localAddr.getAddress()); + filterChains = filterOnIpAddress(filterChains, remoteAddr.getAddress(), false); + filterChains = filterOnSourcePort(filterChains, remoteAddr.getPort()); + + if (filterChains.size() > 1) { + log.log(Level.FINE, "Found more than one matching filter chains: {0}", filterChains); + throw new IllegalStateException("Found more than one matching filter chains."); + // TODO(chengyuanzhang): should we just return any matched one? + } + if (filterChains.size() == 1) { + FilterChain selected = Iterables.getOnlyElement(filterChains); + return new SelectedConfig(selected.getSslContextProviderSupplier()); + } + if (defaultSslContextProviderSupplier != null) { + return new SelectedConfig(defaultSslContextProviderSupplier); + } + log.log(Level.FINE, "No matching filter chain found."); + throw new IllegalStateException("No matching filter chain found."); + } + + // reject if filer-chain-match has non-empty application_protocols + private static Collection filterOnApplicationProtocols( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getApplicationProtocols().isEmpty()) { + filtered.add(filterChain); + } + } + return filtered; + } + + // reject if filer-chain-match has non-empty transport protocol other than "raw_buffer" + private static Collection filterOnTransportProtocol( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + String transportProtocol = filterChainMatch.getTransportProtocol(); + if (Strings.isNullOrEmpty(transportProtocol) || "raw_buffer".equals(transportProtocol)) { + filtered.add(filterChain); + } + } + return filtered; + } + + // reject if filer-chain-match has server_name(s) + private static Collection filterOnServerNames( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getServerNames().isEmpty()) { + filtered.add(filterChain); + } + } + return filtered; + } + + // destination_port present => Always fail match + private static Collection filterOnDestinationPort( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getDestinationPort() + == UInt32Value.getDefaultInstance().getValue()) { + filtered.add(filterChain); + } + } + return filtered; + } + + private static Collection filterOnSourcePort( + Collection filterChains, int sourcePort) { + ArrayList filteredOnMatch = new ArrayList<>(filterChains.size()); + ArrayList filteredOnEmpty = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + List sourcePortsToMatch = filterChainMatch.getSourcePorts(); + if (sourcePortsToMatch.isEmpty()) { + filteredOnEmpty.add(filterChain); + } else if (sourcePortsToMatch.contains(sourcePort)) { + filteredOnMatch.add(filterChain); + } + } + // match against source port is more specific than match against empty list + return filteredOnMatch.isEmpty() ? filteredOnEmpty : filteredOnMatch; + } + + private static Collection filterOnSourceType( + Collection filterChains, InetAddress sourceAddress, + InetAddress destAddress) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + ConnectionSourceType sourceType = + filterChainMatch.getConnectionSourceType(); + + boolean matching = false; + if (sourceType == ConnectionSourceType.SAME_IP_OR_LOOPBACK) { + matching = + sourceAddress.isLoopbackAddress() + || sourceAddress.isAnyLocalAddress() + || sourceAddress.equals(destAddress); + } else if (sourceType == ConnectionSourceType.EXTERNAL) { + matching = !sourceAddress.isLoopbackAddress() && !sourceAddress.isAnyLocalAddress(); + } else { // ANY or null + matching = true; + } + if (matching) { + filtered.add(filterChain); + } + } + return filtered; + } + + private static int getMatchingPrefixLength( + FilterChainMatch filterChainMatch, InetAddress address, boolean forDestination) { + boolean isIPv6 = address instanceof Inet6Address; + List cidrRanges = + forDestination + ? filterChainMatch.getPrefixRanges() + : filterChainMatch.getSourcePrefixRanges(); + int matchingPrefixLength; + if (cidrRanges.isEmpty()) { // if there is no CidrRange assume 0-length match + matchingPrefixLength = 0; + } else { + matchingPrefixLength = -1; + for (CidrRange cidrRange : cidrRanges) { + InetAddress cidrAddr = cidrRange.getAddressPrefix(); + boolean cidrIsIpv6 = cidrAddr instanceof Inet6Address; + if (isIPv6 == cidrIsIpv6) { + int prefixLen = cidrRange.getPrefixLen(); + CidrMatcher matcher = CidrMatcher.create(cidrAddr, prefixLen); + if (matcher.matches(address) && prefixLen > matchingPrefixLength) { + matchingPrefixLength = prefixLen; + } + } + } + } + return matchingPrefixLength; + } + + // use prefix_ranges (CIDR) and get the most specific matches + private static Collection filterOnIpAddress( + Collection filterChains, InetAddress address, boolean forDestination) { + // curent list of top ones + ArrayList topOnes = new ArrayList<>(filterChains.size()); + int topMatchingPrefixLen = -1; + for (FilterChain filterChain : filterChains) { + int currentMatchingPrefixLen = getMatchingPrefixLength( + filterChain.getFilterChainMatch(), address, forDestination); + + if (currentMatchingPrefixLen >= 0) { + if (currentMatchingPrefixLen < topMatchingPrefixLen) { + continue; + } + if (currentMatchingPrefixLen > topMatchingPrefixLen) { + topMatchingPrefixLen = currentMatchingPrefixLen; + topOnes.clear(); + } + topOnes.add(filterChain); + } + } + return topOnes; + } + } + } + + static final class FilterChainMatchingNegotiatorServerFactory + implements InternalProtocolNegotiator.ServerFactory { + private final InternalProtocolNegotiator.ServerFactory delegate; + + public FilterChainMatchingNegotiatorServerFactory( + InternalProtocolNegotiator.ServerFactory delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public ProtocolNegotiator newNegotiator( + final ObjectPool offloadExecutorPool) { + + class FilterChainMatchingNegotiator implements ProtocolNegotiator { + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + AtomicReference filterChainSelectorRef = + grpcHandler.getEagAttributes().get(ATTR_FILTER_CHAIN_SELECTOR_REF); + checkNotNull(filterChainSelectorRef, "filterChainSelectorRef"); + return new FilterChainMatchingHandler(grpcHandler, filterChainSelectorRef.get(), + delegate.newNegotiator(offloadExecutorPool)); + } + + @Override + public void close() { + } + } + + return new FilterChainMatchingNegotiator(); + } + } + + /** + * The FilterChain level configuration. + */ + private static final class SelectedConfig { + @Nullable + private final SslContextProviderSupplier sslContextProviderSupplier; + + private SelectedConfig(@Nullable SslContextProviderSupplier sslContextProviderSupplier) { + this.sslContextProviderSupplier = sslContextProviderSupplier; + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index efeee0758a3..82eddd355af 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -22,8 +22,10 @@ import io.grpc.Internal; import io.grpc.NameResolver; import io.grpc.internal.ObjectPool; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import java.util.concurrent.atomic.AtomicReference; /** * Internal attributes used for xDS implementation. Do not use. @@ -75,5 +77,13 @@ public final class InternalXdsAttributes { static final Attributes.Key ATTR_SERVER_WEIGHT = Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight"); + /** + * Filter chain match for network filters. + */ + @Grpc.TransportAttr + static final Attributes.Key> + ATTR_FILTER_CHAIN_SELECTOR_REF = Attributes.Key.create( + "io.grpc.xds.InternalXdsAttributes.filterChainSelectorRef"); + private InternalXdsAttributes() {} } diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java deleted file mode 100644 index 9a1b659ef1e..00000000000 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ /dev/null @@ -1,455 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableSet; -import com.google.protobuf.UInt32Value; -import io.grpc.Internal; -import io.grpc.Status; -import io.grpc.internal.ObjectPool; -import io.grpc.xds.EnvoyServerProtoData.CidrRange; -import io.grpc.xds.EnvoyServerProtoData.FilterChain; -import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; -import io.grpc.xds.internal.Matchers.CidrMatcher; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.netty.channel.Channel; -import java.net.Inet6Address; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.annotation.Nullable; - -/** - * Serves as a wrapper for {@link XdsClient} used on the server side by {@link - * XdsServerBuilder}. - */ -@Internal -public final class XdsClientWrapperForServerSds { - private static final Logger logger = - Logger.getLogger(XdsClientWrapperForServerSds.class.getName()); - - private AtomicReference curListener = new AtomicReference<>(); - private ObjectPool xdsClientPool; - private final XdsNameResolverProvider.XdsClientPoolFactory xdsClientPoolFactory; - @Nullable private XdsClient xdsClient; - private final int port; - private XdsClient.LdsResourceWatcher listenerWatcher; - private boolean newServerApi; - private String grpcServerResourceId; - @VisibleForTesting final Set serverWatchers = new HashSet<>(); - - /** - * Creates a {@link XdsClientWrapperForServerSds}. - * - * @param port server's port for which listener config is needed. - */ - XdsClientWrapperForServerSds(int port) { - this(port, SharedXdsClientPoolProvider.getDefaultProvider()); - } - - @VisibleForTesting - XdsClientWrapperForServerSds(int port, - XdsNameResolverProvider.XdsClientPoolFactory xdsClientPoolFactory) { - this.port = port; - this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); - } - - @VisibleForTesting XdsClient getXdsClient() { - return xdsClient; - } - - public TlsContextManager getTlsContextManager() { - return xdsClient.getTlsContextManager(); - } - - /** Accepts an XdsClient and starts a watch. */ - @VisibleForTesting - public void start() { - try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(); - } catch (Exception e) { - reportError(e, true); - return; - } - xdsClient = xdsClientPool.getObject(); - this.listenerWatcher = - new XdsClient.LdsResourceWatcher() { - @Override - public void onChanged(XdsClient.LdsUpdate update) { - releaseOldSuppliers(curListener.getAndSet(update.listener())); - reportSuccess(); - } - - @Override - public void onResourceDoesNotExist(String resourceName) { - logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName); - releaseOldSuppliers(curListener.getAndSet(null)); - reportError(Status.NOT_FOUND.asException(), true); - } - - @Override - public void onError(Status error) { - logger.log( - Level.WARNING, "LdsResourceWatcher in XdsClientWrapperForServerSds: {0}", error); - if (isResourceAbsent(error)) { - releaseOldSuppliers(curListener.getAndSet(null)); - reportError(error.asException(), true); - } else { - reportError(error.asException(), false); - } - } - }; - newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); - if (!newServerApi) { - reportError( - new XdsInitializationException( - "requires use of xds_v3 in xds bootstrap"), - true); - return; - } - grpcServerResourceId = xdsClient.getBootstrapInfo() - .getServerListenerResourceNameTemplate(); - if (grpcServerResourceId == null) { - reportError( - new XdsInitializationException( - "missing server_listener_resource_name_template value in xds bootstrap"), - true); - return; - } - grpcServerResourceId = grpcServerResourceId.replaceAll("%s", "0.0.0.0:" + port); - xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher); - } - - // go thru the old listener and release all the old SslContextProviderSupplier - private void releaseOldSuppliers(EnvoyServerProtoData.Listener oldListener) { - if (oldListener != null) { - List filterChains = oldListener.getFilterChains(); - for (FilterChain filterChain : filterChains) { - releaseSupplier(filterChain); - } - releaseSupplier(oldListener.getDefaultFilterChain()); - } - } - - private static void releaseSupplier(FilterChain filterChain) { - if (filterChain != null) { - SslContextProviderSupplier sslContextProviderSupplier = - filterChain.getSslContextProviderSupplier(); - if (sslContextProviderSupplier != null) { - sslContextProviderSupplier.close(); - } - } - } - - /** Whether the throwable indicates our listener resource is absent/deleted. */ - private static boolean isResourceAbsent(Status status) { - Status.Code code = status.getCode(); - switch (code) { - case NOT_FOUND: - case INVALID_ARGUMENT: - case PERMISSION_DENIED: // means resource not available for us - case UNIMPLEMENTED: - case UNAUTHENTICATED: // same as above, resource not available for us - return true; - default: - return false; - } - } - - /** - * Locates the best matching FilterChain to the channel from the current listener and if found - * returns the SslContextProviderSupplier from that FilterChain, else null. - */ - @Nullable - public SslContextProviderSupplier getSslContextProviderSupplier(Channel channel) { - EnvoyServerProtoData.Listener copyListener = curListener.get(); - if (copyListener != null && channel != null) { - SocketAddress localAddress = channel.localAddress(); - SocketAddress remoteAddress = channel.remoteAddress(); - if (localAddress instanceof InetSocketAddress && remoteAddress instanceof InetSocketAddress) { - InetSocketAddress localInetAddr = (InetSocketAddress) localAddress; - InetSocketAddress remoteInetAddr = (InetSocketAddress) remoteAddress; - checkState( - port == localInetAddr.getPort(), - "Channel localAddress port does not match requested listener port"); - return getSslContextProviderSupplier(localInetAddr, remoteInetAddr, copyListener); - } - } - return null; - } - - /** - * Using the logic specified at - * https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/listener/listener_components.proto.html?highlight=filter%20chain#listener-filterchainmatch - * locate a matching filter and return the corresponding SslContextProviderSupplier or else - * return one from default filter chain. - * - * @param localInetAddr dest address of the inbound connection - * @param remoteInetAddr source address of the inbound connection - */ - private static SslContextProviderSupplier getSslContextProviderSupplier( - InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr, - EnvoyServerProtoData.Listener listener) { - List filterChains = listener.getFilterChains(); - - filterChains = filterOnDestinationPort(filterChains); - filterChains = filterOnIpAddress(filterChains, localInetAddr.getAddress(), true); - filterChains = filterOnServerNames(filterChains); - filterChains = filterOnTransportProtocol(filterChains); - filterChains = filterOnApplicationProtocols(filterChains); - filterChains = - filterOnSourceType(filterChains, remoteInetAddr.getAddress(), localInetAddr.getAddress()); - filterChains = filterOnIpAddress(filterChains, remoteInetAddr.getAddress(), false); - filterChains = filterOnSourcePort(filterChains, remoteInetAddr.getPort()); - - if (filterChains.size() > 1) { - // close the connection - throw new IllegalStateException("Found 2 matching filter-chains"); - } else if (filterChains.size() == 1) { - return filterChains.get(0).getSslContextProviderSupplier(); - } - if (listener.getDefaultFilterChain() == null) { - // close the connection - throw new RuntimeException( - "no matching filter chain. local: " + localInetAddr + " remote: " + remoteInetAddr); - } - return listener.getDefaultFilterChain().getSslContextProviderSupplier(); - } - - // reject if filer-chain-match has non-empty application_protocols - private static List filterOnApplicationProtocols(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - if (filterChainMatch.getApplicationProtocols().isEmpty()) { - filtered.add(filterChain); - } - } - return filtered; - } - - // reject if filer-chain-match has non-empty transport protocol other than "raw_buffer" - private static List filterOnTransportProtocol(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - String transportProtocol = filterChainMatch.getTransportProtocol(); - if ( Strings.isNullOrEmpty(transportProtocol) || "raw_buffer".equals(transportProtocol)) { - filtered.add(filterChain); - } - } - return filtered; - } - - // reject if filer-chain-match has server_name(s) - private static List filterOnServerNames(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - if (filterChainMatch.getServerNames().isEmpty()) { - filtered.add(filterChain); - } - } - return filtered; - } - - // destination_port present => Always fail match - private static List filterOnDestinationPort(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - if (filterChainMatch.getDestinationPort() == UInt32Value.getDefaultInstance().getValue()) { - filtered.add(filterChain); - } - } - return filtered; - } - - private static List filterOnSourcePort( - List filterChains, int sourcePort) { - ArrayList filteredOnMatch = new ArrayList<>(filterChains.size()); - ArrayList filteredOnEmpty = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - List sourcePortsToMatch = filterChainMatch.getSourcePorts(); - if (sourcePortsToMatch.isEmpty()) { - filteredOnEmpty.add(filterChain); - } else if (sourcePortsToMatch.contains(sourcePort)) { - filteredOnMatch.add(filterChain); - } - } - // match against source port is more specific than match against empty list - return filteredOnMatch.isEmpty() ? filteredOnEmpty : filteredOnMatch; - } - - private static List filterOnSourceType( - List filterChains, InetAddress sourceAddress, InetAddress destAddress) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - EnvoyServerProtoData.ConnectionSourceType sourceType = - filterChainMatch.getConnectionSourceType(); - - boolean matching = false; - if (sourceType == EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK) { - matching = - sourceAddress.isLoopbackAddress() - || sourceAddress.isAnyLocalAddress() - || sourceAddress.equals(destAddress); - } else if (sourceType == EnvoyServerProtoData.ConnectionSourceType.EXTERNAL) { - matching = !sourceAddress.isLoopbackAddress() && !sourceAddress.isAnyLocalAddress(); - } else { // ANY or null - matching = true; - } - if (matching) { - filtered.add(filterChain); - } - } - return filtered; - } - - private static int getMatchingPrefixLength( - FilterChainMatch filterChainMatch, InetAddress address, boolean forDestination) { - boolean isIPv6 = address instanceof Inet6Address; - List cidrRanges = - forDestination - ? filterChainMatch.getPrefixRanges() - : filterChainMatch.getSourcePrefixRanges(); - int matchingPrefixLength; - if (cidrRanges.isEmpty()) { // if there is no CidrRange assume 0-length match - matchingPrefixLength = 0; - } else { - matchingPrefixLength = -1; - for (CidrRange cidrRange : cidrRanges) { - InetAddress cidrAddr = cidrRange.getAddressPrefix(); - boolean cidrIsIpv6 = cidrAddr instanceof Inet6Address; - if (isIPv6 == cidrIsIpv6) { - int prefixLen = cidrRange.getPrefixLen(); - CidrMatcher matcher = CidrMatcher.create(cidrAddr, prefixLen); - if (matcher.matches(address) && prefixLen > matchingPrefixLength) { - matchingPrefixLength = prefixLen; - } - } - } - } - return matchingPrefixLength; - } - - // use prefix_ranges (CIDR) and get the most specific matches - private static List filterOnIpAddress( - List filterChains, InetAddress address, boolean forDestination) { - // curent list of top ones - ArrayList topOnes = new ArrayList<>(filterChains.size()); - int topMatchingPrefixLen = -1; - for (FilterChain filterChain : filterChains) { - int currentMatchingPrefixLen = - getMatchingPrefixLength(filterChain.getFilterChainMatch(), address, forDestination); - - if (currentMatchingPrefixLen >= 0) { - if (currentMatchingPrefixLen < topMatchingPrefixLen) { - continue; - } - if (currentMatchingPrefixLen > topMatchingPrefixLen) { - topMatchingPrefixLen = currentMatchingPrefixLen; - topOnes.clear(); - } - topOnes.add(filterChain); - } - } - return topOnes; - } - - /** Adds a {@link ServerWatcher} to the list. */ - public void addServerWatcher(ServerWatcher serverWatcher) { - checkNotNull(serverWatcher, "serverWatcher"); - synchronized (serverWatchers) { - serverWatchers.add(serverWatcher); - } - EnvoyServerProtoData.Listener copyListener = curListener.get(); - if (copyListener != null) { - serverWatcher.onListenerUpdate(); - } - } - - /** Removes a {@link ServerWatcher} from the list. */ - public void removeServerWatcher(ServerWatcher serverWatcher) { - checkNotNull(serverWatcher, "serverWatcher"); - synchronized (serverWatchers) { - serverWatchers.remove(serverWatcher); - } - } - - private Set getServerWatchers() { - synchronized (serverWatchers) { - return ImmutableSet.copyOf(serverWatchers); - } - } - - private void reportError(Throwable throwable, boolean isAbsent) { - for (ServerWatcher watcher : getServerWatchers()) { - watcher.onError(throwable, isAbsent); - } - } - - private void reportSuccess() { - for (ServerWatcher watcher : getServerWatchers()) { - watcher.onListenerUpdate(); - } - } - - @VisibleForTesting - public XdsClient.LdsResourceWatcher getListenerWatcher() { - return listenerWatcher; - } - - /** Watcher interface for the clients of this class. */ - public interface ServerWatcher { - - /** Called to report errors from the control plane including "not found". */ - void onError(Throwable throwable, boolean isAbsent); - - /** Called to report successful receipt of listener config. */ - void onListenerUpdate(); - } - - /** Shutdown this instance and release resources. */ - public void shutdown() { - logger.log(Level.FINER, "Shutdown"); - if (xdsClient != null) { - xdsClient.cancelLdsResourceWatch(grpcServerResourceId, listenerWatcher); - xdsClient = xdsClientPool.returnObject(xdsClient); - } - releaseOldSuppliers(curListener.getAndSet(null)); - } -} diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index d201c565caa..d0e12caec11 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_REF; import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.DoNotCall; @@ -29,10 +30,14 @@ import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.netty.InternalNettyServerBuilder; +import io.grpc.netty.InternalNettyServerCredentials; +import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.NettyServerBuilder; -import io.grpc.xds.internal.sds.SdsProtocolNegotiators; -import io.grpc.xds.internal.sds.ServerWrapperForXds; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingNegotiatorServerFactory; +import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; /** @@ -40,11 +45,12 @@ */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7514") public final class XdsServerBuilder extends ForwardingServerBuilder { - private final NettyServerBuilder delegate; private final int port; private XdsServingStatusListener xdsServingStatusListener; private AtomicBoolean isServerBuilt = new AtomicBoolean(false); + private XdsClientPoolFactory xdsClientPoolFactory = + SharedXdsClientPoolProvider.getDefaultProvider(); private XdsServerBuilder(NettyServerBuilder nettyDelegate, int port) { this.delegate = nettyDelegate; @@ -67,37 +73,38 @@ public XdsServerBuilder xdsServingStatusListener( return this; } - /** - * Unsupported call. Users should only use {@link #forPort(int, ServerCredentials)}. - */ @DoNotCall("Unsupported. Use forPort(int, ServerCredentials) instead") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException( - "Unsupported call - use forPort(int, ServerCredentials)"); + "Unsupported call - use forPort(int, ServerCredentials)"); } /** Creates a gRPC server builder for the given port. */ public static XdsServerBuilder forPort(int port, ServerCredentials serverCredentials) { - NettyServerBuilder nettyDelegate = NettyServerBuilder.forPort(port, serverCredentials); + checkNotNull(serverCredentials, "serverCredentials"); + InternalProtocolNegotiator.ServerFactory originalNegotiatorFactory = + InternalNettyServerCredentials.toNegotiator(serverCredentials); + ServerCredentials wrappedCredentials = InternalNettyServerCredentials.create( + new FilterChainMatchingNegotiatorServerFactory(originalNegotiatorFactory)); + NettyServerBuilder nettyDelegate = NettyServerBuilder.forPort(port, wrappedCredentials); return new XdsServerBuilder(nettyDelegate, port); } @Override public Server build() { - return buildServer(new XdsClientWrapperForServerSds(port)); + checkState(isServerBuilt.compareAndSet(false, true), "Server already built!"); + AtomicReference filterChainSelectorRef = new AtomicReference<>(); + InternalNettyServerBuilder.eagAttributes(delegate, Attributes.newBuilder() + .set(ATTR_FILTER_CHAIN_SELECTOR_REF, filterChainSelectorRef) + .build()); + return new XdsServerWrapper("0.0.0.0:" + port, delegate, xdsServingStatusListener, + filterChainSelectorRef, xdsClientPoolFactory); } - /** - * Creates a Server using the given xdsClient. - */ @VisibleForTesting - ServerWrapperForXds buildServer( - XdsClientWrapperForServerSds xdsClient) { - checkState(isServerBuilt.compareAndSet(false, true), "Server already built!"); - InternalNettyServerBuilder.eagAttributes(delegate, Attributes.newBuilder() - .set(SdsProtocolNegotiators.SERVER_XDS_CLIENT, xdsClient) - .build()); - return new ServerWrapperForXds(delegate, xdsClient, xdsServingStatusListener); + XdsServerBuilder xdsClientPoolFactory(XdsClientPoolFactory xdsClientPoolFactory) { + this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + return this; } /** diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java new file mode 100644 index 00000000000..cc8ae1e282c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -0,0 +1,444 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourceHolder; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.XdsClient.LdsResourceWatcher; +import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +final class XdsServerWrapper extends Server { + private static final Logger logger = Logger.getLogger(XdsServerWrapper.class.getName()); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log(Level.SEVERE, "Exception!" + e); + // TODO(chengyuanzhang): implement cleanup. + } + }); + + @VisibleForTesting + static final long RETRY_DELAY_NANOS = TimeUnit.MINUTES.toNanos(1); + private final String listenerAddress; + private final ServerBuilder delegateBuilder; + private boolean sharedTimeService; + private final ScheduledExecutorService timeService; + private final XdsClientPoolFactory xdsClientPoolFactory; + private final XdsServingStatusListener listener; + private final AtomicReference filterChainSelectorRef; + private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private boolean isServing; + private final CountDownLatch internalTerminationLatch = new CountDownLatch(1); + private final SettableFuture initialStartFuture = SettableFuture.create(); + private boolean initialStarted; + private ScheduledHandle restartTimer; + private ObjectPool xdsClientPool; + private XdsClient xdsClient; + private DiscoveryState discoveryState; + private volatile Server delegate; + + XdsServerWrapper( + String listenerAddress, + ServerBuilder delegateBuilder, + XdsServingStatusListener listener, + AtomicReference filterChainSelectorRef, + XdsClientPoolFactory xdsClientPoolFactory) { + this(listenerAddress, delegateBuilder, listener, filterChainSelectorRef, xdsClientPoolFactory, + SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + sharedTimeService = true; + } + + @VisibleForTesting + XdsServerWrapper( + String listenerAddress, + ServerBuilder delegateBuilder, + XdsServingStatusListener listener, + AtomicReference filterChainSelectorRef, + XdsClientPoolFactory xdsClientPoolFactory, + ScheduledExecutorService timeService) { + this.listenerAddress = checkNotNull(listenerAddress, "listenerAddress"); + this.delegateBuilder = checkNotNull(delegateBuilder, "delegateBuilder"); + this.listener = checkNotNull(listener, "listener"); + this.filterChainSelectorRef = checkNotNull(filterChainSelectorRef, "filterChainSelectorRef"); + this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.timeService = checkNotNull(timeService, "timeService"); + this.delegate = delegateBuilder.build(); + } + + @Override + public Server start() throws IOException { + checkState(started.compareAndSet(false, true), "Already started"); + syncContext.execute(new Runnable() { + @Override + public void run() { + internalStart(); + } + }); + Exception exception; + try { + exception = initialStartFuture.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + if (exception != null) { + throw (exception instanceof IOException) ? (IOException) exception : + new IOException(exception); + } + return this; + } + + private void internalStart() { + try { + xdsClientPool = xdsClientPoolFactory.getOrCreate(); + } catch (Exception e) { + StatusException statusException = Status.UNAVAILABLE.withDescription( + "Failed to initialize xDS").withCause(e).asException(); + listener.onNotServing(statusException); + initialStartFuture.set(statusException); + return; + } + xdsClient = xdsClientPool.getObject(); + // TODO(chengyuanzhang): add an API on XdsClient indicating if it is using v3, don't get + // from bootstrap. + boolean useProtocolV3 = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); + String listenerTemplate = xdsClient.getBootstrapInfo().getServerListenerResourceNameTemplate(); + if (!useProtocolV3 || listenerTemplate == null) { + StatusException statusException = + Status.UNAVAILABLE.withDescription( + "Can only support xDS v3 with listener resource name template").asException(); + listener.onNotServing(statusException); + initialStartFuture.set(statusException); + xdsClient = xdsClientPool.returnObject(xdsClient); + return; + } + discoveryState = new DiscoveryState(listenerTemplate.replaceAll("%s", listenerAddress)); + } + + @Override + public Server shutdown() { + if (!shutdown.compareAndSet(false, true)) { + return this; + } + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!delegate.isShutdown()) { + delegate.shutdown(); + } + internalShutdown(); + } + }); + return this; + } + + @Override + public Server shutdownNow() { + if (!shutdown.compareAndSet(false, true)) { + return this; + } + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!delegate.isShutdown()) { + delegate.shutdownNow(); + } + internalShutdown(); + } + }); + return this; + } + + // Must run in SynchronizationContext + private void internalShutdown() { + logger.log(Level.FINER, "Shutting down XdsServerWrapper"); + if (discoveryState != null) { + discoveryState.shutdown(); + } + if (xdsClient != null) { + xdsClient = xdsClientPool.returnObject(xdsClient); + } + if (restartTimer != null) { + restartTimer.cancel(); + } + if (sharedTimeService) { + SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timeService); + } + isServing = false; + internalTerminationLatch.countDown(); + } + + @Override + public boolean isShutdown() { + return shutdown.get(); + } + + @Override + public boolean isTerminated() { + return internalTerminationLatch.getCount() == 0 && delegate.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + long startTime = System.nanoTime(); + if (!internalTerminationLatch.await(timeout, unit)) { + return false; + } + long remainingTime = unit.toNanos(timeout) - (System.nanoTime() - startTime); + return delegate.awaitTermination(remainingTime, TimeUnit.NANOSECONDS); + } + + @Override + public void awaitTermination() throws InterruptedException { + internalTerminationLatch.await(); + delegate.awaitTermination(); + } + + @Override + public int getPort() { + return delegate.getPort(); + } + + @Override + public List getListenSockets() { + return delegate.getListenSockets(); + } + + @Override + public List getServices() { + return delegate.getServices(); + } + + @Override + public List getImmutableServices() { + return delegate.getImmutableServices(); + } + + @Override + public List getMutableServices() { + return delegate.getMutableServices(); + } + + // Must run in SynchronizationContext + private void startDelegateServer() { + if (restartTimer != null && restartTimer.isPending()) { + return; + } + if (isServing) { + return; + } + if (delegate.isShutdown()) { + delegate = delegateBuilder.build(); + } + try { + delegate.start(); + listener.onServing(); + isServing = true; + if (!initialStarted) { + initialStarted = true; + initialStartFuture.set(null); + } + } catch (IOException e) { + logger.log(Level.FINE, "Fail to start delegate server: {0}", e); + if (!initialStarted) { + initialStarted = true; + initialStartFuture.set(e); + } + restartTimer = syncContext.schedule( + new RestartTask(), RETRY_DELAY_NANOS, TimeUnit.NANOSECONDS, timeService); + } + } + + private final class RestartTask implements Runnable { + @Override + public void run() { + startDelegateServer(); + } + } + + private final class DiscoveryState implements LdsResourceWatcher { + private final String resourceName; + // Most recently discovered filter chains. + private List filterChains = new ArrayList<>(); + // Most recently discovered default filter chain. + @Nullable + private FilterChain defaultFilterChain; + private boolean stopped; + + private DiscoveryState(String resourceName) { + this.resourceName = checkNotNull(resourceName, "resourceName"); + xdsClient.watchLdsResource(resourceName, this); + } + + @Override + public void onChanged(final LdsUpdate update) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (stopped) { + return; + } + checkNotNull(update.listener(), "update"); + filterChains = update.listener().getFilterChains(); + defaultFilterChain = update.listener().getDefaultFilterChain(); + updateSelector(); + } + }); + } + + @Override + public void onResourceDoesNotExist(final String resourceName) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (stopped) { + return; + } + StatusException statusException = Status.UNAVAILABLE.withDescription( + "Listener " + resourceName + " unavailable").asException(); + handleConfigNotFound(statusException); + } + }); + } + + @Override + public void onError(final Status error) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (stopped) { + return; + } + boolean isPermanentError = isPermanentError(error); + logger.log(Level.FINE, "{0} error from XdsClient: {1}", + new Object[]{isPermanentError ? "Permanent" : "Transient", error}); + if (isPermanentError) { + handleConfigNotFound(error.asException()); + } else if (!isServing) { + listener.onNotServing(error.asException()); + } + } + }); + } + + private void shutdown() { + stopped = true; + logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); + xdsClient.cancelLdsResourceWatch(resourceName, this); + List toRelease = collectSslContextProviderSuppliers(); + filterChainSelectorRef.set(FilterChainSelector.NO_FILTER_CHAIN); + for (SslContextProviderSupplier s: toRelease) { + s.close(); + } + } + + private List collectSslContextProviderSuppliers() { + List toRelease = new ArrayList<>(); + FilterChainSelector selector = filterChainSelectorRef.get(); + if (selector != null) { + for (FilterChain f: selector.getFilterChains()) { + if (f.getSslContextProviderSupplier() != null) { + toRelease.add(f.getSslContextProviderSupplier()); + } + } + SslContextProviderSupplier defaultSupplier = + selector.getDefaultSslContextProviderSupplier(); + if (defaultSupplier != null) { + toRelease.add(defaultSupplier); + } + } + return toRelease; + } + + private void updateSelector() { + List toRelease = collectSslContextProviderSuppliers(); + FilterChainSelector selector = new FilterChainSelector( + Collections.unmodifiableList(filterChains), + defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier()); + filterChainSelectorRef.set(selector); + for (SslContextProviderSupplier s: toRelease) { + s.close(); + } + startDelegateServer(); + } + + private void handleConfigNotFound(StatusException exception) { + List toRelease = collectSslContextProviderSuppliers(); + filterChainSelectorRef.set(FilterChainSelector.NO_FILTER_CHAIN); + for (SslContextProviderSupplier s: toRelease) { + s.close(); + } + if (!initialStarted) { + initialStarted = true; + initialStartFuture.set(exception); + } + if (restartTimer != null) { + restartTimer.cancel(); + } + if (!delegate.isShutdown()) { + delegate.shutdown(); // let in-progress calls finish + } + isServing = false; + listener.onNotServing(exception); + } + } + + private static boolean isPermanentError(Status error) { + return EnumSet.of( + Status.Code.INTERNAL, + Status.Code.INVALID_ARGUMENT, + Status.Code.FAILED_PRECONDITION, + Status.Code.PERMISSION_DENIED, + Status.Code.UNAUTHENTICATED) + .contains(error.getCode()); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index 37161325746..0128fa53106 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -29,7 +29,6 @@ import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.InternalXdsAttributes; -import io.grpc.xds.XdsClientWrapperForServerSds; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -57,10 +56,12 @@ private SdsProtocolNegotiators() { private static final Logger logger = Logger.getLogger(SdsProtocolNegotiators.class.getName()); - public static final Attributes.Key SERVER_XDS_CLIENT - = Attributes.Key.create("serverXdsClient"); private static final AsciiString SCHEME = AsciiString.of("http"); + public static final Attributes.Key + ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.sds.server.sslContextProviderSupplier"); + /** * Returns a {@link InternalProtocolNegotiator.ClientFactory}. * @@ -253,10 +254,7 @@ public AsciiString scheme() { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - grpcHandler.getEagAttributes().get(SERVER_XDS_CLIENT); - return new HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, - fallbackProtocolNegotiator); + return new HandlerPickerHandler(grpcHandler, fallbackProtocolNegotiator); } @Override @@ -267,25 +265,21 @@ public void close() {} static final class HandlerPickerHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler grpcHandler; - private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; HandlerPickerHandler( GrpcHttp2ConnectionHandler grpcHandler, - @Nullable XdsClientWrapperForServerSds xdsClientWrapperForServerSds, - ProtocolNegotiator fallbackProtocolNegotiator) { + @Nullable ProtocolNegotiator fallbackProtocolNegotiator) { this.grpcHandler = checkNotNull(grpcHandler, "grpcHandler"); - this.xdsClientWrapperForServerSds = xdsClientWrapperForServerSds; this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds == null - ? null - : xdsClientWrapperForServerSds.getSslContextProviderSupplier(ctx.channel()); + ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent)evt; + SslContextProviderSupplier sslContextProviderSupplier = InternalProtocolNegotiationEvent + .getAttributes(pne).get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER); if (sslContextProviderSupplier == null) { if (fallbackProtocolNegotiator == null) { ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); @@ -297,7 +291,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc this, null, fallbackProtocolNegotiator.newHandler(grpcHandler)); - ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ctx.fireUserEventTriggered(pne); return; } else { @@ -307,7 +300,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc null, new ServerSdsHandler( grpcHandler, sslContextProviderSupplier)); - ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ctx.fireUserEventTriggered(pne); return; } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java deleted file mode 100644 index 968c7385499..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java +++ /dev/null @@ -1,368 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.sds; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.ServerServiceDefinition; -import io.grpc.Status; -import io.grpc.internal.GrpcUtil; -import io.grpc.internal.SharedResourceHolder; -import io.grpc.xds.XdsClientWrapperForServerSds; -import io.grpc.xds.XdsInitializationException; -import io.grpc.xds.XdsServerBuilder; -import java.io.IOException; -import java.net.BindException; -import java.net.SocketAddress; -import java.util.EnumSet; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import javax.annotation.Nullable; - -/** - * Wraps a {@link Server} delegate and {@link XdsClientWrapperForServerSds} and intercepts {@link - * Server#shutdown()} and {@link Server#start()} to shut down and start the - * {@link XdsClientWrapperForServerSds} object. - */ -@VisibleForTesting -public final class ServerWrapperForXds extends Server { - private Server delegate; - private final ServerBuilder delegateBuilder; - private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener; - @Nullable XdsClientWrapperForServerSds.ServerWatcher serverWatcher; - private AtomicBoolean started = new AtomicBoolean(); - private volatile ServingState currentServingState; - private final long delayForRetry; - private final TimeUnit timeUnitForDelayForRetry; - private StartRetryTask startRetryTask; - - @VisibleForTesting public enum ServingState { - // during start() i.e. first start - STARTING, - - // after start (1st or subsequent ones) - STARTED, - - // not serving due to listener deletion - NOT_SERVING, - - // enter serving mode after NOT_SERVING - ENTER_SERVING, - - // shut down - could be due to failure - SHUTDOWN - } - - /** Creates the wrapper object using the delegate passed. */ - public ServerWrapperForXds( - ServerBuilder delegateBuilder, - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, - XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) { - this( - delegateBuilder, - xdsClientWrapperForServerSds, - xdsServingStatusListener, - 1L, - TimeUnit.MINUTES); - } - - /** Creates the wrapper object using params passed - used for tests. */ - @VisibleForTesting - public ServerWrapperForXds(ServerBuilder delegateBuilder, - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, - XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener, - long delayForRetry, TimeUnit timeUnitForDelayForRetry) { - this.delegateBuilder = checkNotNull(delegateBuilder, "delegateBuilder"); - this.delegate = delegateBuilder.build(); - this.xdsClientWrapperForServerSds = - checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds"); - this.xdsServingStatusListener = - checkNotNull(xdsServingStatusListener, "xdsServingStatusListener"); - this.delayForRetry = delayForRetry; - this.timeUnitForDelayForRetry = - checkNotNull(timeUnitForDelayForRetry, "timeUnitForDelayForRetry"); - } - - @Override - public Server start() throws IOException { - checkState(started.compareAndSet(false, true), "Already started"); - currentServingState = ServingState.STARTING; - SettableFuture future = addServerWatcher(); - xdsClientWrapperForServerSds.start(); - try { - Throwable throwable = future.get(); - if (throwable != null) { - removeServerWatcher(); - if (throwable instanceof IOException) { - throw (IOException) throwable; - } - throw new IOException(throwable); - } - } catch (InterruptedException | ExecutionException ex) { - removeServerWatcher(); - if (ex instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new RuntimeException(ex); - } - return this; - } - - @VisibleForTesting public ServingState getCurrentServingState() { - return currentServingState; - } - - private SettableFuture addServerWatcher() { - final SettableFuture future = SettableFuture.create(); - serverWatcher = - new XdsClientWrapperForServerSds.ServerWatcher() { - @Override - public void onError(Throwable throwable, boolean isAbsent) { - synchronized (ServerWrapperForXds.this) { - if (currentServingState == ServingState.SHUTDOWN) { - return; - } else if (currentServingState == ServingState.STARTING) { - // during start - if (isPermanentErrorFromXds(throwable)) { - currentServingState = ServingState.SHUTDOWN; - future.set(throwable); - return; - } - xdsServingStatusListener.onNotServing(throwable); - } else { - // is one of STARTED, NOT_SERVING or ENTER_SERVING - if (isAbsent) { - xdsServingStatusListener.onNotServing(throwable); - if (currentServingState == ServingState.STARTED) { - // shutdown the server - delegate.shutdown(); // let existing calls finish on delegate - currentServingState = ServingState.NOT_SERVING; - } - } - } - } - } - - @Override - public void onListenerUpdate() { - synchronized (ServerWrapperForXds.this) { - if (currentServingState == ServingState.SHUTDOWN) { - return; - } else if (currentServingState == ServingState.STARTING) { - // during start() - try { - delegate.start(); - currentServingState = ServingState.STARTED; - xdsServingStatusListener.onServing(); - future.set(null); - } catch (IOException ioe) { - future.set(ioe); - } - } else if (currentServingState == ServingState.NOT_SERVING) { - // coming out of NOT_SERVING - currentServingState = ServingState.ENTER_SERVING; - startRetryTask = new StartRetryTask(); - startRetryTask.createTask(0L); - } - } - } - }; - xdsClientWrapperForServerSds.addServerWatcher(serverWatcher); - return future; - } - - private final class StartRetryTask implements Runnable { - - ScheduledExecutorService timerService; - AtomicReference> future = new AtomicReference<>(); - - private void createTask(long delay) { - if (timerService == null) { - timerService = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - } - future.set(timerService.schedule(this, delay, timeUnitForDelayForRetry)); - } - - private void rebuildAndRestartServer() { - delegate = delegateBuilder.build(); - try { - delegate = delegate.start(); - currentServingState = ServingState.STARTED; - xdsServingStatusListener.onServing(); - cleanUpStartRetryTask(); - } catch (IOException ioe) { - xdsServingStatusListener.onNotServing(ioe); - if (isRetriableErrorInDelegateStart(ioe)) { - createTask(delayForRetry); - } else { - // permanent failure - currentServingState = ServingState.SHUTDOWN; - cleanUpStartRetryTask(); - } - } - } - - @Override - public void run() { - if (currentServingState == ServingState.SHUTDOWN) { - return; - } else if (currentServingState != ServingState.ENTER_SERVING) { - throw new AssertionError("Wrong state:" + currentServingState); - } - rebuildAndRestartServer(); - } - - private void cleanUpStartRetryTask() { - synchronized (ServerWrapperForXds.this) { - if (timerService != null) { - timerService = SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timerService); - } - startRetryTask = null; - } - } - - public void shutdownNow() { - ScheduledFuture oldValue = future.getAndSet(null); - if (oldValue != null) { - oldValue.cancel(true); - } - cleanUpStartRetryTask(); - } - } - - private void removeServerWatcher() { - synchronized (xdsClientWrapperForServerSds) { - if (serverWatcher != null) { - xdsClientWrapperForServerSds.removeServerWatcher(serverWatcher); - serverWatcher = null; - } - } - } - - // if the IOException indicates we can rebuild delegate and retry start... - private static boolean isRetriableErrorInDelegateStart(IOException ioe) { - if (ioe instanceof BindException) { - return true; - } - Throwable cause = ioe.getCause(); - return cause instanceof BindException; - } - - // if the Throwable indicates a permanent error in xDS processing - private static boolean isPermanentErrorFromXds(Throwable throwable) { - if (throwable instanceof XdsInitializationException) { - return true; - } - Status.Code code = Status.fromThrowable(throwable).getCode(); - return EnumSet.of( - Status.Code.INTERNAL, - Status.Code.INVALID_ARGUMENT, - Status.Code.FAILED_PRECONDITION, - Status.Code.PERMISSION_DENIED, - Status.Code.UNAUTHENTICATED) - .contains(code); - } - - private void cleanupStartRetryTaskAndShutdownDelegateAndXdsClient(boolean shutdownNow) { - Server delegateCopy = null; - synchronized (ServerWrapperForXds.this) { - if (startRetryTask != null) { - startRetryTask.shutdownNow(); - } - currentServingState = ServingState.SHUTDOWN; - if (delegate != null && !delegate.isShutdown()) { - delegateCopy = delegate; - } - } - if (delegateCopy != null) { - if (shutdownNow) { - delegateCopy.shutdownNow(); - } else { - delegateCopy.shutdown(); - } - } - xdsClientWrapperForServerSds.shutdown(); - } - - @Override - public Server shutdown() { - cleanupStartRetryTaskAndShutdownDelegateAndXdsClient(false); - return this; - } - - @Override - public Server shutdownNow() { - cleanupStartRetryTaskAndShutdownDelegateAndXdsClient(true); - return this; - } - - @Override - public boolean isShutdown() { - return delegate.isShutdown(); - } - - @Override - public boolean isTerminated() { - return delegate.isTerminated(); - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return delegate.awaitTermination(timeout, unit); - } - - @Override - public void awaitTermination() throws InterruptedException { - delegate.awaitTermination(); - } - - @Override - public int getPort() { - return delegate.getPort(); - } - - @Override - public List getListenSockets() { - return delegate.getListenSockets(); - } - - @Override - public List getServices() { - return delegate.getServices(); - } - - @Override - public List getImmutableServices() { - return delegate.getImmutableServices(); - } - - @Override - public List getMutableServices() { - return delegate.getMutableServices(); - } -} diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java deleted file mode 100644 index 8ee3e87a242..00000000000 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java +++ /dev/null @@ -1,941 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.XdsClient.LdsUpdate; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.netty.channel.Channel; -import java.io.IOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.UnknownHostException; -import java.util.Arrays; -import java.util.Collections; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Tests for {@link XdsClientWrapperForServerSds}. */ -@RunWith(JUnit4.class) -public class FilterChainMatchTest { - - private static final int PORT = 7000; - private static final String LOCAL_IP = "10.1.2.3"; // dest - private static final String REMOTE_IP = "10.4.2.3"; // source - private static final HttpConnectionManager HTTP_CONNECTION_MANAGER = - HttpConnectionManager.forRdsName( - 10L, "route-config", Collections.emptyList()); - - @Mock private Channel channel; - @Mock private TlsContextManager tlsContextManager; - - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsClient.LdsResourceWatcher registeredWatcher; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(PORT, tlsContextManager); - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - } - - @After - public void tearDown() { - xdsClientWrapperForServerSds.shutdown(); - } - - private EnvoyServerProtoData.DownstreamTlsContext getDownstreamTlsContext() { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); - if (sslContextProviderSupplier != null) { - EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); - assertThat(tlsContext).isInstanceOf(EnvoyServerProtoData.DownstreamTlsContext.class); - return (EnvoyServerProtoData.DownstreamTlsContext) tlsContext; - } - return null; - } - - @Test - public void singleFilterChainWithoutAlpn() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.FilterChainMatch filterChainMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContext); - } - - @Test - public void singleFilterChainWithAlpn() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.FilterChainMatch filterChainMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList("managed-mtls"), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, - tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext defaultTlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, - tlsContextManager); - EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChain), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(defaultTlsContext); - } - - @Test - public void defaultFilterChain() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", null, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(), filterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContext); - } - - @Test - public void destPortFails_returnDefaultFilterChain() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextWithDestPort = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithDestPort = - new EnvoyServerProtoData.FilterChainMatch( - PORT, - Arrays.asList(), - Arrays.asList("managed-mtls"), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithDestPort = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithDestPort, HTTP_CONNECTION_MANAGER, - tlsContextWithDestPort, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithDestPort), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); - } - - @Test - public void destPrefixRangeMatch() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, - tlsContextMatch, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); - } - - @Test - public void destPrefixRangeMismatch_returnDefaultFilterChain() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - // 10.2.2.0/24 doesn't match LOCAL_IP - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 24)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, - tlsContextMismatch, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); - } - - @Test - public void dest0LengthPrefixRange() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext0Length = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - // 10.2.2.0/24 doesn't match LOCAL_IP - EnvoyServerProtoData.FilterChainMatch filterChainMatch0Length = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 0)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain0Length = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch0Length, HTTP_CONNECTION_MANAGER, - tlsContext0Length, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChain0Length), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContext0Length); - } - - @Test - public void destPrefixRange_moreSpecificWins() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecific, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); - } - - @Test - public void destPrefixRange_emptyListLessSpecific() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("8.0.0.0", 5)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecific, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); - } - - @Test - public void destPrefixRangeIpv6_moreSpecificWins() - throws UnknownHostException { - setupChannel("FE80:0000:0000:0000:0202:B3FF:FE1E:8329", "2001:DB8::8:800:200C:417A", 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0:0:0:0:0:0:0", 60)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0000:0000:0000:0202:0:0:0", 80)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecific, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - "FE80:0000:0000:0000:0202:B3FF:FE1E:8329", - Arrays.asList(filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); - } - - @Test - public void destPrefixRange_moreSpecificWith2Wins() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.1.2.0", 24), - new EnvoyServerProtoData.CidrRange(LOCAL_IP, 32)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecificWith2, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainMoreSpecificWith2, filterChainLessSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); - } - - @Test - public void sourceTypeMismatch_returnDefaultFilterChain() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, - tlsContextMismatch, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); - } - - @Test - public void sourceTypeLocal() throws UnknownHostException { - setupChannel(LOCAL_IP, LOCAL_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, tlsContextMatch, - tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); - } - - @Test - public void sourcePrefixRange_moreSpecificWith2Wins() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecificWith2, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainMoreSpecificWith2, filterChainLessSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); - } - - @Test - public void sourcePrefixRange_2Matchers_expectException() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange("192.168.10.2", 32)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, - tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChain1, filterChain2), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - try { - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); - fail("expect exception!"); - } catch (IllegalStateException ise) { - assertThat(ise).hasMessageThat().isEqualTo("Found 2 matching filter-chains"); - } - } - - @Test - public void sourcePortMatch_exactMatchWinsOverEmptyList() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextEmptySourcePorts = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchEmptySourcePorts = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchEmptySourcePorts, HTTP_CONNECTION_MANAGER, - tlsContextEmptySourcePorts, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextSourcePortMatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchSourcePortMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(7000, 15000), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchSourcePortMatch, HTTP_CONNECTION_MANAGER, - tlsContextSourcePortMatch, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainEmptySourcePorts, filterChainSourcePortMatch), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextSourcePortMatch); - } - - /** - * Create 6 filterChains: - 1st filter chain has dest port & specific prefix range but is - * eliminated due to dest port - 5 advance to next step: 1 is eliminated due to being less - * specific than the remaining 4. - 4 advance to 3rd step: source type external eliminates one - * with local source_type. - 3 advance to 4th step: more specific 2 get picked based on - * source-prefix range. - 5th step: out of 2 one with matching source port gets picked - */ - @Test - public void filterChain_5stepMatch() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext4 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT4", "VA4"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext5 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT5", "VA5"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext6 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT6", "VA6"); - - // has dest port and specific prefix ranges: gets eliminated in step 1 - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - PORT, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-1", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, - tlsContextManager); - - // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 - - // has single prefix range: and less specific source prefix range: gets eliminated in step 4 - EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.0.0", 16)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( - "filter-chain-2", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, - tlsContextManager); - - // has prefix ranges with one not matching and source type local: gets eliminated in step 3 - EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("192.168.2.0", 24), - new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain3 = new EnvoyServerProtoData.FilterChain( - "filter-chain-3", filterChainMatch3, HTTP_CONNECTION_MANAGER, tlsContext3, - tlsContextManager); - - // has prefix ranges with both matching and source type external but non matching source port: - // gets eliminated in step 5 - EnvoyServerProtoData.FilterChainMatch filterChainMatch4 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), - new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), - EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, - Arrays.asList(16000, 9000), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain4 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-4", filterChainMatch4, HTTP_CONNECTION_MANAGER, tlsContext4, - tlsContextManager); - - // has prefix ranges with both matching and source type external and matching source port: this - // gets selected - EnvoyServerProtoData.FilterChainMatch filterChainMatch5 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), - new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange("192.168.2.0", 24)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(15000, 8000), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain5 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-5", filterChainMatch5, HTTP_CONNECTION_MANAGER, tlsContext5, - tlsContextManager); - - // has prefix range with prefixLen of 29: gets eliminated in step 2 - EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 29)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain6 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-6", filterChainMatch6, HTTP_CONNECTION_MANAGER, tlsContext6, - tlsContextManager); - - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList( - filterChain1, filterChain2, filterChain3, filterChain4, filterChain5, filterChain6), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); - assertThat(tlsContextPicked).isSameInstanceAs(tlsContext5); - } - - @Test - public void filterChainMatch_unsupportedMatchers() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "ROOTCA"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "ROOTCA"); - - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - 0 /* destinationPort */, - Collections.singletonList( - new EnvoyServerProtoData.CidrRange("10.1.0.0", 16)) /* prefixRange */, - Arrays.asList("managed-mtls", "h2") /* applicationProtocol */, - Collections.emptyList() /* sourcePrefixRanges */, - EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, - Collections.emptyList() /* sourcePorts */, - Arrays.asList("server1", "server2") /* serverNames */, - "tls" /* transportProtocol */); - - EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = - new EnvoyServerProtoData.FilterChainMatch( - 0 /* destinationPort */, - Collections.singletonList( - new EnvoyServerProtoData.CidrRange("10.0.0.0", 8)) /* prefixRange */, - Collections.emptyList() /* applicationProtocol */, - Collections.emptyList() /* sourcePrefixRanges */, - EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, - Collections.emptyList() /* sourcePorts */, - Collections.emptyList() /* serverNames */, - "" /* transportProtocol */); - - EnvoyServerProtoData.FilterChainMatch defaultFilterChainMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0 /* destinationPort */, - Collections.emptyList() /* prefixRange */, - Collections.emptyList() /* applicationProtocol */, - Collections.emptyList() /* sourcePrefixRanges */, - EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, - Collections.emptyList() /* sourcePorts */, - Collections.emptyList() /* serverNames */, - "" /* transportProtocol */); - - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, - mock(TlsContextManager.class)); - EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, - mock(TlsContextManager.class)); - - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, - mock(TlsContextManager.class)); - - EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( - "", "10.2.1.34:8000", Arrays.asList(filterChain1, filterChain2), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); - // assert defaultFilterChain match - assertThat(tlsContextPicked.getCommonTlsContext().getTlsCertificateCertificateProviderInstance() - .getCertificateName()).isEqualTo("CERT3"); - } - - private void setupChannel(String localIp, String remoteIp, int remotePort) - throws UnknownHostException { - when(channel.localAddress()) - .thenReturn(new InetSocketAddress(InetAddress.getByName(localIp), PORT)); - when(channel.remoteAddress()) - .thenReturn(new InetSocketAddress(InetAddress.getByName(remoteIp), remotePort)); - } -} diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java new file mode 100644 index 00000000000..c926acc9fcc --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -0,0 +1,1318 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.ProtocolNegotiationEvent; +import io.grpc.xds.EnvoyServerProtoData.CidrRange; +import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.DefaultHttp2FrameReader; +import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class FilterChainMatchingProtocolNegotiatorsTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + private final GrpcHttp2ConnectionHandler grpcHandler = + FakeGrpcHttp2ConnectionHandler.newHandler(); + @Mock private TlsContextManager tlsContextManager; + private ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + private ChannelPipeline pipeline; + private EmbeddedChannel channel; + private ChannelHandlerContext channelHandlerCtx; + @Mock + private ProtocolNegotiator mockDelegate; + private final SettableFuture sslSet = SettableFuture.create(); + private static final HttpConnectionManager HTTP_CONNECTION_MANAGER = createRds("routing-config"); + private static final String LOCAL_IP = "10.1.2.3"; // dest + private static final String REMOTE_IP = "10.4.2.3"; // source + private static final int PORT = 7000; + + @Test + public void filterChainMatch() throws Exception { + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + FilterChain f0 = createFilterChain("filter-chain-0", createRds("r0")); + SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), + mock(TlsContextManager.class)); + FilterChainSelector selector = new FilterChainSelector(Collections.singletonList(f0), + defaultSsl); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); + ChannelHandlerContext channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + pipeline.fireUserEventTriggered(event); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNull(); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(f0.getSslContextProviderSupplier()); + channelHandlerCtx = pipeline.context(next); + assertThat(channelHandlerCtx).isNotNull(); + } + + @Test + @SuppressWarnings("unchecked") + public void nofilterChainMatch_defaultSslContext() throws Exception { + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + SslContextProviderSupplier ssl = new SslContextProviderSupplier(createTls(), tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Collections.EMPTY_LIST, ssl); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); + ChannelHandlerContext channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + pipeline.fireUserEventTriggered(event); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNull(); + + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(ssl); + channelHandlerCtx = pipeline.context(next); + assertThat(channelHandlerCtx).isNotNull(); + } + + @Test + @SuppressWarnings("unchecked") + public void noFilterChainMatch_noDefaultSslContext() { + FilterChainSelector selector = new FilterChainSelector(Collections.EMPTY_LIST, null); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + pipeline.fireUserEventTriggered(event); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + try { + channel.checkException(); + fail("exception expected!"); + } catch (Exception e) { + assertThat(e).isInstanceOf(IllegalStateException.class); + assertThat(e).hasMessageThat().contains("No matching filter chain found."); + } + } + + @Test + public void singleFilterChainWithoutAlpn() throws Exception { + EnvoyServerProtoData.FilterChainMatch filterChainMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.DownstreamTlsContext tlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, + tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChain), null); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext); + } + + @Test + public void singleFilterChainWithAlpn() throws Exception { + EnvoyServerProtoData.FilterChainMatch filterChainMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList("managed-mtls"), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.DownstreamTlsContext tlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, + tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext defaultTlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, + tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChain), defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(defaultTlsContext); + } + + @Test + public void defaultFilterChain() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", null, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); + FilterChainSelector selector = new FilterChainSelector( + Arrays.asList(), + filterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext); + } + + @Test + public void destPortFails_returnDefaultFilterChain() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextWithDestPort = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithDestPort = + new EnvoyServerProtoData.FilterChainMatch( + PORT, + Arrays.asList(), + Arrays.asList("managed-mtls"), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithDestPort = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithDestPort, HTTP_CONNECTION_MANAGER, + tlsContextWithDestPort, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainWithDestPort), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); + } + + @Test + public void destPrefixRangeMatch() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, + tlsContextMatch, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainWithMatch), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMatch); + } + + @Test + public void destPrefixRangeMismatch_returnDefaultFilterChain() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + // 10.2.2.0/24 doesn't match LOCAL_IP + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 24)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMismatch = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, + tlsContextMismatch, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainWithMismatch), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); + } + + @Test + public void dest0LengthPrefixRange() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext0Length = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + // 10.2.2.0/24 doesn't match LOCAL_IP + EnvoyServerProtoData.FilterChainMatch filterChainMatch0Length = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 0)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain0Length = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch0Length, HTTP_CONNECTION_MANAGER, + tlsContext0Length, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChain0Length), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext0Length); + } + + @Test + public void destPrefixRange_moreSpecificWins() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecific, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainLessSpecific, filterChainMoreSpecific), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); + } + + @Test + public void destPrefixRange_emptyListLessSpecific() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("8.0.0.0", 5)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecific, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainLessSpecific, filterChainMoreSpecific), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); + } + + @Test + public void destPrefixRangeIpv6_moreSpecificWins() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0:0:0:0:0:0:0", 60)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0000:0000:0000:0202:0:0:0", 80)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecific, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainLessSpecific, filterChainMoreSpecific), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + setupChannel("FE80:0000:0000:0000:0202:B3FF:FE1E:8329", "2001:DB8::8:800:200C:417A", + 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); + } + + @Test + public void destPrefixRange_moreSpecificWith2Wins() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.1.2.0", 24), + new EnvoyServerProtoData.CidrRange(LOCAL_IP, 32)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecificWith2, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainMoreSpecificWith2, filterChainLessSpecific), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecificWith2); + } + + @Test + public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMismatch = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, + tlsContextMismatch, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, + tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainWithMismatch), defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); + } + + @Test + public void sourceTypeLocal() throws Exception { + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, tlsContextMatch, + tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, + tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainWithMatch), defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + setupChannel(LOCAL_IP, LOCAL_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMatch); + } + + @Test + public void sourcePrefixRange_moreSpecificWith2Wins() + throws Exception { + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecificWith2, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainMoreSpecificWith2, filterChainLessSpecific), + defaultFilterChain.getSslContextProviderSupplier()); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecificWith2); + } + + @Test + public void sourcePrefixRange_2Matchers_expectException() + throws UnknownHostException { + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange("192.168.10.2", 32)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, + tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChain1, filterChain2), + defaultFilterChain.getSslContextProviderSupplier()); + + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + try { + channel.checkException(); + fail("expect exception!"); + } catch (IllegalStateException ise) { + assertThat(ise).hasMessageThat().isEqualTo("Found more than one matching filter chains."); + assertThat(sslSet.isDone()).isFalse(); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + } + } + + @Test + public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextEmptySourcePorts = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchEmptySourcePorts = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchEmptySourcePorts, HTTP_CONNECTION_MANAGER, + tlsContextEmptySourcePorts, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextSourcePortMatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchSourcePortMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(7000, 15000), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchSourcePortMatch, HTTP_CONNECTION_MANAGER, + tlsContextSourcePortMatch, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChainEmptySourcePorts, filterChainSourcePortMatch), + defaultFilterChain.getSslContextProviderSupplier()); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextSourcePortMatch); + } + + /** + * Create 6 filterChains: - 1st filter chain has dest port & specific prefix range but is + * eliminated due to dest port - 5 advance to next step: 1 is eliminated due to being less + * specific than the remaining 4. - 4 advance to 3rd step: source type external eliminates one + * with local source_type. - 3 advance to 4th step: more specific 2 get picked based on + * source-prefix range. - 5th step: out of 2 one with matching source port gets picked + */ + @Test + public void filterChain_5stepMatch() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext4 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT4", "VA4"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext5 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT5", "VA5"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext6 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT6", "VA6"); + + // has dest port and specific prefix ranges: gets eliminated in step 1 + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + PORT, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-1", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, + tlsContextManager); + + // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 + + // has single prefix range: and less specific source prefix range: gets eliminated in step 4 + EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.0.0", 16)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( + "filter-chain-2", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, + tlsContextManager); + + // has prefix ranges with one not matching and source type local: gets eliminated in step 3 + EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("192.168.2.0", 24), + new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain3 = new EnvoyServerProtoData.FilterChain( + "filter-chain-3", filterChainMatch3, HTTP_CONNECTION_MANAGER, tlsContext3, + tlsContextManager); + + // has prefix ranges with both matching and source type external but non matching source port: + // gets eliminated in step 5 + EnvoyServerProtoData.FilterChainMatch filterChainMatch4 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), + new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), + EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, + Arrays.asList(16000, 9000), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain4 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-4", filterChainMatch4, HTTP_CONNECTION_MANAGER, tlsContext4, + tlsContextManager); + + // has prefix ranges with both matching and source type external and matching source port: this + // gets selected + EnvoyServerProtoData.FilterChainMatch filterChainMatch5 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), + new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange("192.168.2.0", 24)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(15000, 8000), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain5 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-5", filterChainMatch5, HTTP_CONNECTION_MANAGER, tlsContext5, + tlsContextManager); + + // has prefix range with prefixLen of 29: gets eliminated in step 2 + EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 29)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain6 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-6", filterChainMatch6, HTTP_CONNECTION_MANAGER, tlsContext6, + tlsContextManager); + + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChain1, filterChain2, filterChain3, filterChain4, filterChain5, filterChain6), + defaultFilterChain.getSslContextProviderSupplier()); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext5); + } + + @Test + public void filterChainMatch_unsupportedMatchers() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "ROOTCA"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "ROOTCA"); + + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + 0 /* destinationPort */, + Collections.singletonList( + new EnvoyServerProtoData.CidrRange("10.1.0.0", 16)) /* prefixRange */, + Arrays.asList("managed-mtls", "h2") /* applicationProtocol */, + Collections.emptyList() /* sourcePrefixRanges */, + EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, + Collections.emptyList() /* sourcePorts */, + Arrays.asList("server1", "server2") /* serverNames */, + "tls" /* transportProtocol */); + + EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = + new EnvoyServerProtoData.FilterChainMatch( + 0 /* destinationPort */, + Collections.singletonList( + new EnvoyServerProtoData.CidrRange("10.0.0.0", 8)) /* prefixRange */, + Collections.emptyList() /* applicationProtocol */, + Collections.emptyList() /* sourcePrefixRanges */, + EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, + Collections.emptyList() /* sourcePorts */, + Collections.emptyList() /* serverNames */, + "" /* transportProtocol */); + + EnvoyServerProtoData.FilterChainMatch defaultFilterChainMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0 /* destinationPort */, + Collections.emptyList() /* prefixRange */, + Collections.emptyList() /* applicationProtocol */, + Collections.emptyList() /* sourcePrefixRanges */, + EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, + Collections.emptyList() /* sourcePorts */, + Collections.emptyList() /* serverNames */, + "" /* transportProtocol */); + + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, + mock(TlsContextManager.class)); + EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, + mock(TlsContextManager.class)); + + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, + mock(TlsContextManager.class)); + + FilterChainSelector selector = new FilterChainSelector(Arrays.asList( + filterChain1, filterChain2), defaultFilterChain.getSslContextProviderSupplier()); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get().getTlsContext().getCommonTlsContext() + .getTlsCertificateCertificateProviderInstance() + .getCertificateName()).isEqualTo("CERT3"); + } + + private static HttpConnectionManager createRds(String name) { + return HttpConnectionManager.forRdsName(0L, name, + new ArrayList()); + } + + private FilterChain createFilterChain(String name, HttpConnectionManager hcm) { + return new FilterChain(name, createMatch(), + hcm, createTls(), tlsContextManager); + } + + private FilterChainMatch createMatch() { + return new FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + } + + private EnvoyServerProtoData.DownstreamTlsContext createTls() { + return DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext + .getDefaultInstance()); + } + + private void setupChannel(final String localIp, final String remoteIp, final int remotePort, + FilterChainMatchingHandler matchingHandler) { + channel = + new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return new InetSocketAddress(localIp, 80); + } + + @Override + public SocketAddress remoteAddress() { + return new InetSocketAddress(remoteIp, remotePort); + } + }; + pipeline = channel.pipeline(); + pipeline.addLast(matchingHandler); + } + + private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { + FakeGrpcHttp2ConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); + } + + static FakeGrpcHttp2ConnectionHandler newHandler() { + DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false); + DefaultHttp2ConnectionEncoder encoder = + new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader()); + Http2Settings settings = new Http2Settings(); + return new FakeGrpcHttp2ConnectionHandler( + /*channelUnused=*/ null, decoder, encoder, settings); + } + + @Override + public String getAuthority() { + return "authority"; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java b/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java deleted file mode 100644 index c4e888f5439..00000000000 --- a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java +++ /dev/null @@ -1,320 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.Status; -import io.grpc.StatusException; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.ServerWrapperForXds; -import java.io.IOException; -import java.net.BindException; -import java.net.NoRouteToHostException; -import java.util.List; -import java.util.concurrent.CancellationException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.InOrder; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -/** - * Unit tests for {@link ServerWrapperForXds}. - */ -@RunWith(JUnit4.class) -public class ServerWrapperForXdsTest { - - private ServerWrapperForXds serverWrapperForXds; - private ServerBuilder mockDelegateBuilder; - private int port; - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener; - private XdsClient.LdsResourceWatcher listenerWatcher; - private Server mockServer; - private TlsContextManager tlsContextManager; - - @Before - public void setUp() throws IOException { - port = XdsServerTestHelper.findFreePort(); - mockDelegateBuilder = mock(ServerBuilder.class); - tlsContextManager = mock(TlsContextManager.class); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(port, tlsContextManager); - mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); - listenerWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - mockServer = mock(Server.class); - when(mockDelegateBuilder.build()).thenReturn(mockServer); - serverWrapperForXds = new ServerWrapperForXds(mockDelegateBuilder, - xdsClientWrapperForServerSds, - mockXdsServingStatusListener, - 100, TimeUnit.MILLISECONDS); - } - - private Future startServerAsync() throws InterruptedException { - final SettableFuture settableFuture = SettableFuture.create(); - Executors.newSingleThreadExecutor().execute(new Runnable() { - @Override - public void run() { - try { - serverWrapperForXds.start(); - settableFuture.set(null); - } catch (Throwable e) { - settableFuture.set(e); - } - } - }); - // wait until xdsClientWrapperForServerSds.serverWatchers populated - for (int i = 0; i < 10; i++) { - synchronized (xdsClientWrapperForServerSds.serverWatchers) { - if (!xdsClientWrapperForServerSds.serverWatchers.isEmpty()) { - break; - } - } - Thread.sleep(100L); - } - return settableFuture; - } - - @Test - public void start() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - listenerWatcher.onError(Status.ABORTED); - verifyCapturedCodeAndNotServing(Status.Code.ABORTED, ServerWrapperForXds.ServingState.STARTING); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), - tlsContextManager); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - listenerWatcher.onResourceDoesNotExist("name"); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.NOT_FOUND); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.INVALID_ARGUMENT); - verifyCapturedCodeAndNotServing(Status.Code.INVALID_ARGUMENT, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.PERMISSION_DENIED); - verifyCapturedCodeAndNotServing(Status.Code.PERMISSION_DENIED, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.UNIMPLEMENTED); - verifyCapturedCodeAndNotServing(Status.Code.UNIMPLEMENTED, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.UNAUTHENTICATED); - verifyCapturedCodeAndNotServing(Status.Code.UNAUTHENTICATED, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.ABORTED); - verifyCapturedCodeAndNotServing(null, ServerWrapperForXds.ServingState.NOT_SERVING); - Server mockServer1 = mock(Server.class); - Server mockServer2 = mock(Server.class); - Server mockServer3 = mock(Server.class); - final SettableFuture settableFutureForThrow = SettableFuture.create(); - final SettableFuture settableFutureToSignalStart = SettableFuture.create(); - doAnswer(new Answer() { - @Override - public Server answer(InvocationOnMock invocation) throws Throwable { - settableFutureToSignalStart.set(null); - throw settableFutureForThrow.get(); - } - }).when(mockServer1).start(); - doThrow(new BindException()).when(mockServer2).start(); - doReturn(mockServer3).when(mockServer3).start(); - when(mockDelegateBuilder.build()).thenReturn(mockServer1, mockServer2, mockServer3); - new Thread(new Runnable() { - @Override - public void run() { - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - } - }).start(); - assertThat(settableFutureToSignalStart.get()).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.ENTER_SERVING); - settableFutureForThrow.set(new IOException(new BindException())); - Thread.sleep(1000L); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - InOrder inOrder = inOrder(mockXdsServingStatusListener); - inOrder.verify(mockXdsServingStatusListener, times(2)).onNotServing(argCaptor.capture()); - List throwableList = argCaptor.getAllValues(); - assertThat(throwableList.size()).isEqualTo(2); - Throwable throwable = throwableList.remove(0); - assertThat(throwable).isInstanceOf(IOException.class); - assertThat(throwable.getCause()).isInstanceOf(BindException.class); - throwable = throwableList.remove(0); - assertThat(throwable).isInstanceOf(BindException.class); - inOrder.verify(mockXdsServingStatusListener).onServing(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - serverWrapperForXds.shutdown(); - } - - @Test - public void delegateInitialStartError() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - doThrow(new IOException("test exception")).when(mockServer).start(); - new Thread(new Runnable() { - @Override - public void run() { - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - } - }).start(); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isInstanceOf(IOException.class); - assertThat(exception).hasMessageThat().isEqualTo("test exception"); - } - - private void verifyCapturedCodeAndNotServing(Status.Code expected, - ServerWrapperForXds.ServingState servingState) { - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener, times(expected != null ? 1 : 0)) - .onNotServing(argCaptor.capture()); - if (expected != null) { - Throwable throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(expected); - } - assertThat(serverWrapperForXds.getCurrentServingState()).isEqualTo(servingState); - reset(mockXdsServingStatusListener); - } - - @Test - public void start_internalError() - throws InterruptedException, TimeoutException, ExecutionException { - Future future = startServerAsync(); - listenerWatcher.onError(Status.INTERNAL); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isInstanceOf(IOException.class); - Throwable cause = exception.getCause(); - assertThat(cause).isInstanceOf(StatusException.class); - assertThat(((StatusException) cause).getStatus().getCode()).isEqualTo(Status.Code.INTERNAL); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); - } - - @Test - public void delegateStartError_shutdown() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - listenerWatcher.onResourceDoesNotExist("name"); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - Server mockServer = mock(Server.class); - doThrow(new IOException(new NoRouteToHostException())).when(mockServer).start(); - when(mockDelegateBuilder.build()).thenReturn(mockServer); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"), - tlsContextManager); - Thread.sleep(100L); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); - } - - @Test - public void shutdownDuringRestart() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - listenerWatcher.onResourceDoesNotExist("name"); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - Server mockServer = mock(Server.class); - final SettableFuture settableFutureForStart = SettableFuture.create(); - final SettableFuture settableFutureToSignalStart = SettableFuture.create(); - final SettableFuture settableFutureForInterrupt = SettableFuture.create(); - doAnswer(new Answer() { - @Override - public Server answer(InvocationOnMock invocation) - throws InterruptedException, ExecutionException { - settableFutureToSignalStart.set(null); - try { - settableFutureForStart.get(); - } catch (InterruptedException | CancellationException e) { - settableFutureForInterrupt.set(e); - throw e; - } - return null; // never reach here - } - }).when(mockServer).start(); - when(mockDelegateBuilder.build()).thenReturn(mockServer); - new Thread(new Runnable() { - @Override - public void run() { - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - } - }).start(); - assertThat(settableFutureToSignalStart.get()).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.ENTER_SERVING); - serverWrapperForXds.shutdown(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); - Throwable interruptedException = settableFutureForInterrupt.get(1L, TimeUnit.SECONDS); - assertThat(interruptedException).isInstanceOf(InterruptedException.class); - } -} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index a8566a682bb..0f92687f443 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -17,6 +17,8 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector.NO_FILTER_CHAIN; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; @@ -27,50 +29,84 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Server; +import io.grpc.ServerBuilder; import io.grpc.Status; -import io.grpc.StatusException; import io.grpc.inprocess.InProcessSocketAddress; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProvider; import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.DefaultHttp2FrameReader; +import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.net.UnknownHostException; import java.util.Arrays; import java.util.Collections; -import org.junit.After; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -/** Tests for {@link XdsClientWrapperForServerSds}. */ +/** Migration test XdsServerWrapper from previous XdsClientWrapperForServerSds. */ @RunWith(JUnit4.class) public class XdsClientWrapperForServerSdsTestMisc { private static final int PORT = 7000; - @Mock private Channel channel; + private EmbeddedChannel channel; + private ChannelPipeline pipeline; @Mock private TlsContextManager tlsContextManager; - @Mock private XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher; - - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsClient.LdsResourceWatcher registeredWatcher; private InetSocketAddress localAddress; private DownstreamTlsContext tlsContext1; private DownstreamTlsContext tlsContext2; private DownstreamTlsContext tlsContext3; + @Mock + private ServerBuilder mockBuilder; + @Mock + Server mockServer; + @Mock + private XdsServingStatusListener listener; + private FakeXdsClient xdsClient = new FakeXdsClient(); + private AtomicReference selectorRef = new AtomicReference<>(); + private XdsServerWrapper xdsServerWrapper; + + @Before - public void setUp() throws IOException { + public void setUp() { MockitoAnnotations.initMocks(this); tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -78,54 +114,51 @@ public void setUp() throws IOException { CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); tlsContext3 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(PORT, tlsContextManager); - } - - @After - public void tearDown() { - xdsClientWrapperForServerSds.shutdown(); - } - - @Test - public void nonInetSocketAddress_expectNull() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - assertThat( - sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager)) - .isNull(); + when(mockBuilder.build()).thenReturn(mockServer); + when(mockServer.isShutdown()).thenReturn(false); + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:" + PORT, mockBuilder, listener, + selectorRef, new FakeXdsClientPoolFactory(xdsClient)); } @Test - public void nonMatchingPort_expectException() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - try { - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT + 1); - sendListenerUpdate(localAddress, null, null, tlsContextManager); - fail("exception expected"); - } catch (IllegalStateException expected) { - assertThat(expected) - .hasMessageThat() - .isEqualTo("Channel localAddress port does not match requested listener port"); - } + public void nonInetSocketAddress_expectNull() throws Exception { + sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager); + assertThat(getSslContextProviderSupplier(selectorRef.get())).isNull(); } @Test - public void emptyFilterChain_expectNull() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); + public void emptyFilterChain_expectNull() throws Exception { InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - ArgumentCaptor listenerWatcherCaptor = ArgumentCaptor - .forClass(null); - XdsClient xdsClient = xdsClientWrapperForServerSds.getXdsClient(); - verify(xdsClient) - .watchLdsResource(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT), - listenerWatcherCaptor.capture()); - XdsClient.LdsResourceWatcher registeredWatcher = listenerWatcherCaptor.getValue(); - when(channel.localAddress()).thenReturn(localAddress); + final InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); + InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); + final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); + channel = new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return localAddress; + } + + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + }; + pipeline = channel.pipeline(); + + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT); + EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -133,50 +166,111 @@ public void emptyFilterChain_expectNull() throws UnknownHostException { Collections.emptyList(), null); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext = getDownstreamTlsContext(); - assertThat(tlsContext).isNull(); - } - - @Test - public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException { - registerWatcherAndCreateListenerUpdate(tlsContext1); - verify(mockServerWatcher).onListenerUpdate(); + xdsClient.ldsWatcher.onChanged(listenerUpdate); + start.get(5, TimeUnit.SECONDS); + FilterChainSelector selector = selectorRef.get(); + assertThat(getSslContextProviderSupplier(selector)).isNull(); } @Test - public void registerServerWatcher_notifyNotFound() throws UnknownHostException { - commonErrorCheck(true, Status.NOT_FOUND, true); + public void registerServerWatcher_notifyNotFound() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsWatched); + try { + start.get(5, TimeUnit.SECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(selectorRef.get()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test - public void registerServerWatcher_notifyInternalError() throws UnknownHostException { - commonErrorCheck(false, Status.INTERNAL, false); + public void registerServerWatcher_notifyInternalError() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onError(Status.INTERNAL); + try { + start.get(5, TimeUnit.SECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(selectorRef.get()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test - public void registerServerWatcher_notifyPermDeniedError() throws UnknownHostException { - commonErrorCheck(false, Status.PERMISSION_DENIED, true); + public void registerServerWatcher_notifyPermDeniedError() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); + try { + start.get(5, TimeUnit.SECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(selectorRef.get()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test - public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws UnknownHostException { - registerWatcherAndCreateListenerUpdate(tlsContext1); - XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext2, tlsContextManager); + public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws Exception { + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext2, null, + tlsContextManager); verify(tlsContextManager, never()) .findOrCreateServerSslContextProvider(any(DownstreamTlsContext.class)); } @Test - public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHostException { + public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = getSslContextProviderSupplier(selectorRef.get()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); XdsServerTestHelper - .generateListenerUpdate(registeredWatcher, Arrays.asList(1234), tlsContext2, + .generateListenerUpdate(xdsClient, Arrays.asList(1234), tlsContext2, tlsContext3, tlsContextManager); + returnedSupplier = getSslContextProviderSupplier(selectorRef.get()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext2); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); reset(tlsContextManager); SslContextProvider sslContextProvider2 = mock(SslContextProvider.class); @@ -185,129 +279,173 @@ public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHo SslContextProvider sslContextProvider3 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext3))) .thenReturn(sslContextProvider3); - callUpdateSslContext(channel); + callUpdateSslContext(returnedSupplier); InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); - InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111); - when(channel.remoteAddress()).thenReturn(remoteAddress); - callUpdateSslContext(channel); - XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient(); - xdsClientWrapperForServerSds.shutdown(); - verify(mockXdsClient, times(1)) - .cancelLdsResourceWatch(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT), - eq(registeredWatcher)); + final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111); + channel = new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return localAddress; + } + + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + }; + pipeline = channel.pipeline(); + returnedSupplier = getSslContextProviderSupplier(selectorRef.get()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext3); + callUpdateSslContext(returnedSupplier); + xdsServerWrapper.shutdown(); + assertThat(xdsClient.ldsResource).isNull(); verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider2)); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider3)); } @Test - public void releaseOldSupplierOnNotFound_verifyClose() throws UnknownHostException { + public void releaseOldSupplierOnNotFound_verifyClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); - registeredWatcher.onResourceDoesNotExist("not-found Error"); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorRef.get()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); + xdsClient.ldsWatcher.onResourceDoesNotExist("not-found Error"); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } @Test - public void releaseOldSupplierOnPermDeniedError_verifyClose() throws UnknownHostException { + public void releaseOldSupplierOnPermDeniedError_verifyClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); - registeredWatcher.onError(Status.PERMISSION_DENIED); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorRef.get()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); + xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } @Test - public void releaseOldSupplierOnInternalError_noClose() throws UnknownHostException { + public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); - registeredWatcher.onError(Status.INTERNAL); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorRef.get()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); + xdsClient.ldsWatcher.onError(Status.CANCELLED); verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); } - private void callUpdateSslContext(Channel channel) { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + private void callUpdateSslContext(SslContextProviderSupplier sslContextProviderSupplier) { assertThat(sslContextProviderSupplier).isNotNull(); SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); sslContextProviderSupplier.updateSslContext(callback); } - private void registerWatcherAndCreateListenerUpdate(DownstreamTlsContext tlsContext) - throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - localAddress = new InetSocketAddress(ipLocalAddress, PORT); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext, null, - tlsContextManager); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - } + private void sendListenerUpdate( + final SocketAddress localAddress, DownstreamTlsContext tlsContext, + DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) + throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + XdsServerTestHelper + .generateListenerUpdate(xdsClient, Arrays.asList(), tlsContext, + tlsContextForDefaultFilterChain, tlsContextManager); + start.get(5, TimeUnit.SECONDS); + InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); + final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); + channel = new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return localAddress; + } - private void commonErrorCheck(boolean generateResourceDoesNotExist, Status status, - boolean isAbsent) throws UnknownHostException { - registerWatcherAndCreateListenerUpdate(tlsContext1); - reset(mockServerWatcher); - if (generateResourceDoesNotExist) { - registeredWatcher.onResourceDoesNotExist("not-found Error"); - } else { - registeredWatcher.onError(status); - } - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockServerWatcher).onError(argCaptor.capture(), eq(isAbsent)); - Throwable throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(status.getCode()); - if (isAbsent) { - assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNull(); - } else { - assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNotNull(); - } + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + }; + pipeline = channel.pipeline(); } - private DownstreamTlsContext sendListenerUpdate( - SocketAddress localAddress, DownstreamTlsContext tlsContext, - DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) - throws UnknownHostException { - when(channel.localAddress()).thenReturn(localAddress); - InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); - InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); - when(channel.remoteAddress()).thenReturn(remoteAddress); - XdsServerTestHelper - .generateListenerUpdate(registeredWatcher, Arrays.asList(), tlsContext, - tlsContextForDefaultFilterChain, tlsContextManager); - return getDownstreamTlsContext(); + private SslContextProviderSupplier getSslContextProviderSupplier( + FilterChainSelector selector) throws Exception { + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + ctx.pipeline().remove(this); + } + }; + ProtocolNegotiator mockDelegate = mock(ProtocolNegotiator.class); + GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + pipeline.addLast(filterChainMatchingHandler); + ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(event) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + return sslSet.get(); } - private DownstreamTlsContext getDownstreamTlsContext() { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); - if (sslContextProviderSupplier != null) { - EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); - assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class); - return (DownstreamTlsContext)tlsContext; + private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { + FakeGrpcHttp2ConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); } - return null; - } - /** Creates XdsClientWrapperForServerSds: also used by other classes. */ - public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds( - int port, DownstreamTlsContext downstreamTlsContext, TlsContextManager tlsContextManager) { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); - xdsClientWrapperForServerSds.start(); - XdsSdsClientServerTest.generateListenerUpdateToWatcher( - downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher(), tlsContextManager); - return xdsClientWrapperForServerSds; + static FakeGrpcHttp2ConnectionHandler newHandler() { + DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false); + DefaultHttp2ConnectionEncoder encoder = + new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader()); + Http2Settings settings = new Http2Settings(); + return new FakeGrpcHttp2ConnectionHandler( + /*channelUnused=*/ null, decoder, encoder, settings); + } + + @Override + public String getAuthority() { + return "authority"; + } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index c71841b2678..ecfa9e6b9bc 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -29,6 +29,8 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -38,6 +40,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; +import io.grpc.Server; import io.grpc.ServerCredentials; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -48,13 +51,19 @@ import io.grpc.testing.protobuf.SimpleServiceGrpc; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.grpc.xds.internal.sds.TlsContextManagerImpl; import io.netty.handler.ssl.NotSslRecordException; -import java.io.IOException; import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.URI; @@ -63,10 +72,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import org.junit.After; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -86,14 +96,9 @@ public class XdsSdsClientServerTest { private Bootstrapper.BootstrapInfo bootstrapInfoForServer = null; private TlsContextManagerImpl tlsContextManagerForClient; private TlsContextManagerImpl tlsContextManagerForServer; - - @Before - public void setUp() throws Exception { - port = XdsServerTestHelper.findFreePort(); - URI expectedUri = new URI("sdstest://localhost:" + port); - fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); - NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); - } + private FakeXdsClient xdsClient = new FakeXdsClient(); + private FakeXdsClientPoolFactory fakePoolFactory = new FakeXdsClientPoolFactory(xdsClient); + private static final String OVERRIDE_AUTHORITY = "foo.test.google.fr"; @After public void tearDown() { @@ -103,16 +108,17 @@ public void tearDown() { } @Test - public void plaintextClientServer() throws IOException, URISyntaxException { + public void plaintextClientServer() throws Exception { buildServerWithTlsContext(/* downstreamTlsContext= */ null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); + getBlockingStub(/* upstreamTlsContext= */ null, + /* overrideAuthority= */ null); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } @Test - public void nullFallbackCredentials_expectException() throws IOException, URISyntaxException { + public void nullFallbackCredentials_expectException() throws Exception { try { buildServerWithTlsContext(/* downstreamTlsContext= */ null, /* fallbackCredentials= */ null); fail("exception expected"); @@ -123,7 +129,7 @@ public void nullFallbackCredentials_expectException() throws IOException, URISyn /** TLS channel - no mTLS. */ @Test - public void tlsClientServer_noClientAuthentication() throws IOException, URISyntaxException { + public void tlsClientServer_noClientAuthentication() throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); buildServerWithTlsContext(downstreamTlsContext); @@ -134,13 +140,13 @@ public void tlsClientServer_noClientAuthentication() throws IOException, URISynt CLIENT_PEM_FILE, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); } @Test public void requireClientAuth_noClientCert_expectException() - throws IOException, URISyntaxException { + throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, true, true); buildServerWithTlsContext(downstreamTlsContext); @@ -168,7 +174,7 @@ public void requireClientAuth_noClientCert_expectException() } @Test - public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException { + public void noClientAuth_sendBadClientCert_passes() throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); buildServerWithTlsContext(downstreamTlsContext); @@ -183,12 +189,12 @@ public void noClientAuth_sendBadClientCert_passes() throws IOException, URISynta } @Test - public void mtls_badClientCert_expectException() throws IOException, URISyntaxException { + public void mtls_badClientCert_expectException() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, true); try { - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, null, null, null, null); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); fail("exception expected"); } catch (StatusRuntimeException sre) { if (sre.getCause() instanceof SSLHandshakeException) { @@ -202,27 +208,18 @@ public void mtls_badClientCert_expectException() throws IOException, URISyntaxEx } } - /** mTLS - client auth enabled. */ - @Test - public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException { - UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, null, null, null, null); - } - /** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */ @Test public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds() - throws IOException, URISyntaxException { + throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( CLIENT_KEY_FILE, CLIENT_PEM_FILE, true); - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true, null, null, null, null); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); } @Test - public void tlsServer_plaintextClient_expectException() throws IOException, URISyntaxException { + public void tlsServer_plaintextClient_expectException() throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); buildServerWithTlsContext(downstreamTlsContext); @@ -239,7 +236,7 @@ public void tlsServer_plaintextClient_expectException() throws IOException, URIS } @Test - public void plaintextServer_tlsClient_expectException() throws IOException, URISyntaxException { + public void plaintextServer_tlsClient_expectException() throws Exception { buildServerWithTlsContext(/* downstreamTlsContext= */ null); // for TLS, client only needs trustCa @@ -261,19 +258,20 @@ public void plaintextServer_tlsClient_expectException() throws IOException, URIS /** mTLS - client auth enabled then update server certs to untrusted. */ @Test public void mtlsClientServer_changeServerContext_expectException() - throws IOException, URISyntaxException { + throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( CLIENT_KEY_FILE, CLIENT_PEM_FILE, true); - XdsClient.LdsResourceWatcher listenerWatcher = - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, "cert-instance-name2", + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, "cert-instance-name2", BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "cert-instance-name2", true, true); - generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher, - tlsContextManagerForServer); + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", + downstreamTlsContext, + tlsContextManagerForServer); + xdsClient.deliverLdsUpdate(LdsUpdate.forTcpListener(listener)); try { SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); @@ -285,27 +283,20 @@ public void mtlsClientServer_changeServerContext_expectException() } } - private XdsClient.LdsResourceWatcher performMtlsTestAndGetListenerWatcher( - UpstreamTlsContext upstreamTlsContext, boolean newApi, String certInstanceName2, + private void performMtlsTestAndGetListenerWatcher( + UpstreamTlsContext upstreamTlsContext, String certInstanceName2, String privateKey2, String cert2, String trustCa2) - throws IOException, URISyntaxException { + throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(certInstanceName2, privateKey2, cert2, trustCa2, true, true); - final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - createXdsClientWrapperForServerSds(port); buildServerWithFallbackServerCredentials( - xdsClientWrapperForServerSds, InsecureServerCredentials.create(), downstreamTlsContext); - - XdsClient.LdsResourceWatcher listenerWatcher = xdsClientWrapperForServerSds - .getListenerWatcher(); + InsecureServerCredentials.create(), downstreamTlsContext); - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = newApi - ? getBlockingStub(upstreamTlsContext, "foo.test.google.fr") : - getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); - return listenerWatcher; } private DownstreamTlsContext setBootstrapInfoAndBuildDownstreamTlsContext( @@ -330,36 +321,25 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli } private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) - throws IOException { + throws Exception { buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create()); } private void buildServerWithTlsContext( DownstreamTlsContext downstreamTlsContext, ServerCredentials fallbackCredentials) - throws IOException { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - createXdsClientWrapperForServerSds(port); - xdsClientWrapperForServerSds.start(); - buildServerWithFallbackServerCredentials( - xdsClientWrapperForServerSds, fallbackCredentials, downstreamTlsContext); + throws Exception { + buildServerWithFallbackServerCredentials(fallbackCredentials, downstreamTlsContext); } private void buildServerWithFallbackServerCredentials( - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, ServerCredentials fallbackCredentials, DownstreamTlsContext downstreamTlsContext) - throws IOException { + throws Exception { ServerCredentials xdsCredentials = XdsServerCredentials.create(fallbackCredentials); - buildServer(port, xdsCredentials, xdsClientWrapperForServerSds, downstreamTlsContext); - } - - /** Creates XdsClientWrapperForServerSds. */ - private XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) { - tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManagerForServer); - xdsClientWrapperForServerSds.start(); - return xdsClientWrapperForServerSds; + XdsServerBuilder builder = XdsServerBuilder.forPort(0, xdsCredentials) + .xdsClientPoolFactory(fakePoolFactory) + .addService(new SimpleServiceImpl()); + buildServer(builder, downstreamTlsContext); } static void generateListenerUpdateToWatcher( @@ -372,18 +352,21 @@ static void generateListenerUpdateToWatcher( } private void buildServer( - int port, - ServerCredentials serverCredentials, - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, + XdsServerBuilder builder, DownstreamTlsContext downstreamTlsContext) - throws IOException { - XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials) - .addService(new SimpleServiceImpl()); + throws Exception { tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsServerTestHelper.generateListenerUpdate( - xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext, - tlsContextManagerForServer); - cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start(); + XdsServerWrapper xdsServer = (XdsServerWrapper) builder.build(); + SettableFuture startFuture = startServerAsync(xdsServer); + EnvoyServerProtoData.Listener listener = buildListener("listener1", "10.1.2.3", + downstreamTlsContext, tlsContextManagerForServer); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + startFuture.get(10, TimeUnit.SECONDS); + port = xdsServer.getPort(); + URI expectedUri = new URI("sdstest://localhost:" + port); + fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); + NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); } static EnvoyServerProtoData.Listener buildListener( @@ -399,9 +382,19 @@ static EnvoyServerProtoData.Listener buildListener( Arrays.asList(), Arrays.asList(), null); - // HttpConnectionManager currently not used for server side. - HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( - 0L, "does not matter", Collections.emptyList()); + String fullPath = "/" + SimpleServiceGrpc.SERVICE_NAME + "/" + "UnaryRpc"; + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath(fullPath, true), + Collections.emptyList(), null); + VirtualHost virtualHost = VirtualHost.create( + "virtual-host", Collections.singletonList(OVERRIDE_AUTHORITY), + Arrays.asList(Route.forAction(routeMatch, null, + ImmutableMap.of())), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), + new ArrayList()); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-foo", filterChainMatch, httpConnectionManager, tlsContext, tlsContextManager); @@ -445,6 +438,24 @@ private static String unaryRpc( return response.getResponseMessage(); } + private SettableFuture startServerAsync(final Server xdsServer) throws Exception { + cleanupRule.register(xdsServer); + final SettableFuture settableFuture = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + xdsServer.start(); + settableFuture.set(null); + } catch (Throwable e) { + settableFuture.set(e); + } + } + }); + xdsClient.ldsResource.get(8000, TimeUnit.MILLISECONDS); + return settableFuture; + } + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { @Override diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 0b174a4a313..4f4844f0b0b 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -33,8 +33,9 @@ import io.grpc.Status; import io.grpc.StatusException; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.ServerWrapperForXds; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -51,6 +52,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +// TODO (zivy@): move certain tests down to XdsServerWrapperTest, or up to XdsSdsClientServerTest. /** * Unit tests for {@link XdsServerBuilder}. */ @@ -59,31 +61,27 @@ public class XdsServerBuilderTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private XdsServerBuilder builder; - private ServerWrapperForXds xdsServer; - private XdsClient.LdsResourceWatcher listenerWatcher; + private XdsServerWrapper xdsServer; private int port; - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private TlsContextManager tlsContextManager; + private FakeXdsClient xdsClient = new FakeXdsClient(); private void buildServer(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) throws IOException { buildBuilder(xdsServingStatusListener); - xdsServer = cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); } private void buildBuilder(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) throws IOException { - port = XdsServerTestHelper.findFreePort(); builder = XdsServerBuilder.forPort( port, XdsServerCredentials.create(InsecureServerCredentials.create())); + builder.xdsClientPoolFactory(new FakeXdsClientPoolFactory(xdsClient)); if (xdsServingStatusListener != null) { - builder = builder.xdsServingStatusListener(xdsServingStatusListener); + builder.xdsServingStatusListener(xdsServingStatusListener); } tlsContextManager = mock(TlsContextManager.class); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(port, tlsContextManager); - listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); } private void verifyServer( @@ -99,7 +97,7 @@ private void verifyServer( assertThat(list).hasSize(1); InetSocketAddress socketAddress = (InetSocketAddress) list.get(0); assertThat(socketAddress.getAddress().isAnyLocalAddress()).isTrue(); - assertThat(socketAddress.getPort()).isEqualTo(port); + assertThat(socketAddress.getPort()).isGreaterThan(-1); if (mockXdsServingStatusListener != null) { if (notServingStatus != null) { ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); @@ -117,10 +115,11 @@ private void verifyServer( private void verifyShutdown() throws InterruptedException { xdsServer.shutdown(); xdsServer.awaitTermination(500L, TimeUnit.MILLISECONDS); - assertThat(xdsClientWrapperForServerSds.getXdsClient()).isNull(); + assertThat(xdsClient.isShutDown()).isTrue(); } - private Future startServerAsync() throws InterruptedException { + private Future startServerAsync() throws + InterruptedException, TimeoutException, ExecutionException { final SettableFuture settableFuture = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -133,15 +132,7 @@ public void run() { } } }); - // wait until xdsClientWrapperForServerSds.serverWatchers populated - for (int i = 0; i < 10; i++) { - synchronized (xdsClientWrapperForServerSds.serverWatchers) { - if (!xdsClientWrapperForServerSds.serverWatchers.isEmpty()) { - break; - } - } - Thread.sleep(100L); - } + xdsClient.ldsResource.get(5000, TimeUnit.MILLISECONDS); return settableFuture; } @@ -151,29 +142,29 @@ public void xdsServerStartAndShutdown() buildServer(null); Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), - tlsContextManager); + tlsContextManager); verifyServer(future, null, null); verifyShutdown(); } @Test - public void xdsServerStartAfterListenerUpdate() + public void xdsServerRestartAfterListenerUpdate() throws IOException, InterruptedException, TimeoutException, ExecutionException { buildServer(null); + Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); - xdsServer.start(); try { xdsServer.start(); fail("expected exception"); } catch (IllegalStateException expected) { assertThat(expected).hasMessageThat().contains("Already started"); } - verifyServer(null,null, null); + verifyServer(future,null, null); } @Test @@ -184,48 +175,34 @@ public void xdsServerStartAndShutdownWithXdsServingStatusListener() buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); } @Test - public void xdsServer_serverWatcher() - throws IOException, InterruptedException, TimeoutException, ExecutionException { + public void xdsServer_discoverState() throws Exception { XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); - listenerWatcher.onError(Status.ABORTED); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener).onNotServing(argCaptor.capture()); - Throwable throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.ABORTED); - assertThat(xdsClientWrapperForServerSds.serverWatchers).hasSize(1); - assertThat(future.isDone()).isFalse(); + XdsServerTestHelper.generateListenerUpdate( + xdsClient, + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); + future.get(5000, TimeUnit.MILLISECONDS); + xdsClient.ldsWatcher.onError(Status.ABORTED); + verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - listenerWatcher.onError(Status.NOT_FOUND); - argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener).onNotServing(argCaptor.capture()); - throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND); + xdsClient.ldsWatcher.onError(Status.CANCELLED); + verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - listenerWatcher.onResourceDoesNotExist("not found error"); - argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener).onNotServing(argCaptor.capture()); - throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND); - assertThat(future.isDone()).isFalse(); + xdsClient.ldsWatcher.onResourceDoesNotExist("not found error"); + verify(mockXdsServingStatusListener).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); @@ -236,12 +213,13 @@ public void xdsServer_startError() throws IOException, InterruptedException, TimeoutException, ExecutionException { XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); + ServerSocket serverSocket = new ServerSocket(0); + port = serverSocket.getLocalPort(); buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); // create port conflict for start to fail - ServerSocket serverSocket = new ServerSocket(port); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); Throwable exception = future.get(5, TimeUnit.SECONDS); @@ -259,16 +237,16 @@ public void xdsServerStartSecondUpdateAndError() buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verifyServer(future, mockXdsServingStatusListener, null); - listenerWatcher.onError(Status.ABORTED); + xdsClient.ldsWatcher.onError(Status.ABORTED); verifyServer(null, mockXdsServingStatusListener, null); } @@ -295,7 +273,7 @@ public void xdsServer_2ndSetter_expectException() throws IOException { .builder("mock").build(); when(mockBindableService.bindService()).thenReturn(serverServiceDefinition); builder.addService(mockBindableService); - xdsServer = cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); try { builder.addService(mock(BindableService.class)); fail("exception expected"); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 2c455673239..078299945e5 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -16,23 +16,27 @@ package io.grpc.xds; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static com.google.common.truth.Truth.assertThat; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.InsecureChannelCredentials; import io.grpc.internal.ObjectPool; +import io.grpc.xds.Bootstrapper.BootstrapInfo; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.Listener; +import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.XdsClient.LdsUpdate; -import java.io.IOException; -import java.net.ServerSocket; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import javax.annotation.Nullable; -import org.mockito.ArgumentCaptor; /** * Helper methods related to {@link XdsServerBuilder} and related classes. @@ -52,8 +56,61 @@ public class XdsServerTestHelper { null, "grpc/server?udpa.resource.listening_address=%s"); + static void generateListenerUpdate(FakeXdsClient xdsClient, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", + Arrays.asList(), tlsContext, null, tlsContextManager); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + } + + static void generateListenerUpdate( + FakeXdsClient xdsClient, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, + tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + } + + static EnvoyServerProtoData.Listener buildTestListener( + String name, String address, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + null, + sourcePorts, + Arrays.asList(), + null); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch1, httpConnectionManager, tlsContext, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, httpConnectionManager, tlsContextForDefaultFilterChain, + tlsContextManager); + EnvoyServerProtoData.Listener listener = + new EnvoyServerProtoData.Listener( + name, address, Arrays.asList(filterChain1), defaultFilterChain); + return listener; + } + static final class FakeXdsClientPoolFactory - implements XdsNameResolverProvider.XdsClientPoolFactory { + implements XdsNameResolverProvider.XdsClientPoolFactory { private XdsClient xdsClient; @@ -82,103 +139,77 @@ public XdsClient getObject() { @Override public XdsClient returnObject(Object object) { + xdsClient.shutdown(); return null; } }; } } - /** Create an XdsClientWrapperForServerSds with a mock XdsClient. */ - public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port, - TlsContextManager tlsContextManager) { - FakeXdsClientPoolFactory fakeXdsClientPoolFactory = new FakeXdsClientPoolFactory( - buildMockXdsClient(tlsContextManager)); - return new XdsClientWrapperForServerSds(port, fakeXdsClientPoolFactory); - } + static final class FakeXdsClient extends XdsClient { + boolean shutdown; + SettableFuture ldsResource = SettableFuture.create(); + LdsResourceWatcher ldsWatcher; + CountDownLatch rdsCount = new CountDownLatch(1); + final Map rdsWatchers = new HashMap<>(); - private static XdsClient buildMockXdsClient(TlsContextManager tlsContextManager) { - XdsClient xdsClient = mock(XdsClient.class); - when(xdsClient.getBootstrapInfo()).thenReturn(BOOTSTRAP_INFO); - when(xdsClient.getTlsContextManager()).thenReturn(tlsContextManager); - return xdsClient; - } + @Override + public TlsContextManager getTlsContextManager() { + return null; + } - static XdsClient.LdsResourceWatcher startAndGetWatcher( - XdsClientWrapperForServerSds xdsClientWrapperForServerSds) { - xdsClientWrapperForServerSds.start(); - XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient(); - ArgumentCaptor listenerWatcherCaptor = - ArgumentCaptor.forClass(null); - verify(mockXdsClient).watchLdsResource(any(String.class), listenerWatcherCaptor.capture()); - return listenerWatcherCaptor.getValue(); - } + @Override + public BootstrapInfo getBootstrapInfo() { + return BOOTSTRAP_INFO; + } - /** - * Creates a {@link XdsClient.LdsUpdate} with {@link - * io.grpc.xds.EnvoyServerProtoData.FilterChain} with a destination port and an optional {@link - * EnvoyServerProtoData.DownstreamTlsContext}. - * @param registeredWatcher the watcher on which to generate the update - * @param tlsContext if non-null, used to populate filterChain - */ - static void generateListenerUpdate( - XdsClient.LdsResourceWatcher registeredWatcher, - EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", - Arrays.asList(), tlsContext, null, tlsContextManager); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - } + @Override + void watchLdsResource(String resourceName, LdsResourceWatcher watcher) { + assertThat(ldsWatcher).isNull(); + ldsWatcher = watcher; + ldsResource.set(resourceName); + } - static void generateListenerUpdate( - XdsClient.LdsResourceWatcher registeredWatcher, List sourcePorts, - EnvoyServerProtoData.DownstreamTlsContext tlsContext, - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, - TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, - tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - } + @Override + void cancelLdsResourceWatch(String resourceName, LdsResourceWatcher watcher) { + assertThat(ldsWatcher).isNotNull(); + ldsResource = null; + ldsWatcher = null; + } - public static void generateListenerUpdate( - XdsClient.LdsResourceWatcher registeredWatcher, EnvoyServerProtoData.Listener listener) { - registeredWatcher.onChanged(LdsUpdate.forTcpListener(listener)); - } + @Override + void watchRdsResource(String resourceName, RdsResourceWatcher watcher) { + rdsWatchers.put(resourceName, watcher); + rdsCount.countDown(); + } - static int findFreePort() throws IOException { - try (ServerSocket socket = new ServerSocket(0)) { - socket.setReuseAddress(true); - return socket.getLocalPort(); + @Override + void cancelRdsResourceWatch(String resourceName, RdsResourceWatcher watcher) { + rdsWatchers.remove(resourceName); } - } - static EnvoyServerProtoData.Listener buildTestListener( - String name, String address, List sourcePorts, - EnvoyServerProtoData.DownstreamTlsContext tlsContext, - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, - TlsContextManager tlsContextManager) { - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - null, - sourcePorts, - Arrays.asList(), - null); - // HttpConnectionManager currently not used for server side. - HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( - 0L, "does not matter", Collections.emptyList()); - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch1, httpConnectionManager, tlsContext, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, httpConnectionManager, tlsContextForDefaultFilterChain, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - name, address, Arrays.asList(filterChain1), defaultFilterChain); - return listener; + @Override + void shutdown() { + shutdown = true; + } + + @Override + boolean isShutDown() { + return shutdown; + } + + void deliverLdsUpdate(List filterChains, + FilterChain defaultFilterChain) { + ldsWatcher.onChanged(LdsUpdate.forTcpListener(new Listener( + "listener", "0.0.0.0:1", filterChains, defaultFilterChain))); + } + + void deliverLdsUpdate(LdsUpdate ldsUpdate) { + ldsWatcher.onChanged(ldsUpdate); + } + + void deliverRdsUpdate(String rdsName, List virtualHosts) { + rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); + } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java new file mode 100644 index 00000000000..3463d2b14a9 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -0,0 +1,456 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerWrapper.RETRY_DELAY_NANOS; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.FakeClock; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class XdsServerWrapperTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private ServerBuilder mockBuilder; + @Mock + private Server mockServer; + @Mock + private static TlsContextManager tlsContextManager; + @Mock + private XdsServingStatusListener listener; + + private AtomicReference selectorRef = new AtomicReference<>(); + private FakeClock executor = new FakeClock(); + private FakeXdsClient xdsClient = new FakeXdsClient(); + private XdsServerWrapper xdsServerWrapper; + + @Before + public void setup() { + when(mockBuilder.build()).thenReturn(mockServer); + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, + selectorRef, new FakeXdsClientPoolFactory(xdsClient), + executor.getScheduledExecutorService()); + } + + @Test + public void testBootstrap_notV3() throws Exception { + Bootstrapper.BootstrapInfo b = + new Bootstrapper.BootstrapInfo( + Arrays.asList( + new Bootstrapper.ServerInfo("uri", InsecureChannelCredentials.create(), false)), + EnvoyProtoData.Node.newBuilder().setId("id").build(), + null, + "grpc/server?udpa.resource.listening_address=%s"); + verifyBootstrapFail(b); + } + + @Test + public void testBootstrap_noTemplate() throws Exception { + Bootstrapper.BootstrapInfo b = + new Bootstrapper.BootstrapInfo( + Arrays.asList( + new Bootstrapper.ServerInfo("uri", InsecureChannelCredentials.create(), true)), + EnvoyProtoData.Node.newBuilder().setId("id").build(), + null, + null); + verifyBootstrapFail(b); + } + + private void verifyBootstrapFail(Bootstrapper.BootstrapInfo b) throws Exception { + XdsClient xdsClient = mock(XdsClient.class); + when(xdsClient.getBootstrapInfo()).thenReturn(b); + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, + selectorRef, new FakeXdsClientPoolFactory(xdsClient)); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + Throwable cause = ex.getCause().getCause(); + assertThat(cause).isInstanceOf(StatusException.class); + assertThat(((StatusException)cause).getStatus().getCode()) + .isEqualTo(Status.UNAVAILABLE.getCode()); + } + } + + + @Test + public void shutdown() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(createVirtualHost("virtual-host-0")), + new ArrayList()); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + SslContextProviderSupplier sslSupplier = f0.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(f0), null); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + xdsServerWrapper.shutdown(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.ldsResource).isNull(); + assertThat(xdsClient.shutdown).isTrue(); + verify(mockServer).shutdown(); + assertThat(sslSupplier.isShutdown()).isTrue(); + when(mockServer.isTerminated()).thenReturn(true); + when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); + assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + assertThat(start.get()).isSameInstanceAs(xdsServerWrapper); + } + + @Test + public void shutdown_afterResourceNotExist() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + verify(mockBuilder, times(1)).build(); + verify(mockServer, never()).start(); + verify(mockServer).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + when(mockServer.isTerminated()).thenReturn(true); + verify(listener, times(1)).onNotServing(any(Throwable.class)); + xdsServerWrapper.shutdown(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.ldsResource).isNull(); + assertThat(xdsClient.shutdown).isTrue(); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(1)).shutdown(); + xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + } + + @Test + public void shutdown_pendingRetry() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + when(mockServer.start()).thenThrow(new IOException("error!")); + FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); + SslContextProviderSupplier sslSupplier = filterChain.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(executor.getPendingTasks().size()).isEqualTo(1); + verify(mockServer).start(); + verify(mockServer, never()).shutdown(); + xdsServerWrapper.shutdown(); + verify(mockServer).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + when(mockServer.isTerminated()).thenReturn(true); + assertThat(sslSupplier.isShutdown()).isTrue(); + assertThat(executor.getPendingTasks().size()).isEqualTo(0); + verify(listener, never()).onNotServing(any(Throwable.class)); + verify(listener, never()).onServing(); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + } + + @Test + public void discoverState_virtualhost() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + tlsContextManager); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + start.get(5000, TimeUnit.MILLISECONDS); + FilterChainSelector selector = selectorRef.get(); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + assertThat(selector.getFilterChains()).isEqualTo(Collections.singletonList(filterChain)); + verify(listener).onServing(); + verify(mockServer).start(); + } + + @Test + public void initialStartIoException() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + when(mockServer.start()).thenThrow(new IOException("error!")); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + assertThat(ex.getCause().getMessage()).isEqualTo("error!"); + } + } + + @Test + public void error() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + verify(listener, times(1)).onNotServing(any(StatusException.class)); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(1)).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + xdsClient.ldsWatcher.onError(Status.INTERNAL); + assertThat(selectorRef.get()).isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + verify(mockBuilder, times(1)).build(); + verify(listener, times(2)).onNotServing(any(StatusException.class)); + verify(mockServer, times(1)).shutdown(); + + when(mockServer.start()).thenThrow(new IOException("error!")) + .thenReturn(mockServer); + when(mockServer.isShutdown()).thenReturn(true).thenReturn(false); + FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); + SslContextProviderSupplier sslSupplier = filterChain.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + assertThat(executor.forwardNanos(RETRY_DELAY_NANOS)).isEqualTo(1); + verify(mockBuilder, times(2)).build(); + verify(mockServer, times(2)).start(); + verify(listener, times(1)).onServing(); + verify(listener, times(2)).onNotServing(any(StatusException.class)); + assertThat(selectorRef.get().getFilterChains()).isEqualTo( + Collections.singletonList(filterChain)); + assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isNull(); + assertThat(sslSupplier.isShutdown()).isFalse(); + + // xds update after start + filterChain = createFilterChain("filter-chain-2", createRds("rds")); + FilterChain f1 = createFilterChain("filter-chain-2-0", createRds("rds")); + SslContextProviderSupplier s1 = filterChain.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), f1); + + verify(mockBuilder, times(2)).build(); + verify(mockServer, times(2)).start(); + verify(listener, times(1)).onServing(); + verify(listener, times(2)).onNotServing(any(StatusException.class)); + assertThat(selectorRef.get().getFilterChains()) + .isEqualTo(Collections.singletonList(filterChain)); + assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()) + .isEqualTo(f1.getSslContextProviderSupplier()); + assertThat(sslSupplier.isShutdown()).isTrue(); + assertThat(s1.isShutdown()).isFalse(); + + // not serving after serving + xdsClient.ldsWatcher.onError(Status.INTERNAL); + verify(mockServer, times(2)).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + assertThat(selectorRef.get()).isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + verify(listener, times(3)).onNotServing(any(StatusException.class)); + assertThat(s1.isShutdown()).isTrue(); + + // cancel retry + when(mockServer.start()).thenThrow(new IOException("error1!")) + .thenThrow(new IOException("error2!")) + .thenReturn(mockServer); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + verify(mockBuilder, times(3)).build(); + when(mockServer.isShutdown()).thenReturn(false); + verify(mockServer, times(3)).start(); + verify(listener, times(1)).onServing(); + verify(listener, times(3)).onNotServing(any(StatusException.class)); + assertThat(selectorRef.get().getFilterChains()).isEqualTo(Collections.singletonList( + filterChain) + ); + assertThat(executor.numPendingTasks()).isEqualTo(1); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + verify(mockServer, times(3)).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + verify(listener, times(4)).onNotServing(any(StatusException.class)); + assertThat(executor.numPendingTasks()).isEqualTo(0); + + // serving after not serving + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + verify(mockBuilder, times(4)).build(); + verify(mockServer, times(4)).start(); + verify(listener, times(1)).onServing(); + verify(listener, times(4)).onNotServing(any(StatusException.class)); + assertThat(executor.forwardNanos(RETRY_DELAY_NANOS)).isEqualTo(1); + verify(listener, times(2)).onServing(); + assertThat(selectorRef.get().getFilterChains()).isEqualTo(Collections.singletonList( + filterChain) + ); + } + + + private FilterChain createFilterChain(String name, HttpConnectionManager hcm) { + return new EnvoyServerProtoData.FilterChain(name, createMatch(), + hcm, createTls(), tlsContextManager); + } + + private VirtualHost createVirtualHost(String name) { + return VirtualHost.create( + name, Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + } + + private HttpConnectionManager createRds(String name) { + return HttpConnectionManager.forRdsName(0L, name, + new ArrayList()); + } + + private EnvoyServerProtoData.FilterChainMatch createMatch() { + return new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + } + + private EnvoyServerProtoData.DownstreamTlsContext createTls() { + return CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java index a8a7a9c9e30..4c89aa4b79a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java @@ -22,6 +22,7 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -42,16 +43,13 @@ import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; -import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.TlsContextManager; -import io.grpc.xds.XdsClientWrapperForServerSds; -import io.grpc.xds.XdsClientWrapperForServerSdsTestMisc; -import io.grpc.xds.XdsServerTestHelper; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsHandler; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsProtocolNegotiator; import io.netty.channel.ChannelHandler; @@ -74,7 +72,6 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertStoreException; -import java.util.Arrays; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -216,18 +213,19 @@ public SocketAddress remoteAddress() { "google_cloud_private_spiffe-server", true, true); TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext, tlsContextManager); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, - InternalProtocolNegotiators.serverPlaintext()); + new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, + InternalProtocolNegotiators.serverPlaintext()); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler // kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event) + .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager)).build(); + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr)); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNull(); channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class); @@ -278,23 +276,19 @@ public SocketAddress localAddress() { } }; pipeline = channel.pipeline(); - DownstreamTlsContext downstreamTlsContext = - DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - .getDefaultInstance()); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext, mock(TlsContextManager.class)); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( - grpcHandler, xdsClientWrapperForServerSds, mockProtocolNegotiator); + grpcHandler, mockProtocolNegotiator); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler // kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event) + .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, null).build(); + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr)); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNull(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop @@ -311,8 +305,7 @@ public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( - grpcHandler, /* xdsClientWrapperForServerSds= */ null, - mockProtocolNegotiator); + grpcHandler, mockProtocolNegotiator); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler @@ -332,8 +325,7 @@ public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( - grpcHandler, /* xdsClientWrapperForServerSds= */ null, - null); + grpcHandler, null); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler @@ -351,54 +343,6 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { } } - @Test - public void noMatchingFilterChain_expectException() { - // we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test - channel = - new EmbeddedChannel() { - @Override - public SocketAddress localAddress() { - return new InetSocketAddress("172.168.1.1", 80); - } - - @Override - public SocketAddress remoteAddress() { - return new InetSocketAddress("172.168.2.2", 90); - } - }; - pipeline = channel.pipeline(); - Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null); - - TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(80, tlsContextManager); - xdsClientWrapperForServerSds.start(); - EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( - "listener1", "0.0.0.0", Arrays.asList(), null); - XdsServerTestHelper.generateListenerUpdate( - xdsClientWrapperForServerSds.getListenerWatcher(), listener); - - SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, - InternalProtocolNegotiators.serverPlaintext()); - pipeline.addLast(handlerPickerHandler); - channelHandlerCtx = pipeline.context(handlerPickerHandler); - assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler - - // kick off protocol negotiation - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); - channelHandlerCtx = pipeline.context(handlerPickerHandler); - assertThat(channelHandlerCtx).isNotNull(); // HandlerPickerHandler still there - try { - channel.checkException(); - fail("exception expected!"); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("no matching filter chain"); - } - } - @Test public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() throws InterruptedException, TimeoutException, ExecutionException { From 51d1484c3cd057ba1b57790646e8449bda7a6c80 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 9 Aug 2021 16:06:44 -0700 Subject: [PATCH 21/82] api: Document that NameResolvers shouldn't block Fixes #8190 --- api/src/main/java/io/grpc/NameResolver.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index cd3a137dfca..f4c05aa6a64 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -50,7 +50,9 @@ *

Implementations don't need to be thread-safe. All methods are guaranteed to * be called sequentially. Additionally, all methods that have side-effects, i.e., * {@link #start(Listener2)}, {@link #shutdown} and {@link #refresh} are called from the same - * {@link SynchronizationContext} as returned by {@link Args#getSynchronizationContext}. + * {@link SynchronizationContext} as returned by {@link Args#getSynchronizationContext}. Do + * not block within the synchronization context; blocking I/O and time-consuming tasks + * should be offloaded to a separate thread, generally {@link Args#getOffloadExecutor}. * * @since 1.0.0 */ From 96a5c25056662985b83b0b078e4811b8774b3321 Mon Sep 17 00:00:00 2001 From: skyguard1 Date: Tue, 10 Aug 2021 11:22:44 +0800 Subject: [PATCH 22/82] rls: fix routeLookupClient may be null in RlsLoadBalancer.requestConnection() (#8379) --- rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java index 12903044a21..289098e2554 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java @@ -82,7 +82,9 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public void requestConnection() { - routeLookupClient.requestConnection(); + if (routeLookupClient != null) { + routeLookupClient.requestConnection(); + } } @Override From 1eb1d157a7b25e9fbd07dd794ff1ec2ffa30def2 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Wed, 11 Aug 2021 10:12:20 -0700 Subject: [PATCH 23/82] xds: allow injecting bootstrapOverride in xdsNameResolverProvider (#8358) --- .../grpc/xds/SharedXdsClientPoolProvider.java | 2 +- .../java/io/grpc/xds/XdsNameResolver.java | 11 ++-- .../io/grpc/xds/XdsNameResolverProvider.java | 28 +++++++++- .../grpc/xds/XdsNameResolverProviderTest.java | 55 +++++++++++++++++++ .../java/io/grpc/xds/XdsNameResolverTest.java | 11 ++-- 5 files changed, 92 insertions(+), 15 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index cc87b9c6b6f..95eef3e3d80 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -52,7 +52,7 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { private final AtomicReference> bootstrapOverride = new AtomicReference<>(); private volatile ObjectPool xdsClientPool; - private SharedXdsClientPoolProvider() { + SharedXdsClientPoolProvider() { this(new BootstrapperImpl()); } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 3ae6346c158..787336e4b54 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -124,22 +124,25 @@ final class XdsNameResolver extends NameResolver { private ResolveState resolveState; XdsNameResolver(String name, ServiceConfigParser serviceConfigParser, - SynchronizationContext syncContext, ScheduledExecutorService scheduler) { + SynchronizationContext syncContext, ScheduledExecutorService scheduler, + @Nullable Map bootstrapOverride) { this(name, serviceConfigParser, syncContext, scheduler, SharedXdsClientPoolProvider.getDefaultProvider(), ThreadSafeRandomImpl.instance, - FilterRegistry.getDefaultRegistry()); + FilterRegistry.getDefaultRegistry(), bootstrapOverride); } @VisibleForTesting XdsNameResolver(String name, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, XdsClientPoolFactory xdsClientPoolFactory, ThreadSafeRandom random, - FilterRegistry filterRegistry) { + FilterRegistry filterRegistry, @Nullable Map bootstrapOverride) { authority = GrpcUtil.checkAuthority(checkNotNull(name, "name")); this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.scheduler = checkNotNull(scheduler, "scheduler"); - this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.xdsClientPoolFactory = bootstrapOverride == null ? checkNotNull(xdsClientPoolFactory, + "xdsClientPoolFactory") : new SharedXdsClientPoolProvider(); + this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); this.random = checkNotNull(random, "random"); this.filterRegistry = checkNotNull(filterRegistry, "filterRegistry"); logId = InternalLogId.allocate("xds-resolver", name); diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index 40aa4f919e9..03d88a9752e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -42,10 +42,31 @@ public final class XdsNameResolverProvider extends NameResolverProvider { private static final String SCHEME = "xds"; + private final String scheme; + private final Map bootstrapOverride; + + public XdsNameResolverProvider() { + this(SCHEME, null); + } + + private XdsNameResolverProvider(String scheme, + @Nullable Map bootstrapOverride) { + this.scheme = checkNotNull(scheme, "scheme"); + this.bootstrapOverride = bootstrapOverride; + } + + /** + * A convenient method to allow creating a {@link XdsNameResolverProvider} with custom scheme + * and bootstrap. + */ + public static XdsNameResolverProvider createForTest(String scheme, + @Nullable Map bootstrapOverride) { + return new XdsNameResolverProvider(scheme, bootstrapOverride); + } @Override public XdsNameResolver newNameResolver(URI targetUri, Args args) { - if (SCHEME.equals(targetUri.getScheme())) { + if (scheme.equals(targetUri.getScheme())) { String targetPath = checkNotNull(targetUri.getPath(), "targetPath"); Preconditions.checkArgument( targetPath.startsWith("/"), @@ -54,14 +75,15 @@ public XdsNameResolver newNameResolver(URI targetUri, Args args) { targetUri); String name = targetPath.substring(1); return new XdsNameResolver(name, args.getServiceConfigParser(), - args.getSynchronizationContext(), args.getScheduledExecutorService()); + args.getSynchronizationContext(), args.getScheduledExecutorService(), + bootstrapOverride); } return null; } @Override public String getDefaultScheme() { - return SCHEME; + return scheme; } @Override diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java index ba1e561410f..32850b441d7 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java @@ -20,15 +20,20 @@ import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import com.google.common.collect.ImmutableMap; import io.grpc.ChannelLogger; import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -114,4 +119,54 @@ public void invalidName_hostnameContainsUnderscore() { // Expected } } + + @Test + public void newProvider_multipleScheme() { + NameResolverRegistry registry = NameResolverRegistry.getDefaultRegistry(); + XdsNameResolverProvider provider0 = XdsNameResolverProvider.createForTest("no-scheme", null); + registry.register(provider0); + XdsNameResolverProvider provider1 = XdsNameResolverProvider.createForTest("new-xds-scheme", + new HashMap()); + registry.register(provider1); + assertThat(registry.asFactory() + .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); + assertThat(registry.asFactory() + .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNotNull(); + assertThat(registry.asFactory() + .newNameResolver(URI.create("no-scheme:///localhost"), args)).isNotNull(); + registry.deregister(provider1); + assertThat(registry.asFactory() + .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNull(); + registry.deregister(provider0); + assertThat(registry.asFactory() + .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); + } + + @Test + public void newProvider_overrideBootstrap() { + Map b = ImmutableMap.of( + "node", ImmutableMap.of( + "id", "ENVOY_NODE_ID", + "cluster", "ENVOY_CLUSTER"), + "xds_servers", Collections.singletonList( + ImmutableMap.of( + "server_uri", "trafficdirector.googleapis.com:443", + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ) + ) + ) + ); + NameResolverRegistry registry = new NameResolverRegistry(); + XdsNameResolverProvider provider = XdsNameResolverProvider.createForTest("no-scheme", b); + registry.register(provider); + NameResolver resolver = registry.asFactory() + .newNameResolver(URI.create("no-scheme:///localhost"), args); + resolver.start(mock(NameResolver.Listener2.class)); + assertThat(resolver).isInstanceOf(XdsNameResolver.class); + assertThat(((XdsNameResolver)resolver).getXdsClient().getBootstrapInfo().getNode().getId()) + .isEqualTo("ENVOY_NODE_ID"); + resolver.shutdown(); + registry.deregister(provider); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index cb2b4481fad..22d7302f207 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -153,7 +153,7 @@ public void setUp() { new FaultFilter(mockRandom, new AtomicLong()), RouterFilter.INSTANCE); resolver = new XdsNameResolver(AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, filterRegistry); + xdsClientPoolFactory, mockRandom, filterRegistry, null); } @After @@ -172,7 +172,6 @@ public void resolving_failToCreateXdsClientPool() { XdsClientPoolFactory xdsClientPoolFactory = new XdsClientPoolFactory() { @Override public void setBootstrapOverride(Map bootstrap) { - throw new UnsupportedOperationException("Should not be called"); } @Override @@ -187,7 +186,7 @@ public ObjectPool getOrCreate() throws XdsInitializationException { } }; resolver = new XdsNameResolver(AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry()); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); @@ -437,7 +436,7 @@ public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { ServiceConfigParser realParser = new ScParser( true, 5, 5, new AutoConfiguredLoadBalancerFactory("pick-first")); resolver = new XdsNameResolver(AUTHORITY, realParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry()); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); RetryPolicy retryPolicy = RetryPolicy.create( @@ -640,7 +639,7 @@ public void resolved_rpcHashingByChannelId() { resolver.shutdown(); reset(mockListener); resolver = new XdsNameResolver(AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry()); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( @@ -1701,10 +1700,8 @@ public void routeMatching_withHeaders() { } private final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { - @Override public void setBootstrapOverride(Map bootstrap) { - throw new UnsupportedOperationException("Should not be called"); } @Override From fd2a58a55e54c04250f77533c44d1e05c0e5985b Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 11 Aug 2021 10:24:37 -0700 Subject: [PATCH 24/82] all: implement retry stats (#8362) --- .../main/java/io/grpc/ClientStreamTracer.java | 30 +- .../java/io/grpc/ManagedChannelBuilder.java | 3 - .../io/grpc/census/CensusStatsModule.java | 230 +++++++---- .../io/grpc/census/CensusTracingModule.java | 47 ++- .../io/grpc/census/CensusModulesTest.java | 269 +++++++++++- .../java/io/grpc/internal/ClientCallImpl.java | 3 +- .../main/java/io/grpc/internal/GrpcUtil.java | 3 +- .../io/grpc/internal/ManagedChannelImpl.java | 11 +- .../internal/ManagedChannelImplBuilder.java | 12 - .../java/io/grpc/internal/OobChannel.java | 2 +- .../io/grpc/internal/RetriableStream.java | 35 +- .../io/grpc/internal/SubchannelChannel.java | 2 +- .../io/grpc/internal/RetriableStreamTest.java | 3 +- interop-testing/build.gradle | 1 + .../integration/AbstractInteropTest.java | 23 ++ .../grpc/testing/integration/RetryTest.java | 382 +++++++++++++++--- 16 files changed, 864 insertions(+), 192 deletions(-) diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java index 6a5d3cc3397..bb836ac82e1 100644 --- a/api/src/main/java/io/grpc/ClientStreamTracer.java +++ b/api/src/main/java/io/grpc/ClientStreamTracer.java @@ -97,11 +97,15 @@ public abstract static class InternalLimitedInfoFactory extends Factory {} public static final class StreamInfo { private final Attributes transportAttrs; private final CallOptions callOptions; + private final int previousAttempts; private final boolean isTransparentRetry; - StreamInfo(Attributes transportAttrs, CallOptions callOptions, boolean isTransparentRetry) { + StreamInfo( + Attributes transportAttrs, CallOptions callOptions, int previousAttempts, + boolean isTransparentRetry) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs"); this.callOptions = checkNotNull(callOptions, "callOptions"); + this.previousAttempts = previousAttempts; this.isTransparentRetry = isTransparentRetry; } @@ -124,6 +128,15 @@ public CallOptions getCallOptions() { return callOptions; } + /** + * Returns the number of preceding attempts for the RPC. + * + * @since 1.40.0 + */ + public int getPreviousAttempts() { + return previousAttempts; + } + /** * Whether the stream is a transparent retry. * @@ -142,6 +155,7 @@ public Builder toBuilder() { return new Builder() .setCallOptions(callOptions) .setTransportAttrs(transportAttrs) + .setPreviousAttempts(previousAttempts) .setIsTransparentRetry(isTransparentRetry); } @@ -159,6 +173,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("transportAttrs", transportAttrs) .add("callOptions", callOptions) + .add("previousAttempts", previousAttempts) .add("isTransparentRetry", isTransparentRetry) .toString(); } @@ -171,6 +186,7 @@ public String toString() { public static final class Builder { private Attributes transportAttrs = Attributes.EMPTY; private CallOptions callOptions = CallOptions.DEFAULT; + private int previousAttempts; private boolean isTransparentRetry; Builder() { @@ -197,6 +213,16 @@ public Builder setCallOptions(CallOptions callOptions) { return this; } + /** + * Set the number of preceding attempts of the RPC. + * + * @since 1.40.0 + */ + public Builder setPreviousAttempts(int previousAttempts) { + this.previousAttempts = previousAttempts; + return this; + } + /** * Sets whether the stream is a transparent retry. * @@ -211,7 +237,7 @@ public Builder setIsTransparentRetry(boolean isTransparentRetry) { * Builds a new StreamInfo. */ public StreamInfo build() { - return new StreamInfo(transportAttrs, callOptions, isTransparentRetry); + return new StreamInfo(transportAttrs, callOptions, previousAttempts, isTransparentRetry); } } } diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index e4a4611541d..73e66ed6dc4 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -479,9 +479,6 @@ public T disableRetry() { * transparent retries, which are safe for non-idempotent RPCs. Service config is ideally provided * by the name resolver, but may also be specified via {@link #defaultServiceConfig}. * - *

For the current release, this method may have a side effect that disables Census stats and - * tracing. - * * @return this * @since 1.11.0 */ diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index ac5f4e705e3..6faeb575ccc 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -17,7 +17,6 @@ package io.grpc.census; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; @@ -28,16 +27,20 @@ import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Context; +import io.grpc.Deadline; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StreamTracer; import io.grpc.census.internal.DeprecatedCensusConstants; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; +import io.opencensus.stats.Measure; import io.opencensus.stats.Measure.MeasureDouble; import io.opencensus.stats.Measure.MeasureLong; import io.opencensus.stats.MeasureMap; @@ -51,9 +54,11 @@ import io.opencensus.tags.propagation.TagContextSerializationException; import io.opencensus.tags.unsafe.ContextUtils; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -61,9 +66,10 @@ /** * Provides factories for {@link StreamTracer} that records stats to Census. * - *

On the client-side, a factory is created for each call, because ClientCall starts earlier than - * the ClientStream, and in some cases may even not create a ClientStream at all. Therefore, it's - * the factory that reports the summary to Census. + *

On the client-side, a factory is created for each call, and the factory creates a stream + * tracer for each attempt. If there is no stream created when the call is ended, we still create a + * tracer. It's the tracer that reports per-attempt stats, and the factory that reports the stats + * of the overall RPC, such as RETRIES_PER_CALL, to Census. * *

On the server-side, there is only one ServerStream per each ServerCall, and ServerStream * starts earlier than the ServerCall. Therefore, only one tracer is created per stream/call and @@ -168,7 +174,6 @@ private void recordRealTimeMetric(TagContext ctx, MeasureLong measure, long valu } private static final class ClientTracer extends ClientStreamTracer { - @Nullable private static final AtomicLongFieldUpdater outboundMessageCountUpdater; @Nullable private static final AtomicLongFieldUpdater inboundMessageCountUpdater; @Nullable private static final AtomicLongFieldUpdater outboundWireSizeUpdater; @@ -222,21 +227,31 @@ private static final class ClientTracer extends ClientStreamTracer { inboundUncompressedSizeUpdater = tmpInboundUncompressedSizeUpdater; } - private final CensusStatsModule module; + final Stopwatch stopwatch; + final CallAttemptsTracerFactory attemptsState; + final AtomicBoolean inboundReceivedOrClosed = new AtomicBoolean(); + final CensusStatsModule module; final TagContext parentCtx; - private final TagContext startCtx; - + final TagContext startCtx; + final StreamInfo info; volatile long outboundMessageCount; volatile long inboundMessageCount; volatile long outboundWireSize; volatile long inboundWireSize; volatile long outboundUncompressedSize; volatile long inboundUncompressedSize; - - ClientTracer(CensusStatsModule module, TagContext parentCtx, TagContext startCtx) { - this.module = checkNotNull(module, "module"); + long roundtripNanos; + Code statusCode; + + ClientTracer( + CallAttemptsTracerFactory attemptsState, CensusStatsModule module, TagContext parentCtx, + TagContext startCtx, StreamInfo info) { + this.attemptsState = attemptsState; + this.module = module; this.parentCtx = parentCtx; - this.startCtx = checkNotNull(startCtx, "startCtx"); + this.startCtx = startCtx; + this.info = info; + this.stopwatch = module.stopwatchSupplier.get().start(); } @Override @@ -296,6 +311,11 @@ public void inboundUncompressedSize(long bytes) { @Override @SuppressWarnings("NonAtomicVolatileUpdate") public void inboundMessage(int seqNo) { + if (inboundReceivedOrClosed.compareAndSet(false, true)) { + // Because inboundUncompressedSize() might be called after streamClosed(), + // we will report stats in callEnded(). Note that this attempt is already committed. + attemptsState.inboundMetricTracer = this; + } if (inboundMessageCountUpdater != null) { inboundMessageCountUpdater.getAndIncrement(this); } else { @@ -316,14 +336,74 @@ public void outboundMessage(int seqNo) { module.recordRealTimeMetric( startCtx, RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1); } + + @Override + public void streamClosed(Status status) { + attemptsState.attemptEnded(); + stopwatch.stop(); + roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); + Deadline deadline = info.getCallOptions().getDeadline(); + statusCode = status.getCode(); + if (statusCode == Status.Code.CANCELLED && deadline != null) { + // When the server's deadline expires, it can only reset the stream with CANCEL and no + // description. Since our timer may be delayed in firing, we double-check the deadline and + // turn the failure into the likely more helpful DEADLINE_EXCEEDED status. + if (deadline.isExpired()) { + statusCode = Code.DEADLINE_EXCEEDED; + } + } + if (inboundReceivedOrClosed.compareAndSet(false, true)) { + if (module.recordFinishedRpcs) { + // Stream is closed early. So no need to record metrics for any inbound events after this + // point. + recordFinishedRpc(); + } + } // Otherwise will report stats in callEnded() to guarantee all inbound metrics are recorded. + } + + void recordFinishedRpc() { + MeasureMap measureMap = module.statsRecorder.newMeasureMap() + // TODO(songya): remove the deprecated measure constants once they are completed removed. + .put(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT, 1) + // The latency is double value + .put( + DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, + roundtripNanos / NANOS_PER_MILLI) + .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT, outboundMessageCount) + .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT, inboundMessageCount) + .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES, outboundWireSize) + .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES, inboundWireSize) + .put( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES, + outboundUncompressedSize) + .put( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES, + inboundUncompressedSize); + if (statusCode != Code.OK) { + measureMap.put(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT, 1); + } + TagValue statusTag = TagValue.create(statusCode.toString()); + measureMap.record( + module + .tagger + .toBuilder(startCtx) + .putLocal(RpcMeasureConstants.GRPC_CLIENT_STATUS, statusTag) + .build()); + } } @VisibleForTesting static final class CallAttemptsTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { - @Nullable - private static final AtomicReferenceFieldUpdater - streamTracerUpdater; + static final MeasureLong RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/retries_per_call", "Number of retries per call", "1"); + static final MeasureLong TRANSPARENT_RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/transparent_retries_per_call", "Transparent retries per call", "1"); + static final MeasureDouble RETRY_DELAY_PER_CALL = + Measure.MeasureDouble.create( + "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; @@ -334,40 +414,45 @@ static final class CallAttemptsTracerFactory extends * (potentially racy) direct updates of the volatile variables. */ static { - AtomicReferenceFieldUpdater tmpStreamTracerUpdater; AtomicIntegerFieldUpdater tmpCallEndedUpdater; try { - tmpStreamTracerUpdater = - AtomicReferenceFieldUpdater.newUpdater( - CallAttemptsTracerFactory.class, ClientTracer.class, "streamTracer"); tmpCallEndedUpdater = AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); } catch (Throwable t) { logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); - tmpStreamTracerUpdater = null; tmpCallEndedUpdater = null; } - streamTracerUpdater = tmpStreamTracerUpdater; callEndedUpdater = tmpCallEndedUpdater; } + ClientTracer inboundMetricTracer; private final CensusStatsModule module; private final Stopwatch stopwatch; - private volatile ClientTracer streamTracer; private volatile int callEnded; private final TagContext parentCtx; private final TagContext startCtx; + private final String fullMethodName; + + // TODO(zdapeng): optimize memory allocation using AtomicFieldUpdater. + private final AtomicLong attemptsPerCall = new AtomicLong(); + private final AtomicLong transparentRetriesPerCall = new AtomicLong(); + private final AtomicLong retryDelayNanos = new AtomicLong(); + private final AtomicLong lastInactiveTimeStamp = new AtomicLong(); + private final AtomicInteger activeStreams = new AtomicInteger(); + private final AtomicBoolean activated = new AtomicBoolean(); CallAttemptsTracerFactory( CensusStatsModule module, TagContext parentCtx, String fullMethodName) { - this.module = checkNotNull(module); - this.parentCtx = checkNotNull(parentCtx); + this.module = checkNotNull(module, "module"); + this.parentCtx = checkNotNull(parentCtx, "parentCtx"); + this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); + this.stopwatch = module.stopwatchSupplier.get().start(); TagValue methodTag = TagValue.create(fullMethodName); - this.startCtx = module.tagger.toBuilder(parentCtx) + startCtx = module.tagger.toBuilder(parentCtx) .putLocal(RpcMeasureConstants.GRPC_CLIENT_METHOD, methodTag) .build(); - this.stopwatch = module.stopwatchSupplier.get().start(); if (module.recordStartedRpcs) { + // Record here in case newClientStreamTracer() would never be called. module.statsRecorder.newMeasureMap() .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) .record(startCtx); @@ -375,30 +460,37 @@ static final class CallAttemptsTracerFactory extends } @Override - public ClientStreamTracer newClientStreamTracer( - ClientStreamTracer.StreamInfo info, Metadata headers) { - ClientTracer tracer = new ClientTracer(module, parentCtx, startCtx); - // TODO(zhangkun83): Once retry or hedging is implemented, a ClientCall may start more than - // one streams. We will need to update this file to support them. - if (streamTracerUpdater != null) { - checkState( - streamTracerUpdater.compareAndSet(this, null, tracer), - "Are you creating multiple streams per call? This class doesn't yet support this case"); + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { + ClientTracer tracer = new ClientTracer(this, module, parentCtx, startCtx, info); + if (activeStreams.incrementAndGet() == 1) { + if (!activated.compareAndSet(false, true)) { + retryDelayNanos.addAndGet(stopwatch.elapsed(TimeUnit.NANOSECONDS)); + } + } + if (module.recordStartedRpcs && attemptsPerCall.get() > 0) { + module.statsRecorder.newMeasureMap() + .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) + .record(startCtx); + } + if (info.isTransparentRetry()) { + transparentRetriesPerCall.incrementAndGet(); } else { - checkState( - streamTracer == null, - "Are you creating multiple streams per call? This class doesn't yet support this case"); - streamTracer = tracer; + attemptsPerCall.incrementAndGet(); } return tracer; } - /** - * Record a finished call and mark the current time as the end time. - * - *

Can be called from any thread without synchronization. Calling it the second time or more - * is a no-op. - */ + // Called whenever each attempt is ended. + void attemptEnded() { + if (activeStreams.decrementAndGet() == 0) { + // Race condition between two extremely close events does not matter because the difference + // in the result would be very small. + long lastInactiveTimeStamp = + this.lastInactiveTimeStamp.getAndSet(stopwatch.elapsed(TimeUnit.NANOSECONDS)); + retryDelayNanos.addAndGet(-lastInactiveTimeStamp); + } + } + void callEnded(Status status) { if (callEndedUpdater != null) { if (callEndedUpdater.getAndSet(this, 1) != 0) { @@ -414,36 +506,30 @@ void callEnded(Status status) { return; } stopwatch.stop(); - long roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); - ClientTracer tracer = streamTracer; - if (tracer == null) { - tracer = new ClientTracer(module, parentCtx, startCtx); + if (attemptsPerCall.get() == 0) { + ClientTracer tracer = new ClientTracer(this, module, parentCtx, startCtx, null); + tracer.roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); + tracer.statusCode = status.getCode(); + tracer.recordFinishedRpc(); + } else if (inboundMetricTracer != null) { + inboundMetricTracer.recordFinishedRpc(); } - MeasureMap measureMap = module.statsRecorder.newMeasureMap() - // TODO(songya): remove the deprecated measure constants once they are completed removed. - .put(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT, 1) - // The latency is double value - .put( - DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, - roundtripNanos / NANOS_PER_MILLI) - .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT, tracer.outboundMessageCount) - .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT, tracer.inboundMessageCount) - .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES, tracer.outboundWireSize) - .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES, tracer.inboundWireSize) - .put( - DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES, - tracer.outboundUncompressedSize) - .put( - DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES, - tracer.inboundUncompressedSize); - if (!status.isOk()) { - measureMap.put(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT, 1); + + long retriesPerCall = 0; + long attempts = attemptsPerCall.get(); + if (attempts > 0) { + retriesPerCall = attempts - 1; } + MeasureMap measureMap = module.statsRecorder.newMeasureMap() + .put(RETRIES_PER_CALL, retriesPerCall) + .put(TRANSPARENT_RETRIES_PER_CALL, transparentRetriesPerCall.get()) + .put(RETRY_DELAY_PER_CALL, retryDelayNanos.get() / NANOS_PER_MILLI); + TagValue methodTag = TagValue.create(fullMethodName); TagValue statusTag = TagValue.create(status.getCode().toString()); measureMap.record( - module - .tagger - .toBuilder(startCtx) + module.tagger + .toBuilder(parentCtx) + .putLocal(RpcMeasureConstants.GRPC_CLIENT_METHOD, methodTag) .putLocal(RpcMeasureConstants.GRPC_CLIENT_STATUS, statusTag) .build()); } diff --git a/census/src/main/java/io/grpc/census/CensusTracingModule.java b/census/src/main/java/io/grpc/census/CensusTracingModule.java index dac62206fd2..08d5fe3ca97 100644 --- a/census/src/main/java/io/grpc/census/CensusTracingModule.java +++ b/census/src/main/java/io/grpc/census/CensusTracingModule.java @@ -32,6 +32,7 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.StreamTracer; +import io.opencensus.trace.AttributeValue; import io.opencensus.trace.BlankSpan; import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.MessageEvent; @@ -60,7 +61,8 @@ final class CensusTracingModule { private static final Logger logger = Logger.getLogger(CensusTracingModule.class.getName()); - @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Nullable + private static final AtomicIntegerFieldUpdater callEndedUpdater; @Nullable private static final AtomicIntegerFieldUpdater streamClosedUpdater; @@ -70,11 +72,11 @@ final class CensusTracingModule { * (potentially racy) direct updates of the volatile variables. */ static { - AtomicIntegerFieldUpdater tmpCallEndedUpdater; + AtomicIntegerFieldUpdater tmpCallEndedUpdater; AtomicIntegerFieldUpdater tmpStreamClosedUpdater; try { tmpCallEndedUpdater = - AtomicIntegerFieldUpdater.newUpdater(ClientCallTracer.class, "callEnded"); + AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); tmpStreamClosedUpdater = AtomicIntegerFieldUpdater.newUpdater(ServerTracer.class, "streamClosed"); } catch (Throwable t) { @@ -116,11 +118,12 @@ public SpanContext parseBytes(byte[] serialized) { } /** - * Creates a {@link ClientCallTracer} for a new call. + * Creates a {@link CallAttemptsTracerFactory} for a new call. */ @VisibleForTesting - ClientCallTracer newClientCallTracer(@Nullable Span parentSpan, MethodDescriptor method) { - return new ClientCallTracer(parentSpan, method); + CallAttemptsTracerFactory newClientCallTracer( + @Nullable Span parentSpan, MethodDescriptor method) { + return new CallAttemptsTracerFactory(parentSpan, method); } /** @@ -223,19 +226,21 @@ private static void recordMessageEvent( } @VisibleForTesting - final class ClientCallTracer extends ClientStreamTracer.InternalLimitedInfoFactory { + final class CallAttemptsTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { volatile int callEnded; private final boolean isSampledToLocalTracing; private final Span span; + private final String fullMethodName; - ClientCallTracer(@Nullable Span parentSpan, MethodDescriptor method) { + CallAttemptsTracerFactory(@Nullable Span parentSpan, MethodDescriptor method) { checkNotNull(method, "method"); this.isSampledToLocalTracing = method.isSampledToLocalTracing(); + this.fullMethodName = method.getFullMethodName(); this.span = censusTracer .spanBuilderWithExplicitParent( - generateTraceSpanName(false, method.getFullMethodName()), + generateTraceSpanName(false, fullMethodName), parentSpan) .setRecordEvents(true) .startSpan(); @@ -244,7 +249,17 @@ final class ClientCallTracer extends ClientStreamTracer.InternalLimitedInfoFacto @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - return new ClientTracer(span, tracingHeader); + Span attemptSpan = censusTracer + .spanBuilderWithExplicitParent( + "Attempt." + fullMethodName.replace('/', '.'), + span) + .setRecordEvents(true) + .startSpan(); + attemptSpan.putAttribute( + "previous-rpc-attempts", AttributeValue.longAttributeValue(info.getPreviousAttempts())); + attemptSpan.putAttribute( + "transparent-retry", AttributeValue.booleanAttributeValue(info.isTransparentRetry())); + return new ClientTracer(attemptSpan, tracingHeader, isSampledToLocalTracing); } /** @@ -271,10 +286,13 @@ void callEnded(io.grpc.Status status) { private static final class ClientTracer extends ClientStreamTracer { private final Span span; final Metadata.Key tracingHeader; + final boolean isSampledToLocalTracing; - ClientTracer(Span span, Metadata.Key tracingHeader) { + ClientTracer( + Span span, Metadata.Key tracingHeader, boolean isSampledToLocalTracing) { this.span = checkNotNull(span, "span"); this.tracingHeader = tracingHeader; + this.isSampledToLocalTracing = isSampledToLocalTracing; } @Override @@ -298,6 +316,11 @@ public void inboundMessageRead( recordMessageEvent( span, MessageEvent.Type.RECEIVED, seqNo, optionalWireSize, optionalUncompressedSize); } + + @Override + public void streamClosed(io.grpc.Status status) { + span.end(createEndSpanOptions(status, isSampledToLocalTracing)); + } } @@ -388,7 +411,7 @@ public ClientCall interceptCall( // Safe usage of the unsafe trace API because CONTEXT_SPAN_KEY.get() returns the same value // as Tracer.getCurrentSpan() except when no value available when the return value is null // for the direct access and BlankSpan when Tracer API is used. - final ClientCallTracer tracerFactory = + final CallAttemptsTracerFactory tracerFactory = newClientCallTracer(ContextUtils.getValue(Context.current()), method); ClientCall call = next.newCall( diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index fd3a049f7a4..d285c8fe8c2 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -18,6 +18,9 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.census.CensusStatsModule.CallAttemptsTracerFactory.RETRIES_PER_CALL; +import static io.grpc.census.CensusStatsModule.CallAttemptsTracerFactory.RETRY_DELAY_PER_CALL; +import static io.grpc.census.CensusStatsModule.CallAttemptsTracerFactory.TRANSPARENT_RETRIES_PER_CALL; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -58,6 +61,7 @@ import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer.ServerCallInfo; import io.grpc.Status; +import io.grpc.census.CensusTracingModule.CallAttemptsTracerFactory; import io.grpc.census.internal.DeprecatedCensusConstants; import io.grpc.internal.FakeClock; import io.grpc.internal.testing.StatsTestUtils; @@ -81,6 +85,7 @@ import io.opencensus.stats.View; import io.opencensus.tags.TagContext; import io.opencensus.tags.TagValue; +import io.opencensus.trace.AttributeValue; import io.opencensus.trace.BlankSpan; import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.MessageEvent; @@ -173,10 +178,12 @@ public String parse(InputStream stream) { private final Random random = new Random(1234); private final Span fakeClientParentSpan = MockableSpan.generateRandomSpan(random); private final Span spyClientSpan = spy(MockableSpan.generateRandomSpan(random)); - private final SpanContext fakeClientSpanContext = spyClientSpan.getContext(); + private final Span spyAttemptSpan = spy(MockableSpan.generateRandomSpan(random)); + private final SpanContext fakeAttemptSpanContext = spyAttemptSpan.getContext(); private final Span spyServerSpan = spy(MockableSpan.generateRandomSpan(random)); private final byte[] binarySpanContext = new byte[]{3, 1, 5}; private final SpanBuilder spyClientSpanBuilder = spy(new MockableSpan.Builder()); + private final SpanBuilder spyAttemptSpanBuilder = spy(new MockableSpan.Builder()); private final SpanBuilder spyServerSpanBuilder = spy(new MockableSpan.Builder()); @Rule @@ -201,15 +208,20 @@ public String parse(InputStream stream) { @Before public void setUp() throws Exception { when(spyClientSpanBuilder.startSpan()).thenReturn(spyClientSpan); - when(tracer.spanBuilderWithExplicitParent(anyString(), ArgumentMatchers.any())) + when(spyAttemptSpanBuilder.startSpan()).thenReturn(spyAttemptSpan); + when(tracer.spanBuilderWithExplicitParent( + eq("Sent.package1.service2.method3"), ArgumentMatchers.any())) .thenReturn(spyClientSpanBuilder); + when(tracer.spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), ArgumentMatchers.any())) + .thenReturn(spyAttemptSpanBuilder); when(spyServerSpanBuilder.startSpan()).thenReturn(spyServerSpan); when(tracer.spanBuilderWithRemoteParent(anyString(), ArgumentMatchers.any())) .thenReturn(spyServerSpanBuilder); when(mockTracingPropagationHandler.toByteArray(any(SpanContext.class))) .thenReturn(binarySpanContext); when(mockTracingPropagationHandler.fromByteArray(any(byte[].class))) - .thenReturn(fakeClientSpanContext); + .thenReturn(fakeAttemptSpanContext); censusStats = new CensusStatsModule( tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), @@ -292,7 +304,7 @@ public ClientCall interceptCall( assertEquals(2, capturedCallOptions.get().getStreamTracerFactories().size()); assertTrue( capturedCallOptions.get().getStreamTracerFactories().get(0) - instanceof CensusTracingModule.ClientCallTracer); + instanceof CallAttemptsTracerFactory); assertTrue( capturedCallOptions.get().getStreamTracerFactories().get(1) instanceof CensusStatsModule.CallAttemptsTracerFactory); @@ -355,6 +367,7 @@ record = statsRecorder.pollRecord(); .setSampleToLocalSpanStore(false) .build()); verify(spyClientSpan, never()).end(); + assertZeroRetryRecorded(); } @Test @@ -489,11 +502,200 @@ private void subtestClientBasicStatsDefaultContext( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals(30 + 100 + 16 + 24, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + assertZeroRetryRecorded(); } else { assertNull(statsRecorder.pollRecord()); } } + // This test is only unit-testing the stat recording logic. The retry behavior is faked. + @Test + public void recordRetryStats() { + CensusStatsModule localCensusStats = + new CensusStatsModule( + tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), + true, true, true, true); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + + StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + + fakeClock.forwardTime(30, MILLISECONDS); + tracer.outboundHeaders(); + fakeClock.forwardTime(100, MILLISECONDS); + tracer.outboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundMessage(1); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundWireSize(1028); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD, 1028, true, true); + tracer.outboundUncompressedSize(1128); + fakeClock.forwardTime(24, MILLISECONDS); + tracer.streamClosed(Status.UNAVAILABLE); + record = statsRecorder.pollRecord(); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + TagValue statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.UNAVAILABLE.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); + assertEquals( + 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals( + 1128, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals( + 30 + 100 + 24, + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + + // faking retry + fakeClock.forwardTime(1000, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + tracer.outboundHeaders(); + tracer.outboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundMessage(1); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundWireSize(1028); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD, 1028, true, true); + tracer.outboundUncompressedSize(1128); + fakeClock.forwardTime(100, MILLISECONDS); + tracer.streamClosed(Status.NOT_FOUND); + record = statsRecorder.pollRecord(); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.NOT_FOUND.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); + assertEquals( + 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals( + 1128, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals( + 100 , + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + + // fake transparent retry + fakeClock.forwardTime(10, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + tracer.streamClosed(Status.UNAVAILABLE); + record = statsRecorder.pollRecord(); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.UNAVAILABLE.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); + assertEquals( + 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + + // fake another transparent retry + fakeClock.forwardTime(10, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + tracer.outboundHeaders(); + tracer.outboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundMessage(1); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundWireSize(1028); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD, 1028, true, true); + tracer.outboundUncompressedSize(1128); + fakeClock.forwardTime(16, MILLISECONDS); + tracer.inboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_METHOD, 1, true, true); + tracer.inboundWireSize(33); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_METHOD, 33, true, true); + tracer.inboundUncompressedSize(67); + fakeClock.forwardTime(24, MILLISECONDS); + // RPC succeeded + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); + + record = statsRecorder.pollRecord(); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.OK.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertThat(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)).isNull(); + assertEquals( + 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals( + 1128, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); + assertEquals( + 33, + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertEquals( + 67, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + assertEquals( + 16 + 24 , + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + + record = statsRecorder.pollRecord(); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.OK.toString(), statusTag.asString()); + assertThat(record.getMetric(RETRIES_PER_CALL)).isEqualTo(1); + assertThat(record.getMetric(TRANSPARENT_RETRIES_PER_CALL)).isEqualTo(2); + assertThat(record.getMetric(RETRY_DELAY_PER_CALL)).isEqualTo(1000D + 10 + 10); + } + private void assertRealTimeMetric( Measure measure, long expectedValue, boolean recordRealTimeMetrics, boolean clientSide) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -517,16 +719,28 @@ private void assertRealTimeMetric( assertEquals(expectedValue, record.getMetricAsLongOrFail(measure)); } + private void assertZeroRetryRecorded() { + StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); + TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertThat(record.getMetric(RETRIES_PER_CALL)).isEqualTo(0); + assertThat(record.getMetric(TRANSPARENT_RETRIES_PER_CALL)).isEqualTo(0); + assertThat(record.getMetric(RETRY_DELAY_PER_CALL)).isEqualTo(0D); + } + @Test public void clientBasicTracingDefaultSpan() { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(null, method); Metadata headers = new Metadata(); ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); clientStreamTracer.streamCreated(Attributes.EMPTY, headers); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), ArgumentMatchers.isNull()); + verify(tracer).spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), eq(spyClientSpan)); verify(spyClientSpan, never()).end(any(EndSpanOptions.class)); + verify(spyAttemptSpan, never()).end(any(EndSpanOptions.class)); clientStreamTracer.outboundMessage(0); clientStreamTracer.outboundMessageSent(0, 882, -1); @@ -538,8 +752,12 @@ public void clientBasicTracingDefaultSpan() { clientStreamTracer.streamClosed(Status.OK); callTracer.callEnded(Status.OK); - InOrder inOrder = inOrder(spyClientSpan); - inOrder.verify(spyClientSpan, times(3)).addMessageEvent(messageEventCaptor.capture()); + InOrder inOrder = inOrder(spyClientSpan, spyAttemptSpan); + inOrder.verify(spyAttemptSpan) + .putAttribute("previous-rpc-attempts", AttributeValue.longAttributeValue(0)); + inOrder.verify(spyAttemptSpan) + .putAttribute("transparent-retry", AttributeValue.booleanAttributeValue(false)); + inOrder.verify(spyAttemptSpan, times(3)).addMessageEvent(messageEventCaptor.capture()); List events = messageEventCaptor.getAllValues(); assertEquals( MessageEvent.builder(MessageEvent.Type.SENT, 0).setCompressedMessageSize(882).build(), @@ -553,18 +771,23 @@ public void clientBasicTracingDefaultSpan() { .setUncompressedMessageSize(90) .build(), events.get(2)); + inOrder.verify(spyAttemptSpan).end( + EndSpanOptions.builder() + .setStatus(io.opencensus.trace.Status.OK) + .setSampleToLocalSpanStore(false) + .build()); inOrder.verify(spyClientSpan).end( EndSpanOptions.builder() .setStatus(io.opencensus.trace.Status.OK) .setSampleToLocalSpanStore(false) .build()); - verifyNoMoreInteractions(spyClientSpan); + inOrder.verifyNoMoreInteractions(); verifyNoMoreInteractions(tracer); } @Test public void clientTracingSampledToLocalSpanStore() { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(null, sampledMethod); callTracer.callEnded(Status.OK); @@ -631,11 +854,12 @@ record = statsRecorder.pollRecord(); 3000, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_SERVER_ELAPSED_TIME)); + assertZeroRetryRecorded(); } @Test public void clientStreamNeverCreatedStillRecordTracing() { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), same(fakeClientParentSpan)); @@ -770,6 +994,7 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS assertNull(clientRecord.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); TagValue clientPropagatedTag = clientRecord.tags.get(StatsTestUtils.EXTRA_TAG); assertEquals("extra-tag-value-897", clientPropagatedTag.asString()); + assertZeroRetryRecorded(); } if (!recordStats) { @@ -812,16 +1037,18 @@ public void statsHeaderMalformed() { @Test public void traceHeadersPropagateSpanContext() throws Exception { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); streamTracer.streamCreated(Attributes.EMPTY, headers); - verify(mockTracingPropagationHandler).toByteArray(same(fakeClientSpanContext)); + verify(mockTracingPropagationHandler).toByteArray(same(fakeAttemptSpanContext)); verifyNoMoreInteractions(mockTracingPropagationHandler); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), same(fakeClientParentSpan)); + verify(tracer).spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), same(spyClientSpan)); verify(spyClientSpanBuilder).setRecordEvents(eq(true)); verifyNoMoreInteractions(tracer); assertTrue(headers.containsKey(censusTracing.tracingHeader)); @@ -831,7 +1058,7 @@ public void traceHeadersPropagateSpanContext() throws Exception { method.getFullMethodName(), headers); verify(mockTracingPropagationHandler).fromByteArray(same(binarySpanContext)); verify(tracer).spanBuilderWithRemoteParent( - eq("Recv.package1.service2.method3"), same(spyClientSpan.getContext())); + eq("Recv.package1.service2.method3"), same(spyAttemptSpan.getContext())); verify(spyServerSpanBuilder).setRecordEvents(eq(true)); Context filteredContext = serverTracer.filterContext(Context.ROOT); @@ -840,7 +1067,7 @@ public void traceHeadersPropagateSpanContext() throws Exception { @Test public void traceHeaders_propagateSpanContext() throws Exception { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); @@ -854,10 +1081,12 @@ public void traceHeaders_propagateSpanContext() throws Exception { public void traceHeaders_missingCensusImpl_notPropagateSpanContext() throws Exception { reset(spyClientSpanBuilder); + reset(spyAttemptSpanBuilder); when(spyClientSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); + when(spyAttemptSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); Metadata headers = new Metadata(); - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); @@ -867,14 +1096,16 @@ public void traceHeaders_missingCensusImpl_notPropagateSpanContext() @Test public void traceHeaders_clientMissingCensusImpl_preservingHeaders() throws Exception { reset(spyClientSpanBuilder); + reset(spyAttemptSpanBuilder); when(spyClientSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); + when(spyAttemptSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); Metadata headers = new Metadata(); headers.put( Metadata.Key.of("never-used-key-bin", Metadata.BINARY_BYTE_MARSHALLER), new byte[] {}); Set originalHeaderKeys = new HashSet<>(headers.keys()); - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); @@ -885,9 +1116,9 @@ public void traceHeaders_clientMissingCensusImpl_preservingHeaders() throws Exce public void traceHeaderMalformed() throws Exception { // As comparison, normal header parsing Metadata headers = new Metadata(); - headers.put(censusTracing.tracingHeader, fakeClientSpanContext); + headers.put(censusTracing.tracingHeader, fakeAttemptSpanContext); // mockTracingPropagationHandler was stubbed to always return fakeServerParentSpanContext - assertSame(spyClientSpan.getContext(), headers.get(censusTracing.tracingHeader)); + assertSame(spyAttemptSpan.getContext(), headers.get(censusTracing.tracingHeader)); // Make BinaryPropagationHandler always throw when parsing the header when(mockTracingPropagationHandler.fromByteArray(any(byte[].class))) @@ -895,7 +1126,7 @@ public void traceHeaderMalformed() throws Exception { headers = new Metadata(); assertNull(headers.get(censusTracing.tracingHeader)); - headers.put(censusTracing.tracingHeader, fakeClientSpanContext); + headers.put(censusTracing.tracingHeader, fakeAttemptSpanContext); assertSame(SpanContext.INVALID, headers.get(censusTracing.tracingHeader)); assertNotSame(spyClientSpan.getContext(), SpanContext.INVALID); diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index 28cd3351203..dd17244e2a5 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -255,7 +255,8 @@ public void runInContext() { effectiveDeadline, context.getDeadline(), callOptions.getDeadline()); stream = clientStreamProvider.newStream(method, callOptions, headers, context); } else { - ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers(callOptions, headers, false); + ClientStreamTracer[] tracers = + GrpcUtil.getClientStreamTracers(callOptions, headers, 0, false); stream = new FailingClientStream( DEADLINE_EXCEEDED.withDescription( "ClientCall started after deadline exceeded: " + effectiveDeadline), diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 5b5a062e95d..54f6d2f41d5 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -757,11 +757,12 @@ public ListenableFuture getStats() { /** Gets stream tracers based on CallOptions. */ public static ClientStreamTracer[] getClientStreamTracers( - CallOptions callOptions, Metadata headers, boolean isTransparentRetry) { + CallOptions callOptions, Metadata headers, int previousAttempts, boolean isTransparentRetry) { List factories = callOptions.getStreamTracerFactories(); ClientStreamTracer[] tracers = new ClientStreamTracer[factories.size() + 1]; StreamInfo streamInfo = StreamInfo.newBuilder() .setCallOptions(callOptions) + .setPreviousAttempts(previousAttempts) .setIsTransparentRetry(isTransparentRetry) .build(); for (int i = 0; i < factories.size(); i++) { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 87162d9aba2..6cd5598e2a6 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -533,7 +533,7 @@ public ClientStream newStream( getTransport(new PickSubchannelArgsImpl(method, headers, callOptions)); Context origContext = context.attach(); ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false); try { return transport.newStream(method, headers, callOptions, tracers); } finally { @@ -572,10 +572,11 @@ void postCommit() { @Override ClientStream newSubstream( - Metadata newHeaders, ClientStreamTracer.Factory factory, boolean isTransparentRetry) { + Metadata newHeaders, ClientStreamTracer.Factory factory, int previousAttempts, + boolean isTransparentRetry) { CallOptions newOptions = callOptions.withStreamTracerFactory(factory); - ClientStreamTracer[] tracers = - GrpcUtil.getClientStreamTracers(newOptions, newHeaders, isTransparentRetry); + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + newOptions, newHeaders, previousAttempts, isTransparentRetry); ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, newHeaders, newOptions)); Context origContext = context.attach(); @@ -624,7 +625,7 @@ ClientStream newSubstream( channelLogger = new ChannelLoggerImpl(channelTracer, timeProvider); ProxyDetector proxyDetector = builder.proxyDetector != null ? builder.proxyDetector : GrpcUtil.DEFAULT_PROXY_DETECTOR; - this.retryEnabled = builder.retryEnabled && !builder.temporarilyDisableRetry; + this.retryEnabled = builder.retryEnabled; this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy); this.offloadExecutorHolder = new ExecutorHolder( diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index d42b3832136..cad4ece233e 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -143,10 +143,6 @@ public static ManagedChannelBuilder forTarget(String target) { long retryBufferSize = DEFAULT_RETRY_BUFFER_SIZE_IN_BYTES; long perRpcBufferLimit = DEFAULT_PER_RPC_BUFFER_LIMIT_IN_BYTES; boolean retryEnabled = false; // TODO(zdapeng): default to true - // Temporarily disable retry when stats or tracing is enabled to avoid breakage, until we know - // what should be the desired behavior for retry + stats/tracing. - // TODO(zdapeng): delete me - boolean temporarilyDisableRetry; InternalChannelz channelz = InternalChannelz.instance(); int maxTraceEvents; @@ -460,8 +456,6 @@ public ManagedChannelImplBuilder disableRetry() { @Override public ManagedChannelImplBuilder enableRetry() { retryEnabled = true; - statsEnabled = false; - tracingEnabled = false; return this; } @@ -592,9 +586,6 @@ public void setStatsRecordRealTimeMetrics(boolean value) { /** * Disable or enable tracing features. Enabled by default. - * - *

For the current release, calling {@code setTracingEnabled(true)} may have a side effect that - * disables retry. */ public void setTracingEnabled(boolean value) { tracingEnabled = value; @@ -642,9 +633,7 @@ public ManagedChannel build() { List getEffectiveInterceptors() { List effectiveInterceptors = new ArrayList<>(this.interceptors); - temporarilyDisableRetry = false; if (statsEnabled) { - temporarilyDisableRetry = true; ClientInterceptor statsInterceptor = null; try { Class censusStatsAccessor = @@ -679,7 +668,6 @@ List getEffectiveInterceptors() { } } if (tracingEnabled) { - temporarilyDisableRetry = true; ClientInterceptor tracingInterceptor = null; try { Class censusTracingAccessor = diff --git a/core/src/main/java/io/grpc/internal/OobChannel.java b/core/src/main/java/io/grpc/internal/OobChannel.java index b628842efe4..589824ae10e 100644 --- a/core/src/main/java/io/grpc/internal/OobChannel.java +++ b/core/src/main/java/io/grpc/internal/OobChannel.java @@ -88,7 +88,7 @@ final class OobChannel extends ManagedChannel implements InternalInstrumented method, CallOptions callOptions, Metadata headers, Context context) { ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false); Context origContext = context.attach(); // delayed transport's newStream() always acquires a lock, but concurrent performance doesn't // matter here because OOB communication should be sparse, and it's not on application RPC's diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index d19a260049b..3d277bbe2fc 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -218,7 +218,7 @@ public ClientStreamTracer newClientStreamTracer( Metadata newHeaders = updateHeaders(headers, previousAttemptCount); // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newSubstream(newHeaders, tracerFactory, isTransparentRetry); + sub.stream = newSubstream(newHeaders, tracerFactory, previousAttemptCount, isTransparentRetry); return sub; } @@ -227,7 +227,8 @@ public ClientStreamTracer newClientStreamTracer( * Client stream is not yet started. */ abstract ClientStream newSubstream( - Metadata headers, ClientStreamTracer.Factory tracerFactory, boolean isTransparentRetry); + Metadata headers, ClientStreamTracer.Factory tracerFactory, int previousAttempts, + boolean isTransparentRetry); /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting @@ -869,24 +870,26 @@ public void run() { synchronized (lock) { scheduledRetry = scheduledRetryCopy = new FutureCanceller(lock); } - scheduledRetryCopy.setFuture( - scheduledExecutorService.schedule( + class RetryBackoffRunnable implements Runnable { + @Override + public void run() { + callExecutor.execute( new Runnable() { @Override public void run() { - callExecutor.execute( - new Runnable() { - @Override - public void run() { - // retry - Substream newSubstream = createSubstream( - substream.previousAttemptCount + 1, - false); - drain(newSubstream); - } - }); + // retry + Substream newSubstream = createSubstream( + substream.previousAttemptCount + 1, + false); + drain(newSubstream); } - }, + }); + } + } + + scheduledRetryCopy.setFuture( + scheduledExecutorService.schedule( + new RetryBackoffRunnable(), retryPlan.backoffNanos, TimeUnit.NANOSECONDS)); return; diff --git a/core/src/main/java/io/grpc/internal/SubchannelChannel.java b/core/src/main/java/io/grpc/internal/SubchannelChannel.java index 1380a6bc716..a1d454ed2fb 100644 --- a/core/src/main/java/io/grpc/internal/SubchannelChannel.java +++ b/core/src/main/java/io/grpc/internal/SubchannelChannel.java @@ -59,7 +59,7 @@ public ClientStream newStream(MethodDescriptor method, transport = notReadyTransport; } ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( - callOptions, headers, /* isTransparentRetry= */ false); + callOptions, headers, 0, /* isTransparentRetry= */ false); Context origContext = context.attach(); try { return transport.newStream(method, headers, callOptions, tracers); diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 95d2c2ba8b5..c9ea504e18b 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -164,7 +164,8 @@ void postCommit() { @Override ClientStream newSubstream( - Metadata metadata, ClientStreamTracer.Factory tracerFactory, boolean isTransparentRetry) { + Metadata metadata, ClientStreamTracer.Factory tracerFactory, int previousAttempts, + boolean isTransparentRetry) { bufferSizeTracer = tracerFactory.newClientStreamTracer(STREAM_INFO, metadata); int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 14c92a9fd1d..852d5882cce 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -44,6 +44,7 @@ dependencies { project(':grpc-grpclb') testImplementation project(':grpc-context').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, + project(':grpc-core').sourceSets.test.output, libraries.mockito alpnagent libraries.jetty_alpn_agent } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 8a6e41722ab..693d9b2af7c 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -92,6 +92,9 @@ import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; +import io.opencensus.stats.Measure; +import io.opencensus.stats.Measure.MeasureDouble; +import io.opencensus.stats.Measure.MeasureLong; import io.opencensus.tags.TagKey; import io.opencensus.tags.TagValue; import io.opencensus.trace.Span; @@ -152,6 +155,15 @@ public abstract class AbstractInteropTest { * SETTINGS/WINDOW_UPDATE exchange. */ public static final int TEST_FLOW_CONTROL_WINDOW = 65 * 1024; + private static final MeasureLong RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/retries_per_call", "Number of retries per call", "1"); + private static final MeasureLong TRANSPARENT_RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/transparent_retries_per_call", "Transparent retries per call", "1"); + private static final MeasureDouble RETRY_DELAY_PER_CALL = + Measure.MeasureDouble.create( + "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); private static final FakeTagger tagger = new FakeTagger(); private static final FakeTagContextBinarySerializer tagContextBinarySerializer = @@ -1234,6 +1246,7 @@ public void deadlineInPast() throws Exception { checkEndTags( clientEndRecord, "grpc.testing.TestService/EmptyCall", Status.DEADLINE_EXCEEDED.getCode(), true); + assertZeroRetryRecorded(); } // warm up the channel @@ -1243,6 +1256,7 @@ public void deadlineInPast() throws Exception { clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); // clientEndRecord clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + assertZeroRetryRecorded(); } try { blockingStub @@ -1261,6 +1275,7 @@ public void deadlineInPast() throws Exception { checkEndTags( clientEndRecord, "grpc.testing.TestService/EmptyCall", Status.DEADLINE_EXCEEDED.getCode(), true); + assertZeroRetryRecorded(); } } @@ -1978,6 +1993,13 @@ private void assertStatsTrace(String method, Status.Code status) { assertStatsTrace(method, status, null, null); } + private void assertZeroRetryRecorded() { + MetricsRecord retryRecord = clientStatsRecorder.pollRecord(); + assertThat(retryRecord.getMetric(RETRIES_PER_CALL)).isEqualTo(0); + assertThat(retryRecord.getMetric(TRANSPARENT_RETRIES_PER_CALL)).isEqualTo(0); + assertThat(retryRecord.getMetric(RETRY_DELAY_PER_CALL)).isEqualTo(0D); + } + private void assertClientStatsTrace(String method, Status.Code code, Collection requests, Collection responses) { // Tracer-based stats @@ -2007,6 +2029,7 @@ private void assertClientStatsTrace(String method, Status.Code code, if (requests != null && responses != null) { checkCensus(clientEndRecord, false, requests, responses); } + assertZeroRetryRecorded(); } } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index 4824f05313b..bdf39e8546a 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -17,13 +17,21 @@ package io.grpc.testing.integration; import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableMap; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.Deadline; +import io.grpc.Deadline.Ticker; import io.grpc.IntegerMarshaller; import io.grpc.ManagedChannel; import io.grpc.Metadata; @@ -36,7 +44,15 @@ import io.grpc.ServerMethodDefinition; import io.grpc.ServerServiceDefinition; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StringMarshaller; +import io.grpc.census.InternalCensusStatsAccessor; +import io.grpc.census.internal.DeprecatedCensusConstants; +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder; +import io.grpc.internal.testing.StatsTestUtils.FakeTagContextBinarySerializer; +import io.grpc.internal.testing.StatsTestUtils.FakeTagger; +import io.grpc.internal.testing.StatsTestUtils.MetricsRecord; import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyServerBuilder; import io.grpc.testing.GrpcCleanupRule; @@ -45,11 +61,20 @@ import io.netty.channel.local.LocalAddress; import io.netty.channel.local.LocalChannel; import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.ScheduledFuture; +import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; +import io.opencensus.stats.Measure; +import io.opencensus.stats.Measure.MeasureDouble; +import io.opencensus.stats.Measure.MeasureLong; +import io.opencensus.tags.TagValue; import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,65 +86,111 @@ @RunWith(JUnit4.class) public class RetryTest { + private static final FakeTagger tagger = new FakeTagger(); + private static final FakeTagContextBinarySerializer tagContextBinarySerializer = + new FakeTagContextBinarySerializer(); + private static final MeasureLong RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/retries_per_call", "Number of retries per call", "1"); + private static final MeasureLong TRANSPARENT_RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/transparent_retries_per_call", "Transparent retries per call", "1"); + private static final MeasureDouble RETRY_DELAY_PER_CALL = + Measure.MeasureDouble.create( + "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + private final FakeClock fakeClock = new FakeClock(); @Mock private ClientCall.Listener mockCallListener; - - @Test - public void retryUntilBufferLimitExceeded() throws Exception { - String message = "String of length 20."; - int bufferLimit = message.length() * 2 - 1; // Can buffer no more than 1 message. - - MethodDescriptor clientStreamingMethod = - MethodDescriptor.newBuilder() - .setType(MethodType.CLIENT_STREAMING) - .setFullMethodName("service/method") - .setRequestMarshaller(new StringMarshaller()) - .setResponseMarshaller(new IntegerMarshaller()) - .build(); - final LinkedBlockingQueue> serverCalls = - new LinkedBlockingQueue<>(); - ServerMethodDefinition methodDefinition = ServerMethodDefinition.create( - clientStreamingMethod, - new ServerCallHandler() { - @Override - public Listener startCall(ServerCall call, Metadata headers) { - serverCalls.offer(call); - return new Listener() {}; + private CountDownLatch backoffLatch = new CountDownLatch(1); + private final EventLoopGroup group = new DefaultEventLoopGroup() { + @SuppressWarnings("FutureReturnValueIgnored") + @Override + public ScheduledFuture schedule( + final Runnable command, final long delay, final TimeUnit unit) { + if (!command.getClass().getName().contains("RetryBackoffRunnable")) { + return super.schedule(command, delay, unit); + } + fakeClock.getScheduledExecutorService().schedule( + new Runnable() { + @Override + public void run() { + group.execute(command); + } + }, + delay, + unit); + backoffLatch.countDown(); + return super.schedule( + new Runnable() { + @Override + public void run() {} // no-op + }, + 0, + TimeUnit.NANOSECONDS); + } + }; + private final FakeStatsRecorder clientStatsRecorder = new FakeStatsRecorder(); + private final ClientInterceptor statsInterceptor = + InternalCensusStatsAccessor.getClientInterceptor( + tagger, tagContextBinarySerializer, clientStatsRecorder, + fakeClock.getStopwatchSupplier(), true, true, true, + /* recordRealTimeMetrics= */ true); + private final MethodDescriptor clientStreamingMethod = + MethodDescriptor.newBuilder() + .setType(MethodType.CLIENT_STREAMING) + .setFullMethodName("service/method") + .setRequestMarshaller(new StringMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + private final LinkedBlockingQueue> serverCalls = + new LinkedBlockingQueue<>(); + private final ServerMethodDefinition methodDefinition = + ServerMethodDefinition.create( + clientStreamingMethod, + new ServerCallHandler() { + @Override + public Listener startCall(ServerCall call, Metadata headers) { + serverCalls.offer(call); + return new Listener() {}; + } } - } - ); - ServerServiceDefinition serviceDefinition = - ServerServiceDefinition.builder(clientStreamingMethod.getServiceName()) - .addMethod(methodDefinition) - .build(); - EventLoopGroup group = new DefaultEventLoopGroup(); - LocalAddress localAddress = new LocalAddress("RetryTest.retryUntilBufferLimitExceeded"); - Server localServer = cleanupRule.register(NettyServerBuilder.forAddress(localAddress) + ); + private final ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder(clientStreamingMethod.getServiceName()) + .addMethod(methodDefinition) + .build(); + private final LocalAddress localAddress = new LocalAddress(this.getClass().getName()); + private Server localServer; + private ManagedChannel channel; + private Map retryPolicy = null; + private long bufferLimit = 1L << 20; // 1M + + private void startNewServer() throws Exception { + localServer = cleanupRule.register(NettyServerBuilder.forAddress(localAddress) .channelType(LocalServerChannel.class) .bossEventLoopGroup(group) .workerEventLoopGroup(group) .addService(serviceDefinition) .build()); localServer.start(); + } - Map retryPolicy = new HashMap<>(); - retryPolicy.put("maxAttempts", 4D); - retryPolicy.put("initialBackoff", "10s"); - retryPolicy.put("maxBackoff", "10s"); - retryPolicy.put("backoffMultiplier", 1D); - retryPolicy.put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")); + private void createNewChannel() { Map methodConfig = new HashMap<>(); Map name = new HashMap<>(); name.put("service", "service"); methodConfig.put("name", Arrays.asList(name)); - methodConfig.put("retryPolicy", retryPolicy); + if (retryPolicy != null) { + methodConfig.put("retryPolicy", retryPolicy); + } Map rawServiceConfig = new HashMap<>(); rawServiceConfig.put("methodConfig", Arrays.asList(methodConfig)); - ManagedChannel channel = cleanupRule.register( + channel = cleanupRule.register( NettyChannelBuilder.forAddress(localAddress) .channelType(LocalChannel.class) .eventLoopGroup(group) @@ -127,23 +198,100 @@ public Listener startCall(ServerCall call, Metadata hea .enableRetry() .perRpcBufferLimit(bufferLimit) .defaultServiceConfig(rawServiceConfig) + .intercept(statsInterceptor) .build()); + } + + private void elapseBackoff(long time, TimeUnit unit) throws Exception { + assertThat(backoffLatch.await(5, SECONDS)).isTrue(); + backoffLatch = new CountDownLatch(1); + fakeClock.forwardTime(time, unit); + } + + private void assertRpcStartedRecorded() throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)) + .isEqualTo(1); + } + + private void assertOutboundMessageRecorded() throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat( + record.getMetricAsLongOrFail( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD)) + .isEqualTo(1); + } + + private void assertInboundMessageRecorded() throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat( + record.getMetricAsLongOrFail( + RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_METHOD)) + .isEqualTo(1); + } + + private void assertOutboundWireSizeRecorded(long length) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat(record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD)) + .isEqualTo(length); + } + + private void assertInboundWireSizeRecorded(long length) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat( + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_METHOD)) + .isEqualTo(length); + } + + private void assertRpcStatusRecorded( + Status.Code code, long roundtripLatencyMs, long outboundMessages) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + TagValue statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertThat(statusTag.asString()).isEqualTo(code.toString()); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)) + .isEqualTo(1); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)) + .isEqualTo(roundtripLatencyMs); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)) + .isEqualTo(outboundMessages); + } + + private void assertRetryStatsRecorded( + int numRetries, int numTransparentRetries, long retryDelayMs) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat(record.getMetricAsLongOrFail(RETRIES_PER_CALL)).isEqualTo(numRetries); + assertThat(record.getMetricAsLongOrFail(TRANSPARENT_RETRIES_PER_CALL)) + .isEqualTo(numTransparentRetries); + assertThat(record.getMetricAsLongOrFail(RETRY_DELAY_PER_CALL)).isEqualTo(retryDelayMs); + } + + @Test + public void retryUntilBufferLimitExceeded() throws Exception { + String message = "String of length 20."; + + startNewServer(); + bufferLimit = message.length() * 2L - 1; // Can buffer no more than 1 message. + retryPolicy = ImmutableMap.builder() + .put("maxAttempts", 4D) + .put("initialBackoff", "10s") + .put("maxBackoff", "10s") + .put("backoffMultiplier", 1D) + .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) + .build(); + createNewChannel(); ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); call.start(mockCallListener, new Metadata()); call.sendMessage(message); - ServerCall serverCall = serverCalls.poll(5, TimeUnit.SECONDS); + ServerCall serverCall = serverCalls.poll(5, SECONDS); serverCall.request(2); // trigger retry - Metadata pushBackMetadata = new Metadata(); - pushBackMetadata.put( - Metadata.Key.of("grpc-retry-pushback-ms", Metadata.ASCII_STRING_MARSHALLER), - "0"); // retry immediately serverCall.close( Status.UNAVAILABLE.withDescription("original attempt failed"), - pushBackMetadata); + new Metadata()); + elapseBackoff(10, SECONDS); // 2nd attempt received - serverCall = serverCalls.poll(5, TimeUnit.SECONDS); + serverCall = serverCalls.poll(5, SECONDS); serverCall.request(2); verify(mockCallListener, never()).onClose(any(Status.class), any(Metadata.class)); // send one more message, should exceed buffer limit @@ -157,4 +305,146 @@ public Listener startCall(ServerCall call, Metadata hea verify(mockCallListener, timeout(5000)).onClose(statusCaptor.capture(), any(Metadata.class)); assertThat(statusCaptor.getValue().getDescription()).contains("2nd attempt failed"); } + + @Test + public void statsRecorded() throws Exception { + startNewServer(); + retryPolicy = ImmutableMap.builder() + .put("maxAttempts", 4D) + .put("initialBackoff", "10s") + .put("maxBackoff", "10s") + .put("backoffMultiplier", 1D) + .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) + .build(); + createNewChannel(); + + ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + String message = "String of length 20."; + call.sendMessage(message); + assertOutboundMessageRecorded(); + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + // original attempt latency + fakeClock.forwardTime(1, SECONDS); + // trigger retry + serverCall.close( + Status.UNAVAILABLE.withDescription("original attempt failed"), + new Metadata()); + assertRpcStatusRecorded(Status.Code.UNAVAILABLE, 1000, 1); + elapseBackoff(10, SECONDS); + assertRpcStartedRecorded(); + assertOutboundMessageRecorded(); + serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + message = "new message"; + call.sendMessage(message); + assertOutboundMessageRecorded(); + assertOutboundWireSizeRecorded(message.length()); + // retry attempt latency + fakeClock.forwardTime(2, SECONDS); + serverCall.sendHeaders(new Metadata()); + serverCall.sendMessage(3); + call.request(1); + assertInboundMessageRecorded(); + assertInboundWireSizeRecorded(1); + serverCall.close(Status.OK, new Metadata()); + assertRpcStatusRecorded(Status.Code.OK, 2000, 2); + assertRetryStatsRecorded(1, 0, 10_000); + } + + @Test + public void serverCancelledAndClientDeadlineExceeded() throws Exception { + startNewServer(); + createNewChannel(); + + class CloseDelayedTracer extends ClientStreamTracer { + @Override + public void streamClosed(Status status) { + fakeClock.forwardTime(10, SECONDS); + } + } + + class CloseDelayedTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new CloseDelayedTracer(); + } + } + + CallOptions callOptions = CallOptions.DEFAULT + .withDeadline(Deadline.after( + 10, + SECONDS, + new Ticker() { + @Override + public long nanoTime() { + return fakeClock.getTicker().read(); + } + })) + .withStreamTracerFactory(new CloseDelayedTracerFactory()); + ClientCall call = channel.newCall(clientStreamingMethod, callOptions); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.close(Status.CANCELLED, new Metadata()); + assertRpcStatusRecorded(Code.DEADLINE_EXCEEDED, 10_000, 0); + assertRetryStatsRecorded(0, 0, 0); + } + + @Ignore("flaky because old transportReportStatus() is not completely migrated yet") + @Test + public void transparentRetryStatsRecorded() throws Exception { + startNewServer(); + createNewChannel(); + + final AtomicBoolean transparentRetryTriggered = new AtomicBoolean(); + class TransparentRetryTriggeringTracer extends ClientStreamTracer { + + @Override + public void streamCreated(Attributes transportAttrs, Metadata metadata) { + if (transparentRetryTriggered.get()) { + return; + } + localServer.shutdownNow(); + } + + @Override + public void streamClosed(Status status) { + if (transparentRetryTriggered.get()) { + return; + } + transparentRetryTriggered.set(true); + try { + startNewServer(); + channel.resetConnectBackoff(); + channel.getState(true); + } catch (Exception e) { + throw new AssertionError("local server can not be restarted", e); + } + } + } + + class TransparentRetryTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new TransparentRetryTriggeringTracer(); + } + } + + CallOptions callOptions = CallOptions.DEFAULT + .withWaitForReady() + .withStreamTracerFactory(new TransparentRetryTracerFactory()); + ClientCall call = channel.newCall(clientStreamingMethod, callOptions); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + assertRpcStatusRecorded(Code.UNAVAILABLE, 0, 0); + assertRpcStartedRecorded(); + call.cancel("cancel", null); + assertRpcStatusRecorded(Code.CANCELLED, 0, 0); + assertRetryStatsRecorded(0, 1, 0); + } } From 21429023436699bee926101956d3dc2a7205d288 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 11 Aug 2021 10:25:57 -0700 Subject: [PATCH 25/82] core: fix retry flow control issue (#8401) There has been an issue about flow control when retry is enabled. Currently we call `masterListener.onReady()` whenever `substreamListener.onReady()` is called. The user's `onReady()` implementation might do ``` while(observer.isReady()) { // send one more message. } ``` However, currently if the `RetriableStream` is still draining, `isReady()` is false, and user's `onReady()` exits immediately. And because `substreamListener.onReady()` is already called, it may not be called again after drained. This PR fixes the issue by - Use a SerializeExecutor to call all `masterListener` callbacks. - Once `RetriableStream` is drained, check `isReady()` and if so call `onReady()`. - Once `substreamListener.onReady()` is called, check `isReady()` and only if so we call `masterListener.onReady()`. --- .../io/grpc/internal/RetriableStream.java | 100 +++++++++++++++--- .../io/grpc/internal/RetriableStreamTest.java | 88 ++++++++++++++- 2 files changed, 170 insertions(+), 18 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 3d277bbe2fc..1fb8d3c43bd 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -30,8 +30,10 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ClientStreamListener.RpcProgress; import java.io.InputStream; +import java.lang.Thread.UncaughtExceptionHandler; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -64,6 +66,16 @@ abstract class RetriableStream implements ClientStream { private final MethodDescriptor method; private final Executor callExecutor; + private final Executor listenerSerializeExecutor = new SynchronizationContext( + new UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw Status.fromThrowable(e) + .withDescription("Uncaught exception in the SynchronizationContext. Re-thrown.") + .asRuntimeException(); + } + } + ); private final ScheduledExecutorService scheduledExecutorService; // Must not modify it. private final Metadata headers; @@ -105,6 +117,7 @@ abstract class RetriableStream implements ClientStream { private FutureCanceller scheduledHedging; private long nextBackoffIntervalNanos; private Status cancellationStatus; + private boolean isClosed; RetriableStream( MethodDescriptor method, Metadata headers, @@ -247,6 +260,7 @@ private void drain(Substream substream) { int chunk = 0x80; List list = null; boolean streamStarted = false; + Runnable onReadyRunnable = null; while (true) { State savedState; @@ -264,7 +278,18 @@ private void drain(Substream substream) { } if (index == savedState.buffer.size()) { // I'm drained state = savedState.substreamDrained(substream); - return; + if (!isReady()) { + return; + } + onReadyRunnable = new Runnable() { + @Override + public void run() { + if (!isClosed) { + masterListener.onReady(); + } + } + }; + break; } if (substream.closed) { @@ -299,6 +324,11 @@ private void drain(Substream substream) { } } + if (onReadyRunnable != null) { + listenerSerializeExecutor.execute(onReadyRunnable); + return; + } + substream.stream.cancel( state.winningSubstream == substream ? cancellationStatus : CANCELLED_BECAUSE_COMMITTED); } @@ -450,14 +480,22 @@ public void run() { } @Override - public final void cancel(Status reason) { + public final void cancel(final Status reason) { Substream noopSubstream = new Substream(0 /* previousAttempts doesn't matter here */); noopSubstream.stream = new NoopClientStream(); Runnable runnable = commit(noopSubstream); if (runnable != null) { - masterListener.closed(reason, RpcProgress.PROCESSED, new Metadata()); runnable.run(); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(reason, RpcProgress.PROCESSED, new Metadata()); + + } + }); return; } @@ -771,18 +809,25 @@ private final class Sublistener implements ClientStreamListener { } @Override - public void headersRead(Metadata headers) { + public void headersRead(final Metadata headers) { commitAndRun(substream); if (state.winningSubstream == substream) { - masterListener.headersRead(headers); if (throttle != null) { throttle.onSuccess(); } + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + masterListener.headersRead(headers); + } + }); } } @Override - public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + public void closed( + final Status status, final RpcProgress rpcProgress, final Metadata trailers) { synchronized (lock) { state = state.substreamClosed(substream); closedSubstreamsInsight.append(status.getCode()); @@ -793,7 +838,14 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (substream.bufferLimitExceeded) { commitAndRun(substream); if (state.winningSubstream == substream) { - masterListener.closed(status, rpcProgress, trailers); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(status, rpcProgress, trailers); + } + }); } return; } @@ -900,7 +952,14 @@ public void run() { commitAndRun(substream); if (state.winningSubstream == substream) { - masterListener.closed(status, rpcProgress, trailers); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(status, rpcProgress, trailers); + } + }); } } @@ -970,22 +1029,37 @@ private Integer getPushbackMills(Metadata trailer) { } @Override - public void messagesAvailable(MessageProducer producer) { + public void messagesAvailable(final MessageProducer producer) { State savedState = state; checkState( savedState.winningSubstream != null, "Headers should be received prior to messages."); if (savedState.winningSubstream != substream) { return; } - masterListener.messagesAvailable(producer); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + masterListener.messagesAvailable(producer); + } + }); } @Override public void onReady() { // FIXME(#7089): hedging case is broken. - // TODO(zdapeng): optimization: if the substream is not drained yet, delay onReady() once - // drained and if is still ready. - masterListener.onReady(); + if (!isReady()) { + return; + } + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + if (!isClosed) { + masterListener.onReady(); + } + } + }); } } diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index c9ea504e18b..8b851573b21 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -256,6 +256,7 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); retriableStream.sendMessage("msg1"); @@ -308,6 +309,7 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(456); inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // send more messages @@ -356,6 +358,7 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(456); inOrder.verify(mockStream3, times(7)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); InsightBuilder insight = new InsightBuilder(); @@ -637,6 +640,7 @@ public void retry_cancelWhileBackoff() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); // retry ClientStream mockStream2 = mock(ClientStream.class); @@ -656,7 +660,7 @@ public void retry_cancelWhileBackoff() { @Test public void operationsWhileDraining() { - ArgumentCaptor sublistenerCaptor1 = + final ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); final AtomicReference sublistenerCaptor2 = new AtomicReference<>(); @@ -669,10 +673,16 @@ public void operationsWhileDraining() { @Override public void request(int numMessages) { retriableStream.sendMessage("substream1 request " + numMessages); + sublistenerCaptor1.getValue().onReady(); if (numMessages > 1) { retriableStream.request(--numMessages); } } + + @Override + public boolean isReady() { + return true; + } })); final ClientStream mockStream2 = @@ -688,7 +698,7 @@ public void start(ClientStreamListener listener) { @Override public void request(int numMessages) { retriableStream.sendMessage("substream2 request " + numMessages); - + sublistenerCaptor2.get().onReady(); if (numMessages == 3) { sublistenerCaptor2.get().headersRead(new Metadata()); } @@ -699,9 +709,14 @@ public void request(int numMessages) { retriableStream.cancel(cancelStatus); } } + + @Override + public boolean isReady() { + return true; + } })); - InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2, masterListener); doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); retriableStream.start(masterListener); @@ -716,6 +731,7 @@ public void request(int numMessages) { inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); // msg "substream1 request 2" inOrder.verify(mockStream1).request(1); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); // msg "substream1 request 1" + inOrder.verify(masterListener).onReady(); // retry doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); @@ -743,8 +759,8 @@ public void request(int numMessages) { // msg "substream2 request 2" inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(100); - - verify(mockStream2).cancel(cancelStatus); + inOrder.verify(mockStream2).cancel(cancelStatus); + inOrder.verify(masterListener, never()).onReady(); // "substream2 request 1" will never be sent inOrder.verify(mockStream2, never()).writeMessage(any(InputStream.class)); @@ -1073,6 +1089,7 @@ public void perRpcBufferLimitExceededDuringBackoff() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); bufferSizeTracer.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -1089,6 +1106,7 @@ public void perRpcBufferLimitExceededDuringBackoff() { fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); + verify(mockStream2).isReady(); // bufferLimitExceeded bufferSizeTracer.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -1152,6 +1170,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); @@ -1167,6 +1186,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // retry2 @@ -1183,6 +1203,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // retry3 @@ -1200,6 +1221,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // retry4 @@ -1214,6 +1236,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor5 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream5).start(sublistenerCaptor5.capture()); + inOrder.verify(mockStream5).isReady(); inOrder.verifyNoMoreInteractions(); // retry5 @@ -1228,6 +1251,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor6 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream6).start(sublistenerCaptor6.capture()); + inOrder.verify(mockStream6).isReady(); inOrder.verifyNoMoreInteractions(); // can not retry any more @@ -1258,6 +1282,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); @@ -1276,6 +1301,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // retry2 @@ -1293,6 +1319,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // retry3 @@ -1307,6 +1334,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // retry4 @@ -1323,6 +1351,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor5 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream5).start(sublistenerCaptor5.capture()); + inOrder.verify(mockStream5).isReady(); inOrder.verifyNoMoreInteractions(); // retry5 @@ -1340,6 +1369,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor6 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream6).start(sublistenerCaptor6.capture()); + inOrder.verify(mockStream6).isReady(); inOrder.verifyNoMoreInteractions(); // can not retry any more even pushback is positive @@ -1597,6 +1627,7 @@ public void transparentRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // transparent retry @@ -1608,6 +1639,7 @@ public void transparentRetry() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1623,6 +1655,7 @@ public void transparentRetry() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1645,6 +1678,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // normal retry @@ -1658,6 +1692,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1674,6 +1709,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); } @@ -1695,6 +1731,7 @@ public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // normal retry @@ -1708,6 +1745,7 @@ public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1738,6 +1776,7 @@ method, new Metadata(), channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_ ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // transparent retry @@ -1750,6 +1789,7 @@ method, new Metadata(), channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_ ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(retriableStreamRecorder).postCommit(); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); assertEquals(0, fakeClock.numPendingTasks()); } @@ -1768,6 +1808,7 @@ public void droppedShouldNeverRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); // drop and verify no retry Status status = Status.fromCode(RETRIABLE_STATUS_CODE_1); @@ -1839,6 +1880,7 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); hedgingStream.sendMessage("msg1"); @@ -1880,6 +1922,8 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(456); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // send more messages @@ -1917,6 +1961,9 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(456); inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // send one more message @@ -1959,6 +2006,9 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); inOrder.verify(mockStream4).request(456); inOrder.verify(mockStream4, times(4)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); InsightBuilder insight = new InsightBuilder(); @@ -2009,6 +2059,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2016,6 +2067,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2023,6 +2075,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2030,6 +2083,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // a random one of the hedges fails @@ -2041,6 +2095,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor5 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream5).start(sublistenerCaptor5.capture()); + inOrder.verify(mockStream5).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2048,6 +2103,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor6 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream6).start(sublistenerCaptor6.capture()); + inOrder.verify(mockStream6).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2092,6 +2148,7 @@ public void hedging_receiveHeaders() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2099,6 +2156,7 @@ public void hedging_receiveHeaders() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2106,6 +2164,7 @@ public void hedging_receiveHeaders() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // a random one of the hedges receives headers @@ -2143,6 +2202,7 @@ public void hedging_pushback_negative() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2150,6 +2210,7 @@ public void hedging_pushback_negative() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2157,6 +2218,7 @@ public void hedging_pushback_negative() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // a random one of the hedges receives a negative pushback @@ -2188,6 +2250,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2195,6 +2258,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); @@ -2212,6 +2276,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // hedge2 receives a pushback for HEDGING_DELAY_IN_SECONDS - 1 second @@ -2225,6 +2290,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // commit @@ -2254,6 +2320,7 @@ public void hedging_cancelled() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2261,6 +2328,8 @@ public void hedging_cancelled() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); Status status = Status.CANCELLED.withDescription("cancelled"); @@ -2275,6 +2344,8 @@ public void hedging_cancelled() { assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + inOrder.verify(masterListener).closed( + any(Status.class), any(RpcProgress.class), any(Metadata.class)); inOrder.verifyNoMoreInteractions(); } @@ -2289,6 +2360,7 @@ public void hedging_perRpcBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); ClientStreamTracer bufferSizeTracer1 = bufferSizeTracer; bufferSizeTracer1.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -2297,6 +2369,8 @@ public void hedging_perRpcBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream2).start(sublistenerCaptor2.capture()); + verify(mockStream1, times(2)).isReady(); + verify(mockStream2).isReady(); ClientStreamTracer bufferSizeTracer2 = bufferSizeTracer; bufferSizeTracer2.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -2313,6 +2387,7 @@ public void hedging_perRpcBufferLimitExceeded() { verify(retriableStreamRecorder).postCommit(); verifyNoMoreInteractions(mockStream1); + verify(mockStream2).isReady(); verifyNoMoreInteractions(mockStream2); } @@ -2327,6 +2402,7 @@ public void hedging_channelBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); ClientStreamTracer bufferSizeTracer1 = bufferSizeTracer; bufferSizeTracer1.outboundWireSize(100); @@ -2335,6 +2411,8 @@ public void hedging_channelBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream2).start(sublistenerCaptor2.capture()); + verify(mockStream1, times(2)).isReady(); + verify(mockStream2).isReady(); ClientStreamTracer bufferSizeTracer2 = bufferSizeTracer; bufferSizeTracer2.outboundWireSize(100); From 2a636420ef649d8dab22906decee2c5892ad08fd Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Wed, 11 Aug 2021 14:01:21 -0700 Subject: [PATCH 26/82] Update xDS client/server image per-branch tag after build (#8400) --- buildscripts/kokoro/xds-k8s.sh | 8 +++++++- buildscripts/xds-k8s/cloudbuild.yaml | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/buildscripts/kokoro/xds-k8s.sh b/buildscripts/kokoro/xds-k8s.sh index cafd884ccaf..0a234f2c6ef 100755 --- a/buildscripts/kokoro/xds-k8s.sh +++ b/buildscripts/kokoro/xds-k8s.sh @@ -54,10 +54,16 @@ build_test_app_docker_images() { cp -v "${docker_dir}/"*.Dockerfile "${build_dir}" cp -v "${docker_dir}/"*.properties "${build_dir}" cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" + # Pick a branch name for the built image + if [[ -n $KOKORO_JOB_NAME ]]; then + branch_name=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/java/([^/]+)/.*|\1|') + else + branch_name='experimental' + fi # Run Google Cloud Build gcloud builds submit "${build_dir}" \ --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT}" + --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=${branch_name}" # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x # TODO(sergiitk): do this when adding support for custom configs per version } diff --git a/buildscripts/xds-k8s/cloudbuild.yaml b/buildscripts/xds-k8s/cloudbuild.yaml index 03c57489214..577ed73ce58 100644 --- a/buildscripts/xds-k8s/cloudbuild.yaml +++ b/buildscripts/xds-k8s/cloudbuild.yaml @@ -3,6 +3,7 @@ steps: args: - 'build' - '--tag=${_SERVER_IMAGE_NAME}:${COMMIT_SHA}' + - '--tag=${_SERVER_IMAGE_NAME}:${BRANCH_NAME}' - '--file=test-server.Dockerfile' - '.' @@ -10,6 +11,7 @@ steps: args: - 'build' - '--tag=${_CLIENT_IMAGE_NAME}:${COMMIT_SHA}' + - '--tag=${_CLIENT_IMAGE_NAME}:${BRANCH_NAME}' - '--file=test-client.Dockerfile' - '.' @@ -19,4 +21,6 @@ substitutions: images: - '${_SERVER_IMAGE_NAME}:${COMMIT_SHA}' + - '${_SERVER_IMAGE_NAME}:${BRANCH_NAME}' - '${_CLIENT_IMAGE_NAME}:${COMMIT_SHA}' + - '${_CLIENT_IMAGE_NAME}:${BRANCH_NAME}' From bdf9a964764a1383f0d980298b818f044e26d405 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 11 Aug 2021 14:44:23 -0700 Subject: [PATCH 27/82] core: enable retry by default (#8402) Stabilize `enableRetry()` and `disableRetry()`. Disable retry in `ManagedChannelImplTest` because each call attempt will fork the headers to a new instance, and add a ClientStreamTracer.Factory for bufferSizeLimit in CallOptions, which makes verification not straightforward. --- api/src/main/java/io/grpc/ManagedChannelBuilder.java | 2 -- .../java/io/grpc/internal/ManagedChannelImplBuilder.java | 2 +- .../test/java/io/grpc/internal/ManagedChannelImplTest.java | 5 +++++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index 73e66ed6dc4..98b22807ccc 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -467,7 +467,6 @@ public T perRpcBufferLimit(long bytes) { * @return this * @since 1.11.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3982") public T disableRetry() { throw new UnsupportedOperationException(); } @@ -482,7 +481,6 @@ public T disableRetry() { * @return this * @since 1.11.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3982") public T enableRetry() { throw new UnsupportedOperationException(); } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index cad4ece233e..26c48fc8596 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -142,7 +142,7 @@ public static ManagedChannelBuilder forTarget(String target) { int maxHedgedAttempts = 5; long retryBufferSize = DEFAULT_RETRY_BUFFER_SIZE_IN_BYTES; long perRpcBufferLimit = DEFAULT_PER_RPC_BUFFER_LIMIT_IN_BYTES; - boolean retryEnabled = false; // TODO(zdapeng): default to true + boolean retryEnabled = true; InternalChannelz channelz = InternalChannelz.instance(); int maxTraceEvents; diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index ccfb5f074c5..668411d7ecc 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -355,6 +355,7 @@ public void close() throws SecurityException { channelBuilder = new ManagedChannelImplBuilder(TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + channelBuilder.disableRetry(); configureBuilder(channelBuilder); } @@ -1881,6 +1882,7 @@ public void oobChannelHasNoChannelCallCredentials() { TARGET, InsecureChannelCredentials.create(), new FakeCallCredentials(metadataKey, channelCredValue), new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + channelBuilder.disableRetry(); configureBuilder(channelBuilder); createChannel(); @@ -1933,6 +1935,7 @@ public void oobChannelHasNoChannelCallCredentials() { new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build()) .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) + .disableRetry() // irrelevant to what we test, disable retry to make verification easy .build(); oob.getState(true); ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); @@ -1980,6 +1983,7 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { TARGET, InsecureChannelCredentials.create(), new FakeCallCredentials(metadataKey, channelCredValue), new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + channelBuilder.disableRetry(); configureBuilder(channelBuilder); createChannel(); @@ -2017,6 +2021,7 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { new FakeNameResolverFactory.Builder(URI.create("fake://oobauthority/")).build()) .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) + .disableRetry() // irrelevant to what we test, disable retry to make verification easy .build(); oob.getState(true); ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); From c8db48e2b19cedebc3c975185939154d80241a67 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Thu, 12 Aug 2021 10:01:32 -0700 Subject: [PATCH 28/82] xds: enable xDS retry by default (#8403) --- xds/src/main/java/io/grpc/xds/ClientXdsClient.java | 4 ++-- xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 4ae6651784f..93bfeb36b34 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -134,8 +134,8 @@ final class ClientXdsClient extends AbstractXdsClient { || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION")); @VisibleForTesting static boolean enableRetry = - !Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")) - && Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")); + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")); private static final String TYPE_URL_HTTP_CONNECTION_MANAGER_V2 = "type.googleapis.com/envoy.config.filter.network.http_connection_manager.v2" diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 4b12ebd71c8..60ce2befe45 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -136,7 +136,7 @@ public class ClientXdsClientDataTest { @Before public void setUp() { originalEnableRetry = ClientXdsClient.enableRetry; - assertThat(originalEnableRetry).isFalse(); + assertThat(originalEnableRetry).isTrue(); } @After From 6a6a5279c0dc067a83799e009e4375e536ef4959 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Thu, 12 Aug 2021 13:45:19 -0700 Subject: [PATCH 29/82] Add a branch name in xds_url_map's CloudBuild (#8405) --- buildscripts/kokoro/xds_url_map.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buildscripts/kokoro/xds_url_map.sh b/buildscripts/kokoro/xds_url_map.sh index cbb1552835b..b791528461d 100755 --- a/buildscripts/kokoro/xds_url_map.sh +++ b/buildscripts/kokoro/xds_url_map.sh @@ -54,7 +54,7 @@ build_test_app_docker_images() { # Run Google Cloud Build gcloud builds submit "${build_dir}" \ --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT}" + --substitutions "_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=experimental" # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x # TODO(sergiitk): do this when adding support for custom configs per version } From 3e9488be25ec214074f6037ab223b7a8234973a0 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 17 Aug 2021 09:47:27 -0700 Subject: [PATCH 30/82] buildscripts: Increase memory for Gradle in Android CI We've still been seeing random memory-related failures with the Android CI, but it is nowhere near as severe as it was. But even when running locally with "-Xmx512m -XX:MaxMetaspaceSize=512m" I get failures. Our CI environment has lots of RAM; let's use it. --- buildscripts/kokoro/android.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/buildscripts/kokoro/android.sh b/buildscripts/kokoro/android.sh index 3ca4fa88bc0..7b9e7f53885 100755 --- a/buildscripts/kokoro/android.sh +++ b/buildscripts/kokoro/android.sh @@ -18,8 +18,9 @@ export OS_NAME=$(uname) cat <> gradle.properties # defaults to -Xmx512m -XX:MaxMetaspaceSize=256m # https://docs.gradle.org/current/userguide/build_environment.html#sec:configuring_jvm_memory -# Increased due to java.lang.OutOfMemoryError: Metaspace failures -org.gradle.jvmargs=-Xmx512m -XX:MaxMetaspaceSize=512m +# Increased due to java.lang.OutOfMemoryError: Metaspace failures, "JVM heap +# space is exhausted", and to increase build speed +org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m EOF echo y | ${ANDROID_HOME}/tools/bin/sdkmanager "build-tools;28.0.3" From 90606abdf13d36e08f3076c769df976d56a2b290 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Thu, 29 Jul 2021 23:41:20 +0000 Subject: [PATCH 31/82] Update README etc to reference 1.40.0 --- README.md | 36 ++++++++++++------------ cronet/README.md | 2 +- documentation/android-channel-builder.md | 4 +-- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index cde4222ba7b..10e437bed1a 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.39.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.39.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.40.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.40.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,17 +43,17 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.39.0 + 1.40.0 io.grpc grpc-protobuf - 1.39.0 + 1.40.0 io.grpc grpc-stub - 1.39.0 + 1.40.0 org.apache.tomcat @@ -65,23 +65,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -implementation 'io.grpc:grpc-netty-shaded:1.39.0' -implementation 'io.grpc:grpc-protobuf:1.39.0' -implementation 'io.grpc:grpc-stub:1.39.0' +implementation 'io.grpc:grpc-netty-shaded:1.40.0' +implementation 'io.grpc:grpc-protobuf:1.40.0' +implementation 'io.grpc:grpc-stub:1.40.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.39.0' -implementation 'io.grpc:grpc-protobuf-lite:1.39.0' -implementation 'io.grpc:grpc-stub:1.39.0' +implementation 'io.grpc:grpc-okhttp:1.40.0' +implementation 'io.grpc:grpc-protobuf-lite:1.40.0' +implementation 'io.grpc:grpc-stub:1.40.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.39.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.40.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -111,9 +111,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.17.2:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.17.3:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.39.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.40.0:exe:${os.detected.classifier} @@ -139,11 +139,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.17.2" + artifact = "com.google.protobuf:protoc:3.17.3" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0' } } generateProtoTasks { @@ -172,11 +172,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.17.2" + artifact = "com.google.protobuf:protoc:3.17.3" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0' } } generateProtoTasks { diff --git a/cronet/README.md b/cronet/README.md index bd5329e5192..580c1e4acbe 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.39.0' +implementation 'io.grpc:grpc-cronet:1.40.0' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index 93447639197..fc491af8498 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.39.0' -implementation 'io.grpc:grpc-okhttp:1.39.0' +implementation 'io.grpc:grpc-android:1.40.0' +implementation 'io.grpc:grpc-okhttp:1.40.0' ``` You also need permission to access the device's network state in your From 2c2ebaebd5a93acec92fbd2708faac582db99371 Mon Sep 17 00:00:00 2001 From: ZhenLian Date: Tue, 17 Aug 2021 16:13:30 -0700 Subject: [PATCH 32/82] advancedtls: adding AdvancedTlsX509TrustManager and AdvancedTlsX509KeyManager (#8175) * add advanced TLS classes and tests --- core/BUILD.bazel | 1 + .../grpc/util/AdvancedTlsX509KeyManager.java | 234 +++++++++ .../util/AdvancedTlsX509TrustManager.java | 361 +++++++++++++ .../java/io/grpc/netty/AdvancedTlsTest.java | 477 ++++++++++++++++++ 4 files changed, 1073 insertions(+) create mode 100644 core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java create mode 100644 core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java create mode 100644 netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java diff --git a/core/BUILD.bazel b/core/BUILD.bazel index c50e86a511c..60a08798d58 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -60,6 +60,7 @@ java_library( "@com_google_code_findbugs_jsr305//jar", "@com_google_guava_guava//jar", "@com_google_j2objc_j2objc_annotations//jar", + "@org_codehaus_mojo_animal_sniffer_annotations//jar", ], ) diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java new file mode 100644 index 00000000000..adaa1e6e69a --- /dev/null +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -0,0 +1,234 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.ExperimentalApi; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.net.Socket; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.util.Arrays; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedKeyManager; + +/** + * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure + * advanced TLS features, such as private key and certificate chain reloading, etc. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") +public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { + private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); + + // The credential information sent to peers to prove our identity. + private volatile KeyInfo keyInfo; + + /** + * Constructs an AdvancedTlsX509KeyManager. + */ + public AdvancedTlsX509KeyManager() throws CertificateException { } + + @Override + public PrivateKey getPrivateKey(String alias) { + if (alias.equals("default")) { + return this.keyInfo.key; + } + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + if (alias.equals("default")) { + return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length); + } + return null; + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return new String[] {"default"}; + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return "default"; + } + + @Override + public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { + return "default"; + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return new String[] {"default"}; + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return "default"; + } + + @Override + public String chooseEngineServerAlias(String keyType, Principal[] issuers, + SSLEngine engine) { + return "default"; + } + + /** + * Updates the current cached private key and cert chains. + * + * @param key the private key that is going to be used + * @param certs the certificate chain that is going to be used + */ + public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) + throws CertificateException { + // TODO(ZhenLian): explore possibilities to do a crypto check here. + this.keyInfo = new KeyInfo(checkNotNull(key, "key"), checkNotNull(certs, "certs")); + } + + /** + * Schedules a {@code ScheduledExecutorService} to read private key and certificate chains from + * the local file paths periodically, and update the cached identity credentials if they are both + * updated. + * + * @param keyFile the file on disk holding the private key + * @param certFile the file on disk holding the certificate chain + * @param period the period between successive read-and-update executions + * @param unit the time unit of the initialDelay and period parameters + * @param executor the execute service we use to read and update the credentials + * @return an object that caller should close when the file refreshes are not needed + */ + public Closeable updateIdentityCredentialsFromFile(File keyFile, File certFile, + long period, TimeUnit unit, ScheduledExecutorService executor) { + final ScheduledFuture future = + executor.scheduleWithFixedDelay( + new LoadFilePathExecution(keyFile, certFile), 0, period, unit); + return new Closeable() { + @Override public void close() { + future.cancel(false); + } + }; + } + + private static class KeyInfo { + // The private key and the cert chain we will use to send to peers to prove our identity. + final PrivateKey key; + final X509Certificate[] certs; + + public KeyInfo(PrivateKey key, X509Certificate[] certs) { + this.key = key; + this.certs = certs; + } + } + + private class LoadFilePathExecution implements Runnable { + File keyFile; + File certFile; + long currentKeyTime; + long currentCertTime; + + public LoadFilePathExecution(File keyFile, File certFile) { + this.keyFile = keyFile; + this.certFile = certFile; + this.currentKeyTime = 0; + this.currentCertTime = 0; + } + + @Override + public void run() { + try { + UpdateResult newResult = readAndUpdate(this.keyFile, this.certFile, this.currentKeyTime, + this.currentCertTime); + if (newResult.success) { + this.currentKeyTime = newResult.keyTime; + this.currentCertTime = newResult.certTime; + } + } catch (CertificateException | IOException | NoSuchAlgorithmException + | InvalidKeySpecException e) { + log.log(Level.SEVERE, "Failed refreshing private key and certificate chain from files. " + + "Using previous ones", e); + } + } + } + + private static class UpdateResult { + boolean success; + long keyTime; + long certTime; + + public UpdateResult(boolean success, long keyTime, long certTime) { + this.success = success; + this.keyTime = keyTime; + this.certTime = certTime; + } + } + + /** + * Reads the private key and certificates specified in the path locations. Updates {@code key} and + * {@code cert} if both of their modified time changed since last read. + * + * @param keyFile the file on disk holding the private key + * @param certFile the file on disk holding the certificate chain + * @param oldKeyTime the time when the private key file is modified during last execution + * @param oldCertTime the time when the certificate chain file is modified during last execution + * @return the result of this update execution + */ + private UpdateResult readAndUpdate(File keyFile, File certFile, long oldKeyTime, long oldCertTime) + throws IOException, CertificateException, NoSuchAlgorithmException, InvalidKeySpecException { + long newKeyTime = keyFile.lastModified(); + long newCertTime = certFile.lastModified(); + // We only update when both the key and the certs are updated. + if (newKeyTime != oldKeyTime && newCertTime != oldCertTime) { + FileInputStream keyInputStream = new FileInputStream(keyFile); + try { + PrivateKey key = CertificateUtils.getPrivateKey(keyInputStream); + FileInputStream certInputStream = new FileInputStream(certFile); + try { + X509Certificate[] certs = CertificateUtils.getX509Certificates(certInputStream); + updateIdentityCredentials(key, certs); + return new UpdateResult(true, newKeyTime, newCertTime); + } finally { + certInputStream.close(); + } + } finally { + keyInputStream.close(); + } + } + return new UpdateResult(false, oldKeyTime, oldCertTime); + } + + /** + * Mainly used to avoid throwing IO Exceptions in java.io.Closeable. + */ + public interface Closeable extends java.io.Closeable { + @Override public void close(); + } +} + diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java new file mode 100644 index 00000000000..ea8e74b1a9e --- /dev/null +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java @@ -0,0 +1,361 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import io.grpc.ExperimentalApi; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.net.Socket; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; + +/** + * AdvancedTlsX509TrustManager is an {@code X509ExtendedTrustManager} that allows users to configure + * advanced TLS features, such as root certificate reloading, peer cert custom verification, etc. + * For Android users: this class is only supported in API level 24 and above. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") +@IgnoreJRERequirement +public final class AdvancedTlsX509TrustManager extends X509ExtendedTrustManager { + private static final Logger log = Logger.getLogger(AdvancedTlsX509TrustManager.class.getName()); + + private final Verification verification; + private final SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier; + + // The delegated trust manager used to perform traditional certificate verification. + private volatile X509ExtendedTrustManager delegateManager = null; + + private AdvancedTlsX509TrustManager(Verification verification, + SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier) throws CertificateException { + this.verification = verification; + this.socketAndEnginePeerVerifier = socketAndEnginePeerVerifier; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException( + "Not enough information to validate peer. SSLEngine or Socket required."); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + checkTrusted(chain, authType, null, socket, false); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + checkTrusted(chain, authType, engine, null, false); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + checkTrusted(chain, authType, engine, null, true); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException( + "Not enough information to validate peer. SSLEngine or Socket required."); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + checkTrusted(chain, authType, null, socket, true); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + if (this.delegateManager == null) { + return new X509Certificate[0]; + } + return this.delegateManager.getAcceptedIssuers(); + } + + /** + * Uses the default trust certificates stored on user's local system. + * After this is used, functions that will provide new credential + * data(e.g. updateTrustCredentials(), updateTrustCredentialsFromFile()) should not be called. + */ + public void useSystemDefaultTrustCerts() throws CertificateException, KeyStoreException, + NoSuchAlgorithmException { + // Passing a null value of KeyStore would make {@code TrustManagerFactory} attempt to use + // system-default trust CA certs. + this.delegateManager = createDelegateTrustManager(null); + } + + /** + * Updates the current cached trust certificates as well as the key store. + * + * @param trustCerts the trust certificates that are going to be used + */ + public void updateTrustCredentials(X509Certificate[] trustCerts) throws CertificateException, + KeyStoreException, NoSuchAlgorithmException, IOException { + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); + int i = 1; + for (X509Certificate cert: trustCerts) { + String alias = Integer.toString(i); + keyStore.setCertificateEntry(alias, cert); + i++; + } + X509ExtendedTrustManager newDelegateManager = createDelegateTrustManager(keyStore); + this.delegateManager = newDelegateManager; + } + + private static X509ExtendedTrustManager createDelegateTrustManager(KeyStore keyStore) + throws CertificateException, KeyStoreException, NoSuchAlgorithmException { + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keyStore); + X509ExtendedTrustManager delegateManager = null; + TrustManager[] tms = tmf.getTrustManagers(); + // Iterate over the returned trust managers, looking for an instance of X509TrustManager. + // If found, use that as the delegate trust manager. + for (int j = 0; j < tms.length; j++) { + if (tms[j] instanceof X509ExtendedTrustManager) { + delegateManager = (X509ExtendedTrustManager) tms[j]; + break; + } + } + if (delegateManager == null) { + throw new CertificateException( + "Failed to find X509ExtendedTrustManager with default TrustManager algorithm " + + TrustManagerFactory.getDefaultAlgorithm()); + } + return delegateManager; + } + + private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine, + Socket socket, boolean checkingServer) throws CertificateException { + if (chain == null || chain.length == 0) { + throw new IllegalArgumentException( + "Want certificate verification but got null or empty certificates"); + } + if (sslEngine == null && socket == null) { + throw new CertificateException( + "Not enough information to validate peer. SSLEngine or Socket required."); + } + if (this.verification != Verification.InsecurelySkipAllVerification) { + X509ExtendedTrustManager currentDelegateManager = this.delegateManager; + if (currentDelegateManager == null) { + throw new CertificateException("No trust roots configured"); + } + if (checkingServer) { + String algorithm = this.verification == Verification.CertificateAndHostNameVerification + ? "HTTPS" : ""; + if (sslEngine != null) { + SSLParameters sslParams = sslEngine.getSSLParameters(); + sslParams.setEndpointIdentificationAlgorithm(algorithm); + sslEngine.setSSLParameters(sslParams); + currentDelegateManager.checkServerTrusted(chain, authType, sslEngine); + } else { + if (!(socket instanceof SSLSocket)) { + throw new CertificateException("socket is not a type of SSLSocket"); + } + SSLSocket sslSocket = (SSLSocket)socket; + SSLParameters sslParams = sslSocket.getSSLParameters(); + sslParams.setEndpointIdentificationAlgorithm(algorithm); + sslSocket.setSSLParameters(sslParams); + currentDelegateManager.checkServerTrusted(chain, authType, sslSocket); + } + } else { + currentDelegateManager.checkClientTrusted(chain, authType, sslEngine); + } + } + // Perform the additional peer cert check. + if (socketAndEnginePeerVerifier != null) { + if (sslEngine != null) { + socketAndEnginePeerVerifier.verifyPeerCertificate(chain, authType, sslEngine); + } else { + socketAndEnginePeerVerifier.verifyPeerCertificate(chain, authType, socket); + } + } + } + + /** + * Schedules a {@code ScheduledExecutorService} to read trust certificates from a local file path + * periodically, and update the cached trust certs if there is an update. + * + * @param trustCertFile the file on disk holding the trust certificates + * @param period the period between successive read-and-update executions + * @param unit the time unit of the initialDelay and period parameters + * @param executor the execute service we use to read and update the credentials + * @return an object that caller should close when the file refreshes are not needed + */ + public Closeable updateTrustCredentialsFromFile(File trustCertFile, long period, TimeUnit unit, + ScheduledExecutorService executor) { + final ScheduledFuture future = + executor.scheduleWithFixedDelay( + new LoadFilePathExecution(trustCertFile), 0, period, unit); + return new Closeable() { + @Override public void close() { + future.cancel(false); + } + }; + } + + private class LoadFilePathExecution implements Runnable { + File file; + long currentTime; + + public LoadFilePathExecution(File file) { + this.file = file; + this.currentTime = 0; + } + + @Override + public void run() { + try { + this.currentTime = readAndUpdate(this.file, this.currentTime); + } catch (CertificateException | IOException | KeyStoreException + | NoSuchAlgorithmException e) { + log.log(Level.SEVERE, "Failed refreshing trust CAs from file. Using previous CAs", e); + } + } + } + + /** + * Reads the trust certificates specified in the path location, and update the key store if the + * modified time has changed since last read. + * + * @param trustCertFile the file on disk holding the trust certificates + * @param oldTime the time when the trust file is modified during last execution + * @return oldTime if failed or the modified time is not changed, otherwise the new modified time + */ + private long readAndUpdate(File trustCertFile, long oldTime) + throws CertificateException, IOException, KeyStoreException, NoSuchAlgorithmException { + long newTime = trustCertFile.lastModified(); + if (newTime == oldTime) { + return oldTime; + } + FileInputStream inputStream = new FileInputStream(trustCertFile); + try { + X509Certificate[] certificates = CertificateUtils.getX509Certificates(inputStream); + updateTrustCredentials(certificates); + return newTime; + } finally { + inputStream.close(); + } + } + + // Mainly used to avoid throwing IO Exceptions in java.io.Closeable. + public interface Closeable extends java.io.Closeable { + @Override public void close(); + } + + public static Builder newBuilder() { + return new Builder(); + } + + // The verification mode when authenticating the peer certificate. + public enum Verification { + // This is the DEFAULT and RECOMMENDED mode for most applications. + // Setting this on the client side will do the certificate and hostname verification, while + // setting this on the server side will only do the certificate verification. + CertificateAndHostNameVerification, + // This SHOULD be chosen only when you know what the implication this will bring, and have a + // basic understanding about TLS. + // It SHOULD be accompanied with proper additional peer identity checks set through + // {@code PeerVerifier}(nit: why this @code not working?). Failing to do so will leave + // applications to MITM attack. + // Also note that this will only take effect if the underlying SDK implementation invokes + // checkClientTrusted/checkServerTrusted with the {@code SSLEngine} parameter while doing + // verification. + // Setting this on either side will only do the certificate verification. + CertificateOnlyVerification, + // Setting is very DANGEROUS. Please try to avoid this in a real production environment, unless + // you are a super advanced user intended to re-implement the whole verification logic on your + // own. A secure verification might include: + // 1. proper verification on the peer certificate chain + // 2. proper checks on the identity of the peer certificate + InsecurelySkipAllVerification, + } + + // Additional custom peer verification check. + // It will be used when checkClientTrusted/checkServerTrusted is called with the {@code Socket} or + // the {@code SSLEngine} parameter. + public interface SslSocketAndEnginePeerVerifier { + /** + * Verifies the peer certificate chain. For more information, please refer to + * {@code X509ExtendedTrustManager}. + * + * @param peerCertChain the certificate chain sent from the peer + * @param authType the key exchange algorithm used, e.g. "RSA", "DHE_DSS", etc + * @param socket the socket used for this connection. This parameter can be null, which + * indicates that implementations need not check the ssl parameters + */ + void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, Socket socket) + throws CertificateException; + + /** + * Verifies the peer certificate chain. For more information, please refer to + * {@code X509ExtendedTrustManager}. + * + * @param peerCertChain the certificate chain sent from the peer + * @param authType the key exchange algorithm used, e.g. "RSA", "DHE_DSS", etc + * @param engine the engine used for this connection. This parameter can be null, which + * indicates that implementations need not check the ssl parameters + */ + void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, SSLEngine engine) + throws CertificateException; + } + + public static final class Builder { + + private Verification verification = Verification.CertificateAndHostNameVerification; + private SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier; + + private Builder() {} + + public Builder setVerification(Verification verification) { + this.verification = verification; + return this; + } + + public Builder setSslSocketAndEnginePeerVerifier(SslSocketAndEnginePeerVerifier verifier) { + this.socketAndEnginePeerVerifier = verifier; + return this; + } + + public AdvancedTlsX509TrustManager build() throws CertificateException { + return new AdvancedTlsX509TrustManager(this.verification, this.socketAndEnginePeerVerifier); + } + } +} + diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java new file mode 100644 index 00000000000..294fbcd4a9a --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -0,0 +1,477 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.TlsServerCredentials.ClientAuth; +import io.grpc.internal.testing.TestUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.util.AdvancedTlsX509KeyManager; +import io.grpc.util.AdvancedTlsX509TrustManager; +import io.grpc.util.AdvancedTlsX509TrustManager.SslSocketAndEnginePeerVerifier; +import io.grpc.util.AdvancedTlsX509TrustManager.Verification; +import io.grpc.util.CertificateUtils; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.net.Socket; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLEngine; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +public class AdvancedTlsTest { + public static final String SERVER_0_KEY_FILE = "server0.key"; + public static final String SERVER_0_PEM_FILE = "server0.pem"; + public static final String CLIENT_0_KEY_FILE = "client.key"; + public static final String CLIENT_0_PEM_FILE = "client.pem"; + public static final String CA_PEM_FILE = "ca.pem"; + public static final String SERVER_BAD_KEY_FILE = "badserver.key"; + public static final String SERVER_BAD_PEM_FILE = "badserver.pem"; + + private ScheduledExecutorService executor; + private Server server; + private ManagedChannel channel; + + private File caCertFile; + private File serverKey0File; + private File serverCert0File; + private File clientKey0File; + private File clientCert0File; + private X509Certificate[] caCert; + private PrivateKey serverKey0; + private X509Certificate[] serverCert0; + private PrivateKey clientKey0; + private X509Certificate[] clientCert0; + private PrivateKey serverKeyBad; + private X509Certificate[] serverCertBad; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() + throws NoSuchAlgorithmException, IOException, CertificateException, InvalidKeySpecException { + executor = Executors.newSingleThreadScheduledExecutor(); + caCertFile = TestUtils.loadCert(CA_PEM_FILE); + serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); + serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE); + clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE); + caCert = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + CA_PEM_FILE)); + serverKey0 = CertificateUtils.getPrivateKey( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_0_KEY_FILE)); + serverCert0 = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_0_PEM_FILE)); + clientKey0 = CertificateUtils.getPrivateKey( + TestUtils.class.getResourceAsStream("/certs/" + CLIENT_0_KEY_FILE)); + clientCert0 = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + CLIENT_0_PEM_FILE)); + serverKeyBad = CertificateUtils.getPrivateKey( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_BAD_KEY_FILE)); + serverCertBad = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_BAD_PEM_FILE)); + } + + @After + public void tearDown() { + if (server != null) { + server.shutdown(); + } + if (channel != null) { + channel.shutdown(); + } + MoreExecutors.shutdownAndAwaitTermination(executor, 5, TimeUnit.SECONDS); + } + + @Test + public void basicMutualTlsTest() throws Exception { + // Create & start a server. + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverCert0File, serverKey0File).trustManager(caCertFile) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + // Create a client to connect. + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientCert0File, clientKey0File).trustManager(caCertFile).build(); + channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au").build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + // Send an actual request, via the full GRPC & network stack, and check that a proper + // response comes back. + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + e.printStackTrace(); + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { + // Create a server with the key manager and trust manager. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .build(); + serverTrustManager.updateTrustCredentials(caCert); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + // Create a client with the key manager and trust manager. + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateAndHostNameVerification) + .build(); + clientTrustManager.updateTrustCredentials(caCert); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au").build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void trustManagerCustomVerifierMutualTlsTest() throws Exception { + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + // Set server's custom verification based on the information of clientCert0. + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("testclient")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("testclient")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + }) + .build(); + serverTrustManager.updateTrustCredentials(caCert); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + // Set client's custom verification based on the information of serverCert0. + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("*.test.google.com.au")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("*.test.google.com.au")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + }) + .build(); + clientTrustManager.updateTrustCredentials(caCert); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), channelCredentials).build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void trustManagerInsecurelySkipAllTest() throws Exception { + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + // Even if we provide bad credentials for the server, the test should still pass, because we + // will configure the client to skip all checks later. + serverKeyManager.updateIdentityCredentials(serverKeyBad, serverCertBad); + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { } + }) + .build(); + serverTrustManager.updateTrustCredentials(caCert); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + // Set the client to skip all checks, including traditional certificate verification. + // Note this is very dangerous in production environment - only do so if you are confident on + // what you are doing! + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.InsecurelySkipAllVerification) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { } + }) + .build(); + clientTrustManager.updateTrustCredentials(caCert); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), channelCredentials).build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { + // Create & start a server. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + Closeable serverKeyShutdown = serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, + serverCert0File, 100, TimeUnit.MILLISECONDS, executor); + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .build(); + Closeable serverTrustShutdown = serverTrustManager.updateTrustCredentialsFromFile(caCertFile, + 100, TimeUnit.MILLISECONDS, executor); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + // Create a client to connect. + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + Closeable clientKeyShutdown = clientKeyManager.updateIdentityCredentialsFromFile(clientKey0File, + clientCert0File,100, TimeUnit.MILLISECONDS, executor); + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateAndHostNameVerification) + .build(); + Closeable clientTrustShutdown = clientTrustManager.updateTrustCredentialsFromFile(caCertFile, + 100, TimeUnit.MILLISECONDS, executor); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au").build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + // Send an actual request, via the full GRPC & network stack, and check that a proper + // response comes back. + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + e.printStackTrace(); + fail("Find error: " + e.getMessage()); + } + // Clean up. + serverKeyShutdown.close(); + serverTrustShutdown.close(); + clientKeyShutdown.close(); + clientTrustShutdown.close(); + } + + @Test + public void keyManagerAliasesTest() throws Exception { + AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); + assertArrayEquals( + new String[] {"default"}, km.getClientAliases("", null)); + assertEquals( + "default", km.chooseClientAlias(new String[] {"default"}, null, null)); + assertArrayEquals( + new String[] {"default"}, km.getServerAliases("", null)); + assertEquals( + "default", km.chooseServerAlias("default", null, null)); + } + + @Test + public void trustManagerCheckTrustedWithSocketTest() throws Exception { + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.InsecurelySkipAllVerification).build(); + tm.updateTrustCredentials(caCert); + tm.checkClientTrusted(serverCert0, "RSA", new Socket()); + tm.useSystemDefaultTrustCerts(); + tm.checkServerTrusted(clientCert0, "RSA", new Socket()); + } + + @Test + public void trustManagerCheckClientTrustedWithoutParameterTest() throws Exception { + exceptionRule.expect(CertificateException.class); + exceptionRule.expectMessage( + "Not enough information to validate peer. SSLEngine or Socket required."); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.InsecurelySkipAllVerification).build(); + tm.checkClientTrusted(serverCert0, "RSA"); + } + + @Test + public void trustManagerCheckServerTrustedWithoutParameterTest() throws Exception { + exceptionRule.expect(CertificateException.class); + exceptionRule.expectMessage( + "Not enough information to validate peer. SSLEngine or Socket required."); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.InsecurelySkipAllVerification).build(); + tm.checkServerTrusted(serverCert0, "RSA"); + } + + @Test + public void trustManagerEmptyChainTest() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage( + "Want certificate verification but got null or empty certificates"); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .build(); + tm.updateTrustCredentials(caCert); + tm.checkClientTrusted(null, "RSA", (SSLEngine) null); + } + + @Test + public void trustManagerBadCustomVerificationTest() throws Exception { + exceptionRule.expect(CertificateException.class); + exceptionRule.expectMessage("Bad Custom Verification"); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CertificateOnlyVerification) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { + throw new CertificateException("Bad Custom Verification"); + } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { + throw new CertificateException("Bad Custom Verification"); + } + }).build(); + tm.updateTrustCredentials(caCert); + tm.checkClientTrusted(serverCert0, "RSA", new Socket()); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest req, StreamObserver respOb) { + respOb.onNext(SimpleResponse.getDefaultInstance()); + respOb.onCompleted(); + } + } +} From 8026ccde4bf193a34574c51682cec1c22575ec5b Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 18 Aug 2021 09:33:09 -0700 Subject: [PATCH 33/82] netty: Don't use old-style classpath for shadow plugin Seems it was introduced unnecessarily in dc74a31b. This also removes the jcenter reference which is a repository that no longer receives updates. --- netty/shaded/build.gradle | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index 521256ea13d..9cb3de9a252 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -4,16 +4,6 @@ import org.gradle.api.file.FileTreeElement import shadow.org.apache.tools.zip.ZipOutputStream import shadow.org.apache.tools.zip.ZipEntry - -buildscript { - repositories { - jcenter() - } - dependencies { - classpath "com.github.jengelman.gradle.plugins:shadow:6.1.0" - } -} - plugins { id "java" id "maven-publish" From e32e177d5aad38ba1160adf677798a1298a8f941 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 9 Aug 2021 19:14:31 -0700 Subject: [PATCH 34/82] xds: Avoid logging and throwing errors The FINE logging was just repeating the exceptions. But really, it is trivial to avoid exceptions in this case and that is beneficial because it will avoid an expensive error handling path in something that is trivial to trigger remotely. The WARNING may be a bit much if connections don't match the filter chains often in production, but it seems most likely a misconfiguration and not something that would be seen often. --- ...ilterChainMatchingProtocolNegotiators.java | 25 +++++++++---------- ...rChainMatchingProtocolNegotiatorsTest.java | 15 ++++------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index 34211d79751..596510ef05f 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -38,6 +38,7 @@ import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.internal.Matchers.CidrMatcher; import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; @@ -59,6 +60,7 @@ /** * Handles L4 filter chain match for the connection based on the xds configuration. * */ +@SuppressWarnings("FutureReturnValueIgnored") // Netty doesn't follow this pattern final class FilterChainMatchingProtocolNegotiators { private static final Logger log = Logger.getLogger( FilterChainMatchingProtocolNegotiators.class.getName()); @@ -89,14 +91,13 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc super.userEventTriggered(ctx, evt); return; } - SelectedConfig config; - try { - config = selector.select( - (InetSocketAddress) ctx.channel().localAddress(), - (InetSocketAddress) ctx.channel().remoteAddress()); - } catch (IllegalStateException ex) { - log.log(Level.FINE, "Did not find exactly one filter chain: " + ex.getMessage()); - ctx.fireExceptionCaught(ex); + SelectedConfig config = selector.select( + (InetSocketAddress) ctx.channel().localAddress(), + (InetSocketAddress) ctx.channel().remoteAddress()); + if (config == null) { + log.log(Level.WARNING, "Connection from {0} to {1} has no matching filter chain. Closing", + new Object[] {ctx.channel().remoteAddress(), ctx.channel().localAddress()}); + ctx.close().addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); return; } ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent) evt; @@ -149,9 +150,8 @@ SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) filterChains = filterOnSourcePort(filterChains, remoteAddr.getPort()); if (filterChains.size() > 1) { - log.log(Level.FINE, "Found more than one matching filter chains: {0}", filterChains); - throw new IllegalStateException("Found more than one matching filter chains."); - // TODO(chengyuanzhang): should we just return any matched one? + throw new IllegalStateException("Found more than one matching filter chains. This should " + + "not be possible as ClientXdsClient validated the chains for uniqueness."); } if (filterChains.size() == 1) { FilterChain selected = Iterables.getOnlyElement(filterChains); @@ -160,8 +160,7 @@ SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) if (defaultSslContextProviderSupplier != null) { return new SelectedConfig(defaultSslContextProviderSupplier); } - log.log(Level.FINE, "No matching filter chain found."); - throw new IllegalStateException("No matching filter chain found."); + return null; } // reject if filer-chain-match has non-empty application_protocols diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index c926acc9fcc..7ae21901d09 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -160,16 +160,10 @@ public void noFilterChainMatch_noDefaultSslContext() { channelHandlerCtx = pipeline.context(filterChainMatchingHandler); assertThat(channelHandlerCtx).isNotNull(); + assertThat(channel.closeFuture().isDone()).isFalse(); pipeline.fireUserEventTriggered(event); - channelHandlerCtx = pipeline.context(filterChainMatchingHandler); - assertThat(channelHandlerCtx).isNotNull(); - try { - channel.checkException(); - fail("exception expected!"); - } catch (Exception e) { - assertThat(e).isInstanceOf(IllegalStateException.class); - assertThat(e).hasMessageThat().contains("No matching filter chain found."); - } + channel.runPendingTasks(); + assertThat(channel.closeFuture().isDone()).isTrue(); } @Test @@ -943,7 +937,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { channel.checkException(); fail("expect exception!"); } catch (IllegalStateException ise) { - assertThat(ise).hasMessageThat().isEqualTo("Found more than one matching filter chains."); + assertThat(ise).hasMessageThat().isEqualTo("Found more than one matching filter chains. This " + + "should not be possible as ClientXdsClient validated the chains for uniqueness."); assertThat(sslSet.isDone()).isFalse(); channelHandlerCtx = pipeline.context(filterChainMatchingHandler); assertThat(channelHandlerCtx).isNotNull(); From 29172a96657f3875b900879d37f808014cd6b123 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Fri, 20 Aug 2021 11:02:03 -0700 Subject: [PATCH 35/82] interop-testing: fix misleading log message (#8426) `logger.log(Level.WARNING, "Rpc failed: {0}", t)` will just print a literal "Rpc failed: {0}" followed by exception details. --- .../main/java/io/grpc/testing/integration/XdsTestClient.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java index d48be9f5031..087152dca64 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java @@ -369,7 +369,7 @@ public void onCompleted() { @Override public void onError(Throwable t) { if (printResponse) { - logger.log(Level.WARNING, "Rpc failed: {0}", t); + logger.log(Level.WARNING, "Rpc failed", t); } handleRpcError(requestId, config.rpcType, Status.fromThrowable(t), savedWatchers); From c54fcba2eede2160cab127a19fa9cdde0b7f1e02 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Fri, 20 Aug 2021 12:12:54 -0700 Subject: [PATCH 36/82] Extend the xds_url_map job's timeout to 90 minutes (#8429) As title. We recently had one flake caused by the Kokoro job timeout. --- buildscripts/kokoro/xds_url_map.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/buildscripts/kokoro/xds_url_map.cfg b/buildscripts/kokoro/xds_url_map.cfg index 36ff8398b0c..4b5be84f880 100644 --- a/buildscripts/kokoro/xds_url_map.cfg +++ b/buildscripts/kokoro/xds_url_map.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/xds_url_map.sh" -timeout_mins: 60 +timeout_mins: 90 action { define_artifacts { From e45aab085c45eed46c94cd705e2e53cdf69a3f29 Mon Sep 17 00:00:00 2001 From: Terry Wilson Date: Fri, 20 Aug 2021 14:42:01 -0700 Subject: [PATCH 37/82] core: Don't mark calls as cancelled if they are successfully completed. (#8408) The semantics around cancel vary slightly between ServerCall and CancellableContext - the context should always be cancelled regardless of the outcome of the call while the ServerCall should only be cancelled on a non-OK status. This fixes a bug where the ServerCall was always marked cancelled regardless of call status. Fixes #5882 --- core/src/main/java/io/grpc/internal/ServerCallImpl.java | 8 +++++++- .../test/java/io/grpc/internal/ServerCallImplTest.java | 4 ++++ core/src/test/java/io/grpc/internal/ServerImplTest.java | 6 ++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 6f123e76678..f82d87cade0 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -279,7 +279,11 @@ public ServerStreamListenerImpl( new Context.CancellationListener() { @Override public void cancelled(Context context) { - ServerStreamListenerImpl.this.call.cancelled = true; + // If the context has a cancellation cause then something exceptional happened + // and we should also mark the call as cancelled. + if (context.cancellationCause() != null) { + ServerStreamListenerImpl.this.call.cancelled = true; + } } }, MoreExecutors.directExecutor()); @@ -355,6 +359,8 @@ private void closedInternal(Status status) { } finally { // Cancel context after delivering RPC closure notification to allow the application to // clean up and update any state based on whether onComplete or onCancel was called. + // Note that in failure situations JumpToApplicationThreadServerStreamListener has already + // closed the context. In these situations this cancel() call will be a no-op. context.cancel(null); } } diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 5130bc05aa7..ea49b94e8aa 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -18,6 +18,7 @@ import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -378,6 +379,9 @@ public void streamListener_closedOk() { verify(callListener).onComplete(); assertTrue(context.isCancelled()); assertNull(context.cancellationCause()); + // The call considers cancellation to be an exceptional situation so it should + // not be cancelled with an OK status. + assertFalse(call.isCancelled()); } @Test diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 2a9dbd5a1fe..0f5c510f97c 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -1196,7 +1196,9 @@ public void testStreamClose_clientOkTriggersDelayedCancellation() throws Excepti context, contextCancelled, null); // For close status OK: - // isCancelled is expected to be true after all pending work is done + // The context isCancelled is expected to be true after all pending work is done, + // but for the call it should be false as it gets set cancelled only if the call + // fails with a non-OK status. assertFalse(callReference.get().isCancelled()); assertFalse(context.get().isCancelled()); streamListener.closed(Status.OK); @@ -1204,7 +1206,7 @@ public void testStreamClose_clientOkTriggersDelayedCancellation() throws Excepti assertFalse(context.get().isCancelled()); assertEquals(1, executor.runDueTasks()); - assertTrue(callReference.get().isCancelled()); + assertFalse(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); assertTrue(contextCancelled.get()); } From cae23393668bdabc3a0e438780ed1bc07e87c1dd Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Tue, 24 Aug 2021 11:27:02 -0700 Subject: [PATCH 38/82] xds: fix RingHash LB null pointer issue (#8438) --- xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index e91e76090ab..036f77f7cd1 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -188,8 +188,10 @@ private void handleClusterDiscovered() { if (root.result.lbPolicy() == LbPolicy.RING_HASH) { lbProvider = lbRegistry.getProvider("ring_hash"); lbConfig = new RingHashConfig(root.result.minRingSize(), root.result.maxRingSize()); - } else { + } + if (lbProvider == null) { lbProvider = lbRegistry.getProvider("round_robin"); + lbConfig = null; } ClusterResolverConfig config = new ClusterResolverConfig( Collections.unmodifiableList(instances), new PolicySelection(lbProvider, lbConfig)); From 6776fa7c8b22e4e44d6e0d77e1361259b8a62ffb Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Tue, 24 Aug 2021 13:09:33 -0700 Subject: [PATCH 39/82] xds: enable ring hash by default (#8442) --- .../main/java/io/grpc/xds/RingHashLoadBalancerProvider.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index af613b26078..102a18000a9 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -17,6 +17,7 @@ package io.grpc.xds; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -45,7 +46,8 @@ public final class RingHashLoadBalancerProvider extends LoadBalancerProvider { static final long MAX_RING_SIZE = 8 * 1024 * 1024L; private static final boolean enableRingHash = - Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")); + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")); @Override public LoadBalancer newLoadBalancer(Helper helper) { From fddc6552b392ae921fc08cbd494606bb83657e1d Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Tue, 24 Aug 2021 14:58:14 -0700 Subject: [PATCH 40/82] upgrade cronet to 92.4515.131 (#8445) --- build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index 393b91a30bd..adae5b70d1b 100644 --- a/build.gradle +++ b/build.gradle @@ -146,8 +146,8 @@ subprojects { autovalue: "com.google.auto.value:auto-value:${autovalueVersion}", autovalue_annotation: "com.google.auto.value:auto-value-annotations:${autovalueVersion}", errorprone: "com.google.errorprone:error_prone_annotations:2.4.0", - cronet_api: 'org.chromium.net:cronet-api:76.3809.111', - cronet_embedded: 'org.chromium.net:cronet-embedded:66.3359.158', + cronet_api: 'org.chromium.net:cronet-api:92.4515.131', + cronet_embedded: 'org.chromium.net:cronet-embedded:92.4515.131', gson: "com.google.code.gson:gson:2.8.6", guava: "com.google.guava:guava:${guavaVersion}", javax_annotation: 'org.apache.tomcat:annotations-api:6.0.53', From 48219d902a159246cc870fbed54fda58692691d9 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Tue, 24 Aug 2021 16:33:12 -0700 Subject: [PATCH 41/82] fix import warning (#8441) --- netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java | 3 +++ .../xds/FilterChainMatchingProtocolNegotiatorsTest.java | 6 ++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index 294fbcd4a9a..df76481a12d 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -59,7 +59,10 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +@RunWith(JUnit4.class) public class AdvancedTlsTest { public static final String SERVER_0_KEY_FILE = "server0.key"; public static final String SERVER_0_PEM_FILE = "server0.pem"; diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index 7ae21901d09..21224f73885 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -118,7 +118,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { } @Test - @SuppressWarnings("unchecked") public void nofilterChainMatch_defaultSslContext() throws Exception { final SettableFuture sslSet = SettableFuture.create(); ChannelHandler next = new ChannelInboundHandlerAdapter() { @@ -132,7 +131,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); SslContextProviderSupplier ssl = new SslContextProviderSupplier(createTls(), tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Collections.EMPTY_LIST, ssl); + FilterChainSelector selector = new FilterChainSelector(new ArrayList(), ssl); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); @@ -151,9 +150,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { } @Test - @SuppressWarnings("unchecked") public void noFilterChainMatch_noDefaultSslContext() { - FilterChainSelector selector = new FilterChainSelector(Collections.EMPTY_LIST, null); + FilterChainSelector selector = new FilterChainSelector(new ArrayList(), null); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); From 8a5694b7f8d13fa317ae51b4cd5a961ed46adda3 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Wed, 25 Aug 2021 10:58:42 -0700 Subject: [PATCH 42/82] Update README etc to reference 1.40.1 (#8448) --- README.md | 30 ++++++++++++------------ cronet/README.md | 2 +- documentation/android-channel-builder.md | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 10e437bed1a..6611b0ef1af 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.40.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.40.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.40.1/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.40.1/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,17 +43,17 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.40.0 + 1.40.1 io.grpc grpc-protobuf - 1.40.0 + 1.40.1 io.grpc grpc-stub - 1.40.0 + 1.40.1 org.apache.tomcat @@ -65,23 +65,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -implementation 'io.grpc:grpc-netty-shaded:1.40.0' -implementation 'io.grpc:grpc-protobuf:1.40.0' -implementation 'io.grpc:grpc-stub:1.40.0' +implementation 'io.grpc:grpc-netty-shaded:1.40.1' +implementation 'io.grpc:grpc-protobuf:1.40.1' +implementation 'io.grpc:grpc-stub:1.40.1' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.40.0' -implementation 'io.grpc:grpc-protobuf-lite:1.40.0' -implementation 'io.grpc:grpc-stub:1.40.0' +implementation 'io.grpc:grpc-okhttp:1.40.1' +implementation 'io.grpc:grpc-protobuf-lite:1.40.1' +implementation 'io.grpc:grpc-stub:1.40.1' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.40.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.40.1 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -113,7 +113,7 @@ For protobuf-based codegen integrated with the Maven build system, you can use com.google.protobuf:protoc:3.17.3:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.40.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.40.1:exe:${os.detected.classifier} @@ -143,7 +143,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.40.1' } } generateProtoTasks { @@ -176,7 +176,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.40.1' } } generateProtoTasks { diff --git a/cronet/README.md b/cronet/README.md index 580c1e4acbe..c982604bdac 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.40.0' +implementation 'io.grpc:grpc-cronet:1.40.1' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index fc491af8498..113b20159b9 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.40.0' -implementation 'io.grpc:grpc-okhttp:1.40.0' +implementation 'io.grpc:grpc-android:1.40.1' +implementation 'io.grpc:grpc-okhttp:1.40.1' ``` You also need permission to access the device's network state in your From 3cb0696b1fba1b4eece4024a96e5ed88ecbe2517 Mon Sep 17 00:00:00 2001 From: ZhenLian Date: Wed, 25 Aug 2021 16:13:09 -0700 Subject: [PATCH 43/82] advancedtls: change enum to use UPPER_SNAKE_CASE (#8446) --- .../util/AdvancedTlsX509TrustManager.java | 12 ++++----- .../java/io/grpc/netty/AdvancedTlsTest.java | 26 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java index ea8e74b1a9e..f6e366d3219 100644 --- a/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java @@ -171,13 +171,13 @@ private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine ss throw new CertificateException( "Not enough information to validate peer. SSLEngine or Socket required."); } - if (this.verification != Verification.InsecurelySkipAllVerification) { + if (this.verification != Verification.INSECURELY_SKIP_ALL_VERIFICATION) { X509ExtendedTrustManager currentDelegateManager = this.delegateManager; if (currentDelegateManager == null) { throw new CertificateException("No trust roots configured"); } if (checkingServer) { - String algorithm = this.verification == Verification.CertificateAndHostNameVerification + String algorithm = this.verification == Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION ? "HTTPS" : ""; if (sslEngine != null) { SSLParameters sslParams = sslEngine.getSSLParameters(); @@ -288,7 +288,7 @@ public enum Verification { // This is the DEFAULT and RECOMMENDED mode for most applications. // Setting this on the client side will do the certificate and hostname verification, while // setting this on the server side will only do the certificate verification. - CertificateAndHostNameVerification, + CERTIFICATE_AND_HOST_NAME_VERIFICATION, // This SHOULD be chosen only when you know what the implication this will bring, and have a // basic understanding about TLS. // It SHOULD be accompanied with proper additional peer identity checks set through @@ -298,13 +298,13 @@ public enum Verification { // checkClientTrusted/checkServerTrusted with the {@code SSLEngine} parameter while doing // verification. // Setting this on either side will only do the certificate verification. - CertificateOnlyVerification, + CERTIFICATE_ONLY_VERIFICATION, // Setting is very DANGEROUS. Please try to avoid this in a real production environment, unless // you are a super advanced user intended to re-implement the whole verification logic on your // own. A secure verification might include: // 1. proper verification on the peer certificate chain // 2. proper checks on the identity of the peer certificate - InsecurelySkipAllVerification, + INSECURELY_SKIP_ALL_VERIFICATION, } // Additional custom peer verification check. @@ -338,7 +338,7 @@ void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, SSL public static final class Builder { - private Verification verification = Verification.CertificateAndHostNameVerification; + private Verification verification = Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION; private SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier; private Builder() {} diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index df76481a12d..7dd5ec75e54 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -161,7 +161,7 @@ public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); serverTrustManager.updateTrustCredentials(caCert); ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() @@ -174,7 +174,7 @@ public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateAndHostNameVerification) + .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) .build(); clientTrustManager.updateTrustCredentials(caCert); ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() @@ -198,7 +198,7 @@ public void trustManagerCustomVerifierMutualTlsTest() throws Exception { serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); // Set server's custom verification based on the information of clientCert0. AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .setSslSocketAndEnginePeerVerifier( new SslSocketAndEnginePeerVerifier() { @Override @@ -238,7 +238,7 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); // Set client's custom verification based on the information of serverCert0. AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .setSslSocketAndEnginePeerVerifier( new SslSocketAndEnginePeerVerifier() { @Override @@ -289,7 +289,7 @@ public void trustManagerInsecurelySkipAllTest() throws Exception { // will configure the client to skip all checks later. serverKeyManager.updateIdentityCredentials(serverKeyBad, serverCertBad); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .setSslSocketAndEnginePeerVerifier( new SslSocketAndEnginePeerVerifier() { @Override @@ -315,7 +315,7 @@ public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authTy // Note this is very dangerous in production environment - only do so if you are confident on // what you are doing! AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.InsecurelySkipAllVerification) + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION) .setSslSocketAndEnginePeerVerifier( new SslSocketAndEnginePeerVerifier() { @Override @@ -350,7 +350,7 @@ public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { Closeable serverKeyShutdown = serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, serverCert0File, 100, TimeUnit.MILLISECONDS, executor); AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); Closeable serverTrustShutdown = serverTrustManager.updateTrustCredentialsFromFile(caCertFile, 100, TimeUnit.MILLISECONDS, executor); @@ -365,7 +365,7 @@ public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { Closeable clientKeyShutdown = clientKeyManager.updateIdentityCredentialsFromFile(clientKey0File, clientCert0File,100, TimeUnit.MILLISECONDS, executor); AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateAndHostNameVerification) + .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) .build(); Closeable clientTrustShutdown = clientTrustManager.updateTrustCredentialsFromFile(caCertFile, 100, TimeUnit.MILLISECONDS, executor); @@ -407,7 +407,7 @@ public void keyManagerAliasesTest() throws Exception { @Test public void trustManagerCheckTrustedWithSocketTest() throws Exception { AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.InsecurelySkipAllVerification).build(); + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); tm.updateTrustCredentials(caCert); tm.checkClientTrusted(serverCert0, "RSA", new Socket()); tm.useSystemDefaultTrustCerts(); @@ -420,7 +420,7 @@ public void trustManagerCheckClientTrustedWithoutParameterTest() throws Exceptio exceptionRule.expectMessage( "Not enough information to validate peer. SSLEngine or Socket required."); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.InsecurelySkipAllVerification).build(); + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); tm.checkClientTrusted(serverCert0, "RSA"); } @@ -430,7 +430,7 @@ public void trustManagerCheckServerTrustedWithoutParameterTest() throws Exceptio exceptionRule.expectMessage( "Not enough information to validate peer. SSLEngine or Socket required."); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.InsecurelySkipAllVerification).build(); + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); tm.checkServerTrusted(serverCert0, "RSA"); } @@ -440,7 +440,7 @@ public void trustManagerEmptyChainTest() throws Exception { exceptionRule.expectMessage( "Want certificate verification but got null or empty certificates"); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .build(); tm.updateTrustCredentials(caCert); tm.checkClientTrusted(null, "RSA", (SSLEngine) null); @@ -451,7 +451,7 @@ public void trustManagerBadCustomVerificationTest() throws Exception { exceptionRule.expect(CertificateException.class); exceptionRule.expectMessage("Bad Custom Verification"); AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() - .setVerification(Verification.CertificateOnlyVerification) + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) .setSslSocketAndEnginePeerVerifier( new SslSocketAndEnginePeerVerifier() { @Override From f1b699bbf1684a6c30e8d0da99e7e93a2dc784fe Mon Sep 17 00:00:00 2001 From: Alexander Polcyn Date: Wed, 25 Aug 2021 12:57:28 -0700 Subject: [PATCH 44/82] Update default XDS server name in C2P resolver --- .../main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java | 2 +- .../java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java index 55ec809772f..9abdf12f175 100644 --- a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java @@ -184,7 +184,7 @@ public void run() { ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); } ImmutableMap.Builder serverBuilder = ImmutableMap.builder(); - String server_uri = "directpath-trafficdirector.googleapis.com"; + String server_uri = "directpath-pa.googleapis.com"; if (serverUriOverride != null && serverUriOverride.length() > 0) { server_uri = serverUriOverride; } diff --git a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java b/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java index 07e957b24c4..421b2a1dd0a 100644 --- a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java @@ -189,7 +189,7 @@ public void onGcpAndNoProvidedBootstrapDelegateToXds() { Map server = Iterables.getOnlyElement( (List>) bootstrap.get("xds_servers")); assertThat(server).containsExactly( - "server_uri", "directpath-trafficdirector.googleapis.com", + "server_uri", "directpath-pa.googleapis.com", "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), "server_features", ImmutableList.of("xds_v3")); } From df4ac5973c6d382175feea609d684c26f467bc2f Mon Sep 17 00:00:00 2001 From: Terry Wilson Date: Thu, 26 Aug 2021 14:42:27 -0700 Subject: [PATCH 45/82] core: Exit idle mode in enterIdle() if there are pending calls or delayed transport. This change assures that if there are only calls in real transport the channel will remain in idle mode. Idle mode will be exited if there are calls in delayed transport to allow them to be processed. --- .../grpc/internal/InUseStateAggregator.java | 15 ++++ .../io/grpc/internal/ManagedChannelImpl.java | 5 +- .../internal/InUseStateAggregatorTest.java | 63 ++++++++++++++++ .../ManagedChannelImplIdlenessTest.java | 75 +++++++++++++++++++ 4 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java diff --git a/core/src/main/java/io/grpc/internal/InUseStateAggregator.java b/core/src/main/java/io/grpc/internal/InUseStateAggregator.java index f4f3a186d88..f3d870e8797 100644 --- a/core/src/main/java/io/grpc/internal/InUseStateAggregator.java +++ b/core/src/main/java/io/grpc/internal/InUseStateAggregator.java @@ -53,6 +53,21 @@ public final boolean isInUse() { return !inUseObjects.isEmpty(); } + /** + * Returns {@code true} if any of the given objects are in use. + * + * @param objects The objects to consider. + * @return {@code true} if any of the given objects are in use. + */ + public final boolean anyObjectInUse(Object... objects) { + for (Object object : objects) { + if (inUseObjects.contains(object)) { + return true; + } + } + return false; + } + /** * Called when the aggregated in-use state has changed to true, which means at least one object is * in use. diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 6cd5598e2a6..2e079078fc7 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -423,7 +423,10 @@ private void enterIdleMode() { delayedTransport.reprocess(null); channelLogger.log(ChannelLogLevel.INFO, "Entering IDLE state"); channelStateManager.gotoState(IDLE); - if (inUseStateAggregator.isInUse()) { + // If the inUseStateAggregator still considers pending calls to be queued up or the delayed + // transport to be holding some we need to exit idle mode to give these calls a chance to + // be processed. + if (inUseStateAggregator.anyObjectInUse(pendingCallsInUseObject, delayedTransport)) { exitIdleMode(); } } diff --git a/core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java b/core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java new file mode 100644 index 00000000000..e1bbc063ea3 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertTrue; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link InUseStateAggregator}. + */ +@RunWith(JUnit4.class) +public class InUseStateAggregatorTest { + + private InUseStateAggregator aggregator; + + @Before + public void setUp() { + aggregator = new InUseStateAggregator() { + @Override + protected void handleInUse() { + } + + @Override + protected void handleNotInUse() { + } + }; + } + + @Test + public void anyObjectInUse() { + String objectOne = "1"; + String objectTwo = "2"; + String objectThree = "3"; + + aggregator.updateObjectInUse(objectOne, true); + assertTrue(aggregator.anyObjectInUse(objectOne)); + + aggregator.updateObjectInUse(objectTwo, true); + aggregator.updateObjectInUse(objectThree, true); + assertTrue(aggregator.anyObjectInUse(objectOne, objectTwo, objectThree)); + + aggregator.updateObjectInUse(objectTwo, false); + assertTrue(aggregator.anyObjectInUse(objectOne, objectThree)); + } +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index 6a76f75c8b7..30e137cba22 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -26,7 +26,9 @@ import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atMostOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -284,6 +286,35 @@ public void delayedTransportHoldsOffIdleness() throws Exception { verify(mockLoadBalancer).shutdown(); } + @Test + public void pendingCallExitsIdleAfterEnter() throws Exception { + // Create a pending call without starting it. + channel.newCall(method, CallOptions.DEFAULT); + + channel.enterIdle(); + + // Just the existence of a non-started, pending call means the channel cannot stay + // in idle mode because the expectation is that the pending call will also need to + // be handled. + verify(mockNameResolver, times(2)).start(any(NameResolver.Listener2.class)); + } + + @Test + public void delayedTransportExitsIdleAfterEnter() throws Exception { + // Start a new call that will go to the delayed transport + ClientCall call = channel.newCall(method, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + deliverResolutionResult(); + + channel.enterIdle(); + + // Since we have a call in delayed transport, the call to enterIdle() should have resulted in + // the channel going to idle mode and then immediately exiting. We confirm this by verifying + // that the name resolver was started up twice - once when the call was first created and a + // second time after exiting idle mode. + verify(mockNameResolver, times(2)).start(any(NameResolver.Listener2.class)); + } + @Test public void realTransportsHoldsOffIdleness() throws Exception { final EquivalentAddressGroup addressGroup = servers.get(1); @@ -332,6 +363,50 @@ public void realTransportsHoldsOffIdleness() throws Exception { verify(mockLoadBalancer).shutdown(); } + @Test + public void enterIdleWhileRealTransportInProgress() { + final EquivalentAddressGroup addressGroup = servers.get(1); + + // Start a call, which goes to delayed transport + ClientCall call = channel.newCall(method, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + + // Verify that we have exited the idle mode + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); + deliverResolutionResult(); + Helper helper = helperCaptor.getValue(); + + // Create a subchannel for the real transport to happen on. + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + requestConnectionSafely(helper, subchannel); + MockClientTransportInfo t0 = newTransports.poll(); + t0.listener.transportReady(); + + SubchannelPicker mockPicker = mock(SubchannelPicker.class); + when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) + .thenReturn(PickResult.withSubchannel(subchannel)); + updateBalancingStateSafely(helper, READY, mockPicker); + + // Delayed transport creates real streams in the app executor + executor.runDueTasks(); + + // Move transport to the in-use state + t0.listener.transportInUse(true); + + // Now we enter Idle mode while real transport is happening + channel.enterIdle(); + + // Verify that the name resolver and the load balance were shut down. + verify(mockNameResolver).shutdown(); + verify(mockLoadBalancer).shutdown(); + + // When there are no pending streams, the call to enterIdle() should stick and + // we remain in idle mode. We verify this by making sure that the name resolver + // was not started up more than once (the initial startup). + verify(mockNameResolver, atMostOnce()).start(isA(NameResolver.Listener2.class)); + } + @Test public void updateSubchannelAddresses_newAddressConnects() { ClientCall call = channel.newCall(method, CallOptions.DEFAULT); From 46d47d52d9744642b8970873b277f91e69155593 Mon Sep 17 00:00:00 2001 From: Kurt Alfred Kluever Date: Fri, 27 Aug 2021 14:24:27 -0400 Subject: [PATCH 46/82] Update error-prone to the latest release (2.9.0) (#8456) required as a prerequisite to using `@InlineMe.` --- build.gradle | 2 +- repositories.bzl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index adae5b70d1b..6c099e0cf39 100644 --- a/build.gradle +++ b/build.gradle @@ -145,7 +145,7 @@ subprojects { animalsniffer_annotations: "org.codehaus.mojo:animal-sniffer-annotations:1.19", autovalue: "com.google.auto.value:auto-value:${autovalueVersion}", autovalue_annotation: "com.google.auto.value:auto-value-annotations:${autovalueVersion}", - errorprone: "com.google.errorprone:error_prone_annotations:2.4.0", + errorprone: "com.google.errorprone:error_prone_annotations:2.9.0", cronet_api: 'org.chromium.net:cronet-api:92.4515.131', cronet_embedded: 'org.chromium.net:cronet-embedded:92.4515.131', gson: "com.google.code.gson:gson:2.8.6", diff --git a/repositories.bzl b/repositories.bzl index ad50272d286..6cf75aa7bb6 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -16,7 +16,7 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.auth:google-auth-library-oauth2-http:0.22.0", "com.google.code.findbugs:jsr305:3.0.2", "com.google.code.gson:gson:jar:2.8.6", - "com.google.errorprone:error_prone_annotations:2.4.0", + "com.google.errorprone:error_prone_annotations:2.9.0", "com.google.guava:failureaccess:1.0.1", "com.google.guava:guava:30.1-android", "com.google.j2objc:j2objc-annotations:1.3", From 0f6380b470a8bf864b71d2ac419ae15b631a4a80 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Fri, 27 Aug 2021 13:30:47 -0700 Subject: [PATCH 47/82] xds: server side xDS routing and config application (#8318) Added routing config discovery for HCM in LdsUpdate in XdsServerWrapper. This can be LDS inline or through RDS. Deal with inflight SslContextProviderSupplier resource handling. Discovered routing config is updated to FilterChainSelectorRef. Added routing config data field in FilterChainSelector. Filter chain matching would resulting in setting a new attribute key for server routing config. Filter chain matching logics mostly not changed. Installed ConfigApplyingInterceptor in XdsServerWrapper's delegateBuilder. It fetches server routing config attribute set above. It does routing match and creates server interceptors for the http filters as a result. --- ...ilterChainMatchingProtocolNegotiators.java | 50 +- .../main/java/io/grpc/xds/RoutingUtils.java | 219 ++++++ .../java/io/grpc/xds/XdsServerBuilder.java | 3 +- .../java/io/grpc/xds/XdsServerWrapper.java | 382 +++++++++- ...rChainMatchingProtocolNegotiatorsTest.java | 493 ++++++------- .../XdsClientWrapperForServerSdsTestMisc.java | 2 +- .../io/grpc/xds/XdsSdsClientServerTest.java | 19 +- .../io/grpc/xds/XdsServerBuilderTest.java | 2 +- .../java/io/grpc/xds/XdsServerTestHelper.java | 2 +- .../io/grpc/xds/XdsServerWrapperTest.java | 692 ++++++++++++++++-- 10 files changed, 1440 insertions(+), 424 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/RoutingUtils.java diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index 596510ef05f..0c8780fe744 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_REF; +import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import com.google.common.annotations.VisibleForTesting; @@ -36,6 +37,7 @@ import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; import io.grpc.xds.internal.Matchers.CidrMatcher; import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.ChannelFutureListener; @@ -50,6 +52,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; @@ -101,9 +104,11 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc return; } ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent) evt; - Attributes attr = InternalProtocolNegotiationEvent.getAttributes(pne) - .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, - config.sslContextProviderSupplier).build(); + // TODO(zivy): merge into one key and take care of this outer class visibility. + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(pne).toBuilder() + .set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, config.sslContextProviderSupplier) + .set(ATTR_SERVER_ROUTING_CONFIG, config.routingConfig) + .build(); pne = InternalProtocolNegotiationEvent.withAttributes(pne, attr); ctx.pipeline().replace(this, null, delegate.newHandler(grpcHandler)); ctx.fireUserEventTriggered(pne); @@ -111,22 +116,29 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc static final class FilterChainSelector { public static final FilterChainSelector NO_FILTER_CHAIN = new FilterChainSelector( - Collections.emptyList(), null); - - private final List filterChainList; + Collections.emptyMap(), null, null); + private final Map routingConfigs; @Nullable private final SslContextProviderSupplier defaultSslContextProviderSupplier; + @Nullable + private final ServerRoutingConfig defaultRoutingConfig; - FilterChainSelector(List filterChainList, - @Nullable SslContextProviderSupplier defaultSslContextProviderSupplier) { - checkNotNull(filterChainList, "filterChainList"); - this.filterChainList = filterChainList; + FilterChainSelector(Map routingConfigs, + @Nullable SslContextProviderSupplier defaultSslContextProviderSupplier, + @Nullable ServerRoutingConfig defaultRoutingConfig) { + this.routingConfigs = checkNotNull(routingConfigs, "routingConfigs"); this.defaultSslContextProviderSupplier = defaultSslContextProviderSupplier; + this.defaultRoutingConfig = defaultRoutingConfig; + } + + @VisibleForTesting + Map getRoutingConfigs() { + return routingConfigs; } @VisibleForTesting - List getFilterChains() { - return filterChainList; + ServerRoutingConfig getDefaultRoutingConfig() { + return defaultRoutingConfig; } @VisibleForTesting @@ -138,7 +150,7 @@ SslContextProviderSupplier getDefaultSslContextProviderSupplier() { * Throws IllegalStateException when no exact one match, and we should close the connection. */ SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) { - Collection filterChains = new ArrayList<>(filterChainList); + Collection filterChains = routingConfigs.keySet(); filterChains = filterOnDestinationPort(filterChains); filterChains = filterOnIpAddress(filterChains, localAddr.getAddress(), true); filterChains = filterOnServerNames(filterChains); @@ -155,10 +167,11 @@ SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) } if (filterChains.size() == 1) { FilterChain selected = Iterables.getOnlyElement(filterChains); - return new SelectedConfig(selected.getSslContextProviderSupplier()); + return new SelectedConfig( + routingConfigs.get(selected), selected.getSslContextProviderSupplier()); } - if (defaultSslContextProviderSupplier != null) { - return new SelectedConfig(defaultSslContextProviderSupplier); + if (defaultRoutingConfig != null) { + return new SelectedConfig(defaultRoutingConfig, defaultSslContextProviderSupplier); } return null; } @@ -361,10 +374,13 @@ public void close() { * The FilterChain level configuration. */ private static final class SelectedConfig { + private final ServerRoutingConfig routingConfig; @Nullable private final SslContextProviderSupplier sslContextProviderSupplier; - private SelectedConfig(@Nullable SslContextProviderSupplier sslContextProviderSupplier) { + private SelectedConfig(ServerRoutingConfig routingConfig, + @Nullable SslContextProviderSupplier sslContextProviderSupplier) { + this.routingConfig = checkNotNull(routingConfig, "routingConfig"); this.sslContextProviderSupplier = sslContextProviderSupplier; } } diff --git a/xds/src/main/java/io/grpc/xds/RoutingUtils.java b/xds/src/main/java/io/grpc/xds/RoutingUtils.java new file mode 100644 index 00000000000..8bf879f43b0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/RoutingUtils.java @@ -0,0 +1,219 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.base.Joiner; +import io.grpc.Metadata; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; +import java.util.List; +import java.util.Locale; +import javax.annotation.Nullable; + +/** + * Utilities for performing virtual host domain name matching and route matching. + */ +// TODO(chengyuanzhang): clean up implementations in XdsNameResolver. +final class RoutingUtils { + // Prevent instantiation. + private RoutingUtils() { + } + + /** + * Returns the {@link VirtualHost} with the best match domain for the given hostname. + */ + @Nullable + static VirtualHost findVirtualHostForHostName(List virtualHosts, String hostName) { + // Domain search order: + // 1. Exact domain names: ``www.foo.com``. + // 2. Suffix domain wildcards: ``*.foo.com`` or ``*-bar.foo.com``. + // 3. Prefix domain wildcards: ``foo.*`` or ``foo-*``. + // 4. Special wildcard ``*`` matching any domain. + // + // The longest wildcards match first. + // Assuming only a single virtual host in the entire route configuration can match + // on ``*`` and a domain must be unique across all virtual hosts. + int matchingLen = -1; // longest length of wildcard pattern that matches host name + boolean exactMatchFound = false; // true if a virtual host with exactly matched domain found + VirtualHost targetVirtualHost = null; // target VirtualHost with longest matched domain + for (VirtualHost vHost : virtualHosts) { + for (String domain : vHost.domains()) { + boolean selected = false; + if (matchHostName(hostName, domain)) { // matching + if (!domain.contains("*")) { // exact matching + exactMatchFound = true; + targetVirtualHost = vHost; + break; + } else if (domain.length() > matchingLen) { // longer matching pattern + selected = true; + } else if (domain.length() == matchingLen && domain.startsWith("*")) { // suffix matching + selected = true; + } + } + if (selected) { + matchingLen = domain.length(); + targetVirtualHost = vHost; + } + } + if (exactMatchFound) { + break; + } + } + return targetVirtualHost; + } + + /** + * Returns {@code true} iff {@code hostName} matches the domain name {@code pattern} with + * case-insensitive. + * + *

Wildcard pattern rules: + *

    + *
  1. A single asterisk (*) matches any domain.
  2. + *
  3. Asterisk (*) is only permitted in the left-most or the right-most part of the pattern, + * but not both.
  4. + *
+ */ + private static boolean matchHostName(String hostName, String pattern) { + checkArgument(hostName.length() != 0 && !hostName.startsWith(".") && !hostName.endsWith("."), + "Invalid host name"); + checkArgument(pattern.length() != 0 && !pattern.startsWith(".") && !pattern.endsWith("."), + "Invalid pattern/domain name"); + + hostName = hostName.toLowerCase(Locale.US); + pattern = pattern.toLowerCase(Locale.US); + // hostName and pattern are now in lower case -- domain names are case-insensitive. + + if (!pattern.contains("*")) { + // Not a wildcard pattern -- hostName and pattern must match exactly. + return hostName.equals(pattern); + } + // Wildcard pattern + + if (pattern.length() == 1) { + return true; + } + + int index = pattern.indexOf('*'); + + // At most one asterisk (*) is allowed. + if (pattern.indexOf('*', index + 1) != -1) { + return false; + } + + // Asterisk can only match prefix or suffix. + if (index != 0 && index != pattern.length() - 1) { + return false; + } + + // HostName must be at least as long as the pattern because asterisk has to + // match one or more characters. + if (hostName.length() < pattern.length()) { + return false; + } + + if (index == 0 && hostName.endsWith(pattern.substring(1))) { + // Prefix matching fails. + return true; + } + + // Pattern matches hostname if suffix matching succeeds. + return index == pattern.length() - 1 + && hostName.startsWith(pattern.substring(0, pattern.length() - 1)); + } + + /** + * Returns {@code true} iff the given {@link RouteMatch} matches the RPC's full method name and + * headers. + */ + static boolean matchRoute(RouteMatch routeMatch, String fullMethodName, + Metadata headers, ThreadSafeRandom random) { + if (!matchPath(routeMatch.pathMatcher(), fullMethodName)) { + return false; + } + for (HeaderMatcher headerMatcher : routeMatch.headerMatchers()) { + if (!matchHeader(headerMatcher, getHeaderValue(headers, headerMatcher.name()))) { + return false; + } + } + FractionMatcher fraction = routeMatch.fractionMatcher(); + return fraction == null || random.nextInt(fraction.denominator()) < fraction.numerator(); + } + + private static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) { + if (pathMatcher.path() != null) { + return pathMatcher.caseSensitive() + ? pathMatcher.path().equals(fullMethodName) + : pathMatcher.path().equalsIgnoreCase(fullMethodName); + } else if (pathMatcher.prefix() != null) { + return pathMatcher.caseSensitive() + ? fullMethodName.startsWith(pathMatcher.prefix()) + : fullMethodName.toLowerCase().startsWith(pathMatcher.prefix().toLowerCase()); + } + return pathMatcher.regEx().matches(fullMethodName); + } + + private static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) { + if (headerMatcher.present() != null) { + return (value == null) == headerMatcher.present().equals(headerMatcher.inverted()); + } + if (value == null) { + return false; + } + boolean baseMatch; + if (headerMatcher.exactValue() != null) { + baseMatch = headerMatcher.exactValue().equals(value); + } else if (headerMatcher.safeRegEx() != null) { + baseMatch = headerMatcher.safeRegEx().matches(value); + } else if (headerMatcher.range() != null) { + long numValue; + try { + numValue = Long.parseLong(value); + baseMatch = numValue >= headerMatcher.range().start() + && numValue <= headerMatcher.range().end(); + } catch (NumberFormatException ignored) { + baseMatch = false; + } + } else if (headerMatcher.prefix() != null) { + baseMatch = value.startsWith(headerMatcher.prefix()); + } else { + baseMatch = value.endsWith(headerMatcher.suffix()); + } + return baseMatch != headerMatcher.inverted(); + } + + @Nullable + private static String getHeaderValue(Metadata headers, String headerName) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return null; + } + if (headerName.equals("content-type")) { + return "application/grpc"; + } + Metadata.Key key; + try { + key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); + } catch (IllegalArgumentException e) { + return null; + } + Iterable values = headers.getAll(key); + return values == null ? null : Joiner.on(",").join(values); + } +} diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index d0e12caec11..34879fd8cd0 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -49,6 +49,7 @@ public final class XdsServerBuilder extends ForwardingServerBuilder ATTR_SERVER_ROUTING_CONFIG = + Attributes.Key.create("io.grpc.xds.ServerWrapper.serverRoutingConfig"); + @VisibleForTesting static final long RETRY_DELAY_NANOS = TimeUnit.MINUTES.toNanos(1); private final String listenerAddress; private final ServerBuilder delegateBuilder; private boolean sharedTimeService; private final ScheduledExecutorService timeService; + private final FilterRegistry filterRegistry; + private final ThreadSafeRandom random = ThreadSafeRandomImpl.instance; private final XdsClientPoolFactory xdsClientPoolFactory; private final XdsServingStatusListener listener; private final AtomicReference filterChainSelectorRef; @@ -92,9 +118,10 @@ public void uncaughtException(Thread t, Throwable e) { ServerBuilder delegateBuilder, XdsServingStatusListener listener, AtomicReference filterChainSelectorRef, - XdsClientPoolFactory xdsClientPoolFactory) { + XdsClientPoolFactory xdsClientPoolFactory, + FilterRegistry filterRegistry) { this(listenerAddress, delegateBuilder, listener, filterChainSelectorRef, xdsClientPoolFactory, - SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + filterRegistry, SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); sharedTimeService = true; } @@ -105,13 +132,16 @@ public void uncaughtException(Thread t, Throwable e) { XdsServingStatusListener listener, AtomicReference filterChainSelectorRef, XdsClientPoolFactory xdsClientPoolFactory, + FilterRegistry filterRegistry, ScheduledExecutorService timeService) { this.listenerAddress = checkNotNull(listenerAddress, "listenerAddress"); this.delegateBuilder = checkNotNull(delegateBuilder, "delegateBuilder"); + this.delegateBuilder.intercept(new ConfigApplyingInterceptor()); this.listener = checkNotNull(listener, "listener"); this.filterChainSelectorRef = checkNotNull(filterChainSelectorRef, "filterChainSelectorRef"); this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); this.timeService = checkNotNull(timeService, "timeService"); + this.filterRegistry = checkNotNull(filterRegistry,"filterRegistry"); this.delegate = delegateBuilder.build(); } @@ -148,8 +178,6 @@ private void internalStart() { return; } xdsClient = xdsClientPool.getObject(); - // TODO(chengyuanzhang): add an API on XdsClient indicating if it is using v3, don't get - // from bootstrap. boolean useProtocolV3 = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); String listenerTemplate = xdsClient.getBootstrapInfo().getServerListenerResourceNameTemplate(); if (!useProtocolV3 || listenerTemplate == null) { @@ -307,6 +335,10 @@ public void run() { private final class DiscoveryState implements LdsResourceWatcher { private final String resourceName; + // RDS resource name is the key. + private final Map routeDiscoveryStates = new HashMap<>(); + // Track pending RDS resources using rds name. + private final Set pendingRds = new HashSet<>(); // Most recently discovered filter chains. private List filterChains = new ArrayList<>(); // Most recently discovered default filter chain. @@ -328,9 +360,44 @@ public void run() { return; } checkNotNull(update.listener(), "update"); + if (!pendingRds.isEmpty()) { + // filter chain state has not yet been applied to filterChainSelectorRef and there are + // two sets of sslContextProviderSuppliers, so we release the old ones. + releaseSuppliersInFlight(); + pendingRds.clear(); + } filterChains = update.listener().getFilterChains(); defaultFilterChain = update.listener().getDefaultFilterChain(); - updateSelector(); + List allFilterChains = filterChains; + if (defaultFilterChain != null) { + allFilterChains = new ArrayList<>(filterChains); + allFilterChains.add(defaultFilterChain); + } + Set allRds = new HashSet<>(); + for (FilterChain filterChain : allFilterChains) { + HttpConnectionManager hcm = filterChain.getHttpConnectionManager(); + if (hcm.virtualHosts() == null) { + RouteDiscoveryState rdsState = routeDiscoveryStates.get(hcm.rdsName()); + if (rdsState == null) { + rdsState = new RouteDiscoveryState(hcm.rdsName()); + routeDiscoveryStates.put(hcm.rdsName(), rdsState); + xdsClient.watchRdsResource(hcm.rdsName(), rdsState); + } + if (rdsState.isPending) { + pendingRds.add(hcm.rdsName()); + } + allRds.add(hcm.rdsName()); + } + } + for (Map.Entry entry: routeDiscoveryStates.entrySet()) { + if (!allRds.contains(entry.getKey())) { + xdsClient.cancelRdsResourceWatch(entry.getKey(), entry.getValue()); + } + } + routeDiscoveryStates.keySet().retainAll(allRds); + if (pendingRds.isEmpty()) { + updateSelector(true); + } } }); } @@ -372,47 +439,61 @@ public void run() { private void shutdown() { stopped = true; + cleanUpRouteDiscoveryStates(); logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); xdsClient.cancelLdsResourceWatch(resourceName, this); - List toRelease = collectSslContextProviderSuppliers(); + List toRelease = getSuppliersInUse(); filterChainSelectorRef.set(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { s.close(); } + releaseSuppliersInFlight(); } - private List collectSslContextProviderSuppliers() { - List toRelease = new ArrayList<>(); - FilterChainSelector selector = filterChainSelectorRef.get(); - if (selector != null) { - for (FilterChain f: selector.getFilterChains()) { - if (f.getSslContextProviderSupplier() != null) { - toRelease.add(f.getSslContextProviderSupplier()); - } - } - SslContextProviderSupplier defaultSupplier = - selector.getDefaultSslContextProviderSupplier(); - if (defaultSupplier != null) { - toRelease.add(defaultSupplier); - } + /** + * Use firstTimeNoPendingRds to indicate that the previous SslContextProviderSuppliers in + * filterChainSelectorRef should be released. Call updateSelector(true) when all routing are + * just complete and the newest filter chain is ready to be applied to the + * filterChainSelectorRef. Call updateSelector(false) for subsequent routing update + * corresponding to the same filter chain list. + */ + private void updateSelector(boolean firstTimeNoPendingRds) { + Map filterChainRouting = new HashMap<>(); + for (FilterChain filterChain: filterChains) { + filterChainRouting.put(filterChain, generateRoutingConfig(filterChain)); } - return toRelease; - } - - private void updateSelector() { - List toRelease = collectSslContextProviderSuppliers(); FilterChainSelector selector = new FilterChainSelector( - Collections.unmodifiableList(filterChains), - defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier()); + Collections.unmodifiableMap(filterChainRouting), + defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier(), + defaultFilterChain == null ? null : generateRoutingConfig(defaultFilterChain)); + List toRelease = Collections.emptyList(); + if (firstTimeNoPendingRds) { + toRelease = getSuppliersInUse(); + } filterChainSelectorRef.set(selector); - for (SslContextProviderSupplier s: toRelease) { - s.close(); + for (SslContextProviderSupplier e: toRelease) { + e.close(); } startDelegateServer(); } + private ServerRoutingConfig generateRoutingConfig(FilterChain filterChain) { + HttpConnectionManager hcm = filterChain.getHttpConnectionManager(); + if (hcm.virtualHosts() != null) { + return ServerRoutingConfig.create(hcm.httpFilterConfigs(), hcm.virtualHosts()); + } else { + RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); + if (rds != null && rds.savedVirtualHosts != null) { + return ServerRoutingConfig.create(hcm.httpFilterConfigs(), rds.savedVirtualHosts); + } else { + return ServerRoutingConfig.FAILING_ROUTING_CONFIG; + } + } + } + private void handleConfigNotFound(StatusException exception) { - List toRelease = collectSslContextProviderSuppliers(); + cleanUpRouteDiscoveryStates(); + List toRelease = getSuppliersInUse(); filterChainSelectorRef.set(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { s.close(); @@ -430,15 +511,240 @@ private void handleConfigNotFound(StatusException exception) { isServing = false; listener.onNotServing(exception); } + + private void cleanUpRouteDiscoveryStates() { + for (RouteDiscoveryState rdsState : routeDiscoveryStates.values()) { + String rdsName = rdsState.resourceName; + logger.log(Level.FINE, "Stop watching RDS resource {0}", rdsName); + xdsClient.cancelRdsResourceWatch(rdsName, rdsState); + } + routeDiscoveryStates.clear(); + } + + private List getSuppliersInUse() { + List toRelease = new ArrayList<>(); + FilterChainSelector selector = filterChainSelectorRef.get(); + if (selector != null) { + for (FilterChain f: selector.getRoutingConfigs().keySet()) { + if (f.getSslContextProviderSupplier() != null) { + toRelease.add(f.getSslContextProviderSupplier()); + } + } + SslContextProviderSupplier defaultSupplier = + selector.getDefaultSslContextProviderSupplier(); + if (defaultSupplier != null) { + toRelease.add(defaultSupplier); + } + } + return toRelease; + } + + private void releaseSuppliersInFlight() { + SslContextProviderSupplier supplier; + for (FilterChain filterChain : filterChains) { + supplier = filterChain.getSslContextProviderSupplier(); + if (supplier != null) { + supplier.close(); + } + } + if (defaultFilterChain != null + && (supplier = defaultFilterChain.getSslContextProviderSupplier()) != null) { + supplier.close(); + } + } + + private final class RouteDiscoveryState implements RdsResourceWatcher { + private final String resourceName; + @Nullable + private List savedVirtualHosts; + private boolean isPending = true; + + private RouteDiscoveryState(String resourceName) { + this.resourceName = checkNotNull(resourceName, "resourceName"); + } + + @Override + public void onChanged(final RdsUpdate update) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; + } + savedVirtualHosts = update.virtualHosts; + maybeUpdateSelector(); + } + }); + } + + @Override + public void onResourceDoesNotExist(final String resourceName) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; + } + logger.log(Level.WARNING, "Rds {0} unavailable", resourceName); + savedVirtualHosts = null; + maybeUpdateSelector(); + } + }); + } + + @Override + public void onError(final Status error) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; + } + logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", + new Object[]{resourceName, error}); + maybeUpdateSelector(); + } + }); + } + + // Update the selector to use the most recently updated configs only after all rds have been + // discovered, i.e. pendingRds is empty. Do the updateSelector even after rds are already + // fully discovered and new change comes. + private void maybeUpdateSelector() { + isPending = false; + boolean isLastPending = pendingRds.remove(resourceName); + if (pendingRds.isEmpty()) { + updateSelector(isLastPending); + } + } + } + + private boolean isPermanentError(Status error) { + return EnumSet.of( + Status.Code.INTERNAL, + Status.Code.INVALID_ARGUMENT, + Status.Code.FAILED_PRECONDITION, + Status.Code.PERMISSION_DENIED, + Status.Code.UNAUTHENTICATED) + .contains(error.getCode()); + } } - private static boolean isPermanentError(Status error) { - return EnumSet.of( - Status.Code.INTERNAL, - Status.Code.INVALID_ARGUMENT, - Status.Code.FAILED_PRECONDITION, - Status.Code.PERMISSION_DENIED, - Status.Code.UNAUTHENTICATED) - .contains(error.getCode()); + @VisibleForTesting + final class ConfigApplyingInterceptor implements ServerInterceptor { + private final ServerInterceptor noopInterceptor = new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + return next.startCall(call, headers); + } + }; + + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + ServerRoutingConfig routingConfig = call.getAttributes().get(ATTR_SERVER_ROUTING_CONFIG); + if (routingConfig == null + || routingConfig.equals(ServerRoutingConfig.FAILING_ROUTING_CONFIG)) { + String errorMsg = "Missing xDS routing config. " + (routingConfig == null ? "" : + "RDS config unavailable."); + call.close(Status.UNAVAILABLE.withDescription(errorMsg), new Metadata()); + return new Listener() {}; + } + VirtualHost virtualHost = RoutingUtils.findVirtualHostForHostName( + routingConfig.virtualHosts(), call.getAuthority()); + if (virtualHost == null) { + call.close( + Status.UNAVAILABLE.withDescription("Could not find xDS virtual host matching RPC"), + new Metadata()); + return new Listener() {}; + } + Route selectedRoute = null; + Map selectedOverrideConfigs = + new HashMap<>(virtualHost.filterConfigOverrides()); + MethodDescriptor method = call.getMethodDescriptor(); + for (Route route : virtualHost.routes()) { + if (RoutingUtils.matchRoute( + route.routeMatch(), "/" + method.getFullMethodName(), headers, random)) { + selectedRoute = route; + selectedOverrideConfigs.putAll(route.filterConfigOverrides()); + break; + } + } + if (selectedRoute == null) { + call.close( + Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC"), + new Metadata()); + return new ServerCall.Listener() {}; + } + List filterInterceptors = new ArrayList<>(); + for (NamedFilterConfig namedFilterConfig : routingConfig.httpFilterConfigs()) { + FilterConfig filterConfig = namedFilterConfig.filterConfig; + Filter filter = filterRegistry.get(filterConfig.typeUrl()); + if (filter instanceof ServerInterceptorBuilder) { + ServerInterceptor interceptor = + ((ServerInterceptorBuilder) filter).buildServerInterceptor( + filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); + if (interceptor != null) { + filterInterceptors.add(interceptor); + } + } else { + call.close( + Status.UNAVAILABLE.withDescription("HttpFilterConfig(type URL: " + + filterConfig.typeUrl() + ") is not supported on server-side."), + new Metadata()); + return new Listener() {}; + } + } + ServerInterceptor interceptor = combineInterceptors(filterInterceptors); + return interceptor.interceptCall(call, headers, next); + } + + private ServerInterceptor combineInterceptors(final List interceptors) { + if (interceptors.isEmpty()) { + return noopInterceptor; + } + if (interceptors.size() == 1) { + return interceptors.get(0); + } + return new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + // intercept forward + for (int i = interceptors.size() - 1; i >= 0; i--) { + next = InternalServerInterceptors.interceptCallHandlerCreate( + interceptors.get(i), next); + } + return next.startCall(call, headers); + } + }; + } + } + + /** + * The HttpConnectionManager level configuration. + */ + @AutoValue + abstract static class ServerRoutingConfig { + private static final ServerRoutingConfig FAILING_ROUTING_CONFIG = + new AutoValue_XdsServerWrapper_ServerRoutingConfig( + ImmutableList.of(), ImmutableList.of()); + + // Top level http filter configs. + abstract ImmutableList httpFilterConfigs(); + + abstract ImmutableList virtualHosts(); + + /** + * Server routing configuration. + * */ + public static ServerRoutingConfig create(List httpFilterConfigs, + List virtualHosts) { + checkNotNull(httpFilterConfigs, "httpFilterConfigs"); + checkNotNull(virtualHosts, "virtualHosts"); + return new AutoValue_XdsServerWrapper_ServerRoutingConfig( + ImmutableList.copyOf(httpFilterConfigs), ImmutableList.copyOf(virtualHosts)); + } } } diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index 21224f73885..a7a1a7c62e3 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -17,24 +17,27 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.EnvoyServerProtoData.CidrRange; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.FilterChain; -import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.ChannelHandler; @@ -57,6 +60,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -79,59 +84,24 @@ public class FilterChainMatchingProtocolNegotiatorsTest { private ChannelHandlerContext channelHandlerCtx; @Mock private ProtocolNegotiator mockDelegate; - private final SettableFuture sslSet = SettableFuture.create(); private static final HttpConnectionManager HTTP_CONNECTION_MANAGER = createRds("routing-config"); private static final String LOCAL_IP = "10.1.2.3"; // dest private static final String REMOTE_IP = "10.4.2.3"; // source private static final int PORT = 7000; - @Test - public void filterChainMatch() throws Exception { - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; - when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); - FilterChain f0 = createFilterChain("filter-chain-0", createRds("r0")); - SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), - mock(TlsContextManager.class)); - FilterChainSelector selector = new FilterChainSelector(Collections.singletonList(f0), - defaultSsl); - FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); - setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); - ChannelHandlerContext channelHandlerCtx = pipeline.context(filterChainMatchingHandler); - assertThat(channelHandlerCtx).isNotNull(); - - pipeline.fireUserEventTriggered(event); - channelHandlerCtx = pipeline.context(filterChainMatchingHandler); - assertThat(channelHandlerCtx).isNull(); - channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); - assertThat(sslSet.get()).isEqualTo(f0.getSslContextProviderSupplier()); - channelHandlerCtx = pipeline.context(next); - assertThat(channelHandlerCtx).isNotNull(); - } - @Test public void nofilterChainMatch_defaultSslContext() throws Exception { final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); - SslContextProviderSupplier ssl = new SslContextProviderSupplier(createTls(), tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(new ArrayList(), ssl); + SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), + tlsContextManager); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + new HashMap(), defaultSsl, noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); @@ -144,14 +114,16 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { channel.runPendingTasks(); assertThat(sslSet.isDone()).isTrue(); - assertThat(sslSet.get()).isEqualTo(ssl); + assertThat(sslSet.get()).isEqualTo(defaultSsl); + assertThat(routingSettable.get()).isEqualTo(noopConfig); channelHandlerCtx = pipeline.context(next); assertThat(channelHandlerCtx).isNotNull(); } @Test public void noFilterChainMatch_noDefaultSslContext() { - FilterChainSelector selector = new FilterChainSelector(new ArrayList(), null); + FilterChainSelector selector = new FilterChainSelector( + new HashMap(), null, null); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); @@ -182,24 +154,22 @@ public void singleFilterChainWithoutAlpn() throws Exception { "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChain), null); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector(ImmutableMap.of(filterChain, noopConfig), + null, null); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext); } @@ -225,57 +195,26 @@ public void singleFilterChainWithAlpn() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChain), defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChain, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(defaultTlsContext); } - @Test - public void defaultFilterChain() throws Exception { - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", null, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( - Arrays.asList(), - filterChain.getSslContextProviderSupplier()); - FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); - - final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; - when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); - setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); - pipeline.fireUserEventTriggered(event); - channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); - assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext); - } - @Test public void destPortFails_returnDefaultFilterChain() throws Exception { EnvoyServerProtoData.DownstreamTlsContext tlsContextWithDestPort = @@ -301,27 +240,28 @@ public void destPortFails_returnDefaultFilterChain() throws Exception { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainWithDestPort), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig routingConfig = ServerRoutingConfig.create( + new ArrayList(), Arrays.asList(createVirtualHost("virtual"))); + ServerRoutingConfig defaultRoutingConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainWithDestPort, routingConfig), + defaultFilterChain.getSslContextProviderSupplier(), defaultRoutingConfig); + FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); - assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(defaultRoutingConfig); + assertThat(sslSet.get().getTlsContext()) + .isSameInstanceAs(tlsContextForDefaultFilterChain); } @Test @@ -347,26 +287,24 @@ public void destPrefixRangeMatch() throws Exception { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainWithMatch), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainWithMatch, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("no-match")); + FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChainWithMatch.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMatch); } @@ -395,26 +333,25 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainWithMismatch), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); } @@ -443,26 +380,23 @@ public void dest0LengthPrefixRange() "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChain0Length), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChain0Length, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), null); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChain0Length.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext0Length); } @@ -505,26 +439,25 @@ public void destPrefixRange_moreSpecificWins() tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), + filterChainMoreSpecific, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChainMoreSpecific.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); } @@ -567,26 +500,24 @@ public void destPrefixRange_emptyListLessSpecific() tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), + filterChainMoreSpecific, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChainMoreSpecific.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); } @@ -628,29 +559,27 @@ public void destPrefixRangeIpv6_moreSpecificWins() tlsContextMoreSpecific, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), + filterChainMoreSpecific, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain.getSslContextProviderSupplier()); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel("FE80:0000:0000:0000:0202:B3FF:FE1E:8329", "2001:DB8::8:800:200C:417A", 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChainMoreSpecific.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); } @@ -695,26 +624,25 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainMoreSpecificWith2, filterChainLessSpecific), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, + filterChainLessSpecific, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo( + filterChainMoreSpecificWith2.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecificWith2); } @@ -741,39 +669,31 @@ public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainWithMismatch), defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); } @Test public void sourceTypeLocal() throws Exception { final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -796,14 +716,18 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainWithMatch), defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainWithMatch, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); setupChannel(LOCAL_IP, LOCAL_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChainWithMatch.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMatch); } @@ -811,14 +735,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { public void sourcePrefixRange_moreSpecificWith2Wins() throws Exception { final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = @@ -858,16 +776,22 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { tlsContextLessSpecific, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainMoreSpecificWith2, filterChainLessSpecific), - defaultFilterChain.getSslContextProviderSupplier()); + + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, + filterChainLessSpecific, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo( + filterChainMoreSpecificWith2.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecificWith2); } @@ -921,10 +845,11 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChain1, filterChain2), - defaultFilterChain.getSslContextProviderSupplier()); - + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChain1, noopConfig, filterChain2, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); @@ -983,26 +908,24 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChainEmptySourcePorts, filterChainSourcePortMatch), - defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChainEmptySourcePorts, randomConfig("no-match"), + filterChainSourcePortMatch, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChainSourcePortMatch.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextSourcePortMatch); } @@ -1135,26 +1058,31 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChain1, filterChain2, filterChain3, filterChain4, filterChain5, filterChain6), - defaultFilterChain.getSslContextProviderSupplier()); + + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + Map map = new HashMap<>(); + map.put(filterChain1, randomConfig("1")); + map.put(filterChain2, randomConfig("2")); + map.put(filterChain3, randomConfig("3")); + map.put(filterChain4, randomConfig("4")); + map.put(filterChain5, noopConfig); + map.put(filterChain6, randomConfig("6")); + FilterChainSelector selector = new FilterChainSelector( + map, defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChain5.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext5); } @@ -1213,25 +1141,23 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, mock(TlsContextManager.class)); - FilterChainSelector selector = new FilterChainSelector(Arrays.asList( - filterChain1, filterChain2), defaultFilterChain.getSslContextProviderSupplier()); + ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new ArrayList()); + FilterChainSelector selector = new FilterChainSelector( + ImmutableMap.of(filterChain1, randomConfig("1"), filterChain2, randomConfig("2")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); - ChannelHandler next = new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { - ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; - sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) - .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); - } - }; + final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); - assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext().getCommonTlsContext() .getTlsCertificateCertificateProviderInstance() .getCertificateName()).isEqualTo("CERT3"); @@ -1242,21 +1168,15 @@ private static HttpConnectionManager createRds(String name) { new ArrayList()); } - private FilterChain createFilterChain(String name, HttpConnectionManager hcm) { - return new FilterChain(name, createMatch(), - hcm, createTls(), tlsContextManager); + private static VirtualHost createVirtualHost(String name) { + return VirtualHost.create( + name, Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); } - private FilterChainMatch createMatch() { - return new FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); + private static ServerRoutingConfig randomConfig(String domain) { + return ServerRoutingConfig.create( + new ArrayList(), Arrays.asList(createVirtualHost(domain))); } private EnvoyServerProtoData.DownstreamTlsContext createTls() { @@ -1283,6 +1203,21 @@ public SocketAddress remoteAddress() { pipeline.addLast(matchingHandler); } + private static ChannelHandler captureAttrHandler( + final SettableFuture sslSet, + final SettableFuture routingSettable) { + return new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + routingSettable.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_ROUTING_CONFIG)); + } + }; + } + private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { FakeGrpcHttp2ConnectionHandler( ChannelPromise channelUnused, diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index 0f92687f443..532cb282b26 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -117,7 +117,7 @@ public void setUp() { when(mockBuilder.build()).thenReturn(mockServer); when(mockServer.isShutdown()).thenReturn(false); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:" + PORT, mockBuilder, listener, - selectorRef, new FakeXdsClientPoolFactory(xdsClient)); + selectorRef, new FakeXdsClientPoolFactory(xdsClient), FilterRegistry.newRegistry()); } @Test diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index ecfa9e6b9bc..579542a2777 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -113,7 +113,7 @@ public void plaintextClientServer() throws Exception { SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(/* upstreamTlsContext= */ null, - /* overrideAuthority= */ null); + /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } @@ -157,7 +157,7 @@ public void requireClientAuth_noClientCert_expectException() CLIENT_PEM_FILE, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); try { unaryRpc(/* requestMessage= */ "buddy", blockingStub); fail("exception expected"); @@ -184,7 +184,7 @@ public void noClientAuth_sendBadClientCert_passes() throws Exception { BAD_CLIENT_PEM_FILE, true); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } @@ -245,7 +245,7 @@ public void plaintextServer_tlsClient_expectException() throws Exception { CLIENT_PEM_FILE, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); try { unaryRpc("buddy", blockingStub); fail("exception expected"); @@ -274,7 +274,7 @@ public void mtlsClientServer_changeServerContext_expectException() xdsClient.deliverLdsUpdate(LdsUpdate.forTcpListener(listener)); try { SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); fail("exception expected"); } catch (StatusRuntimeException sre) { @@ -342,15 +342,6 @@ private void buildServerWithFallbackServerCredentials( buildServer(builder, downstreamTlsContext); } - static void generateListenerUpdateToWatcher( - DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher, - TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext, - tlsContextManager); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - } - private void buildServer( XdsServerBuilder builder, DownstreamTlsContext downstreamTlsContext) diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 4f4844f0b0b..476dc10a16a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -205,7 +205,7 @@ public void xdsServer_discoverState() throws Exception { xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); - verifyServer(future, mockXdsServingStatusListener, null); + verifyServer(null, mockXdsServingStatusListener, null); } @Test diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 078299945e5..f289c4726fb 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -179,7 +179,7 @@ void cancelLdsResourceWatch(String resourceName, LdsResourceWatcher watcher) { @Override void watchRdsResource(String resourceName, RdsResourceWatcher watcher) { - rdsWatchers.put(resourceName, watcher); + assertThat(rdsWatchers.put(resourceName, watcher)).isNull(); //re-register is not allowed. rdsCount.countDown(); } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 3463d2b14a9..f4ab7025e43 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -18,38 +18,59 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.XdsServerWrapper.RETRY_DELAY_NANOS; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.Status; import io.grpc.StatusException; import io.grpc.internal.FakeClock; +import io.grpc.testing.TestMethodDescriptors; import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.Filter.ServerInterceptorBuilder; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClient.RdsResourceWatcher; +import io.grpc.xds.XdsClient.RdsUpdate; import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; +import io.grpc.xds.XdsServerWrapper.ConfigApplyingInterceptor; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProviderSupplier; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -59,6 +80,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -80,6 +102,7 @@ public class XdsServerWrapperTest { private AtomicReference selectorRef = new AtomicReference<>(); private FakeClock executor = new FakeClock(); private FakeXdsClient xdsClient = new FakeXdsClient(); + private FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private XdsServerWrapper xdsServerWrapper; @Before @@ -87,7 +110,7 @@ public void setup() { when(mockBuilder.build()).thenReturn(mockServer); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, selectorRef, new FakeXdsClientPoolFactory(xdsClient), - executor.getScheduledExecutorService()); + filterRegistry, executor.getScheduledExecutorService()); } @Test @@ -118,7 +141,7 @@ private void verifyBootstrapFail(Bootstrapper.BootstrapInfo b) throws Exception XdsClient xdsClient = mock(XdsClient.class); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, - selectorRef, new FakeXdsClientPoolFactory(xdsClient)); + selectorRef, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -142,7 +165,6 @@ public void run() { } } - @Test public void shutdown() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -161,9 +183,12 @@ public void run() { HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( 0L, Collections.singletonList(createVirtualHost("virtual-host-0")), new ArrayList()); - EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); - SslContextProviderSupplier sslSupplier = f0.getSslContextProviderSupplier(); - xdsClient.deliverLdsUpdate(Collections.singletonList(f0), null); + FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); + xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer).start(); xdsServerWrapper.shutdown(); @@ -171,7 +196,8 @@ public void run() { assertThat(xdsClient.ldsResource).isNull(); assertThat(xdsClient.shutdown).isTrue(); verify(mockServer).shutdown(); - assertThat(sslSupplier.isShutdown()).isTrue(); + assertThat(f0.getSslContextProviderSupplier().isShutdown()).isTrue(); + assertThat(f1.getSslContextProviderSupplier().isShutdown()).isTrue(); when(mockServer.isTerminated()).thenReturn(true); when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); @@ -180,6 +206,43 @@ public void run() { assertThat(start.get()).isSameInstanceAs(xdsServerWrapper); } + @Test + public void shutdown_inflight() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(createVirtualHost("virtual-host-0")), + new ArrayList()); + FilterChain f0 = createFilterChain("filter-chain-0", createRds("rds")); + FilterChain f1 = createFilterChain("filter-chain-1", hcm_virtual); + xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); + xdsServerWrapper.shutdown(); + when(mockServer.isTerminated()).thenReturn(true); + when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); + assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + verify(mockServer, never()).start(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.ldsResource).isNull(); + assertThat(xdsClient.shutdown).isTrue(); + verify(mockServer).shutdown(); + assertThat(f0.getSslContextProviderSupplier().isShutdown()).isTrue(); + assertThat(f1.getSslContextProviderSupplier().isShutdown()).isTrue(); + assertThat(start.isDone()).isFalse(); //shall we set initialStatus when shutdown? + } + @Test public void shutdown_afterResourceNotExist() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -235,7 +298,9 @@ public void run() { FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); SslContextProviderSupplier sslSupplier = filterChain.getSslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); try { start.get(5000, TimeUnit.MILLISECONDS); fail("Start should throw exception"); @@ -247,7 +312,6 @@ public void run() { verify(mockServer, never()).shutdown(); xdsServerWrapper.shutdown(); verify(mockServer).shutdown(); - when(mockServer.isShutdown()).thenReturn(true); when(mockServer.isTerminated()).thenReturn(true); assertThat(sslSupplier.isShutdown()).isTrue(); assertThat(executor.getPendingTasks().size()).isEqualTo(0); @@ -257,6 +321,35 @@ public void run() { assertThat(xdsServerWrapper.isTerminated()).isTrue(); } + @Test + public void initialStartIoException() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + when(mockServer.start()).thenThrow(new IOException("error!")); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + assertThat(ex.getCause().getMessage()).isEqualTo("error!"); + } + } + @Test public void discoverState_virtualhost() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -285,13 +378,16 @@ public void run() { start.get(5000, TimeUnit.MILLISECONDS); FilterChainSelector selector = selectorRef.get(); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); - assertThat(selector.getFilterChains()).isEqualTo(Collections.singletonList(filterChain)); + assertThat(selector.getRoutingConfigs()).isEqualTo(ImmutableMap.of( + filterChain, ServerRoutingConfig.create(httpConnectionManager.httpFilterConfigs(), + httpConnectionManager.virtualHosts()) + )); verify(listener).onServing(); verify(mockServer).start(); } @Test - public void initialStartIoException() throws Exception { + public void discoverState_rds() throws Exception { final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -303,19 +399,163 @@ public void run() { } } }); - when(mockServer.start()).thenThrow(new IOException("error!")); - xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); - xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - try { - start.get(5000, TimeUnit.MILLISECONDS); - fail("Start should throw exception"); - } catch (ExecutionException ex) { - assertThat(ex.getCause()).isInstanceOf(IOException.class); - assertThat(ex.getCause().getMessage()).isEqualTo("error!"); - } + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = createVirtualHost("virtual-host-0"); + HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); + xdsClient.rdsCount = new CountDownLatch(3); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); + assertThat(start.isDone()).isFalse(); + assertThat(selectorRef.get()).isNull(); + verify(mockServer, never()).start(); + verify(listener, never()).onServing(); + + EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r1")); + EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r2")); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); + verify(mockServer, never()).start(); + verify(listener, never()).onServing(); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + + xdsClient.deliverRdsUpdate("r1", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + verify(mockServer, never()).start(); + xdsClient.deliverRdsUpdate("r2", + Collections.singletonList(createVirtualHost("virtual-host-2"))); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + f0, ServerRoutingConfig.create( + hcm_virtual.httpFilterConfigs(), hcm_virtual.virtualHosts()), + f2, ServerRoutingConfig.create(f2.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1"))) + )); + assertThat(selectorRef.get().getDefaultRoutingConfig()).isEqualTo( + ServerRoutingConfig.create(f3.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-2")))); + assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isEqualTo( + f3.getSslContextProviderSupplier()); + } + + @Test + public void discoverState_oneRdsToMultipleFilterChain() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", createRds("r0")); + EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); + EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); + + xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); + assertThat(start.isDone()).isFalse(); + assertThat(selectorRef.get()).isNull(); + + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("r0", + Collections.singletonList(createVirtualHost("virtual-host-0"))); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer, times(1)).start(); + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + f0, ServerRoutingConfig.create( + f0.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-0"))), + f1, ServerRoutingConfig.create(f1.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-0"))) + )); + assertThat(selectorRef.get().getDefaultRoutingConfig()).isEqualTo( + ServerRoutingConfig.create(f2.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-0")))); + assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isSameInstanceAs( + f2.getSslContextProviderSupplier()); + + EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); + EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); + xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.deliverLdsUpdate(Arrays.asList(f1, f3), f4); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("r1", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + xdsClient.deliverRdsUpdate("r0", + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + f1, ServerRoutingConfig.create( + f1.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-0"))), + f3, ServerRoutingConfig.create(f3.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-0"))) + )); + assertThat(selectorRef.get().getDefaultRoutingConfig()).isEqualTo( + ServerRoutingConfig.create(f4.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1")))); + assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isSameInstanceAs( + f4.getSslContextProviderSupplier()); + verify(mockServer, times(1)).start(); + xdsServerWrapper.shutdown(); + verify(mockServer, times(1)).shutdown(); + when(mockServer.isTerminated()).thenReturn(true); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + } + + @Test + public void discoverState_rds_onError_and_resourceNotExist() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = createVirtualHost("virtual-host-0"); + HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); + xdsClient.rdsCount.await(); + xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + start.get(5000, TimeUnit.MILLISECONDS); + assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo(ServerRoutingConfig.create( + ImmutableList.of(), ImmutableList.of()) + ); + xdsClient.deliverRdsUpdate("r0", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo( + ServerRoutingConfig.create(f1.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1")))); + + xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo( + ServerRoutingConfig.create(f1.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1")))); + + xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); + assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo(ServerRoutingConfig.create( + ImmutableList.of(), ImmutableList.of()) + ); } + @Test public void error() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -339,106 +579,386 @@ public void run() { } verify(listener, times(1)).onNotServing(any(StatusException.class)); verify(mockBuilder, times(1)).build(); - verify(mockServer, times(1)).shutdown(); - when(mockServer.isShutdown()).thenReturn(true); + FilterChain filterChain0 = createFilterChain("filter-chain-0", createRds("rds")); + SslContextProviderSupplier sslSupplier0 = filterChain0.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain0), null); xdsClient.ldsWatcher.onError(Status.INTERNAL); assertThat(selectorRef.get()).isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + assertThat(xdsClient.rdsWatchers).isEmpty(); verify(mockBuilder, times(1)).build(); verify(listener, times(2)).onNotServing(any(StatusException.class)); - verify(mockServer, times(1)).shutdown(); + assertThat(sslSupplier0.isShutdown()).isFalse(); when(mockServer.start()).thenThrow(new IOException("error!")) .thenReturn(mockServer); - when(mockServer.isShutdown()).thenReturn(true).thenReturn(false); - FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); - SslContextProviderSupplier sslSupplier = filterChain.getSslContextProviderSupplier(); - xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + FilterChain filterChain1 = createFilterChain("filter-chain-1", createRds("rds")); + SslContextProviderSupplier sslSupplier1 = filterChain1.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain1), null); + assertThat(sslSupplier0.isShutdown()).isTrue(); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + RdsResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); assertThat(executor.forwardNanos(RETRY_DELAY_NANOS)).isEqualTo(1); - verify(mockBuilder, times(2)).build(); + verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(1)).onServing(); - verify(listener, times(2)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getFilterChains()).isEqualTo( - Collections.singletonList(filterChain)); - assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isNull(); - assertThat(sslSupplier.isShutdown()).isFalse(); - + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + filterChain1, ServerRoutingConfig.create( + filterChain1.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1"))) + )); // xds update after start - filterChain = createFilterChain("filter-chain-2", createRds("rds")); - FilterChain f1 = createFilterChain("filter-chain-2-0", createRds("rds")); - SslContextProviderSupplier s1 = filterChain.getSslContextProviderSupplier(); - xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), f1); - - verify(mockBuilder, times(2)).build(); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-2"))); + assertThat(sslSupplier1.isShutdown()).isFalse(); + xdsClient.ldsWatcher.onError(Status.DEADLINE_EXCEEDED); + verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); - verify(listener, times(1)).onServing(); verify(listener, times(2)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getFilterChains()) - .isEqualTo(Collections.singletonList(filterChain)); - assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()) - .isEqualTo(f1.getSslContextProviderSupplier()); - assertThat(sslSupplier.isShutdown()).isTrue(); - assertThat(s1.isShutdown()).isFalse(); + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + filterChain1, ServerRoutingConfig.create( + filterChain1.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-2"))) + )); + assertThat(sslSupplier1.isShutdown()).isFalse(); // not serving after serving - xdsClient.ldsWatcher.onError(Status.INTERNAL); - verify(mockServer, times(2)).shutdown(); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + assertThat(xdsClient.rdsWatchers).isEmpty(); + verify(mockServer, times(3)).shutdown(); when(mockServer.isShutdown()).thenReturn(true); assertThat(selectorRef.get()).isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); verify(listener, times(3)).onNotServing(any(StatusException.class)); - assertThat(s1.isShutdown()).isTrue(); + assertThat(sslSupplier1.isShutdown()).isTrue(); + // no op + saveRdsWatcher.onChanged( + new RdsUpdate(Collections.singletonList(createVirtualHost("virtual-host-1")))); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(2)).start(); + verify(listener, times(1)).onServing(); // cancel retry when(mockServer.start()).thenThrow(new IOException("error1!")) .thenThrow(new IOException("error2!")) .thenReturn(mockServer); - xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - verify(mockBuilder, times(3)).build(); + FilterChain filterChain2 = createFilterChain("filter-chain-2", createRds("rds")); + SslContextProviderSupplier sslSupplier2 = filterChain2.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain2), null); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(sslSupplier1.isShutdown()).isTrue(); + verify(mockBuilder, times(2)).build(); when(mockServer.isShutdown()).thenReturn(false); verify(mockServer, times(3)).start(); verify(listener, times(1)).onServing(); verify(listener, times(3)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getFilterChains()).isEqualTo(Collections.singletonList( - filterChain) - ); + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + filterChain2, ServerRoutingConfig.create( + filterChain2.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1"))) + )); assertThat(executor.numPendingTasks()).isEqualTo(1); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); - verify(mockServer, times(3)).shutdown(); - when(mockServer.isShutdown()).thenReturn(true); + verify(mockServer, times(4)).shutdown(); verify(listener, times(4)).onNotServing(any(StatusException.class)); + when(mockServer.isShutdown()).thenReturn(true); assertThat(executor.numPendingTasks()).isEqualTo(0); + assertThat(sslSupplier2.isShutdown()).isTrue(); // serving after not serving - xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - verify(mockBuilder, times(4)).build(); + FilterChain filterChain3 = createFilterChain("filter-chain-2", createRds("rds")); + SslContextProviderSupplier sslSupplier3 = filterChain3.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain3), null); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + verify(mockBuilder, times(3)).build(); verify(mockServer, times(4)).start(); verify(listener, times(1)).onServing(); + when(mockServer.isShutdown()).thenReturn(false); verify(listener, times(4)).onNotServing(any(StatusException.class)); - assertThat(executor.forwardNanos(RETRY_DELAY_NANOS)).isEqualTo(1); - verify(listener, times(2)).onServing(); - assertThat(selectorRef.get().getFilterChains()).isEqualTo(Collections.singletonList( - filterChain) - ); + assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( + filterChain3, ServerRoutingConfig.create( + filterChain3.getHttpConnectionManager().httpFilterConfigs(), + Collections.singletonList(createVirtualHost("virtual-host-1"))) + )); + xdsServerWrapper.shutdown(); + verify(mockServer, times(5)).shutdown(); + assertThat(sslSupplier3.isShutdown()).isTrue(); + when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); + assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); } + @Test + @SuppressWarnings("unchecked") + public void interceptor_notServerInterceptor() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url"); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo( + "HttpFilterConfig(type URL: filter-type-url) is not supported on server-side."); + } - private FilterChain createFilterChain(String name, HttpConnectionManager hcm) { + @Test + @SuppressWarnings("unchecked") + public void interceptor_virtualHostNotMatch() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url"); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); + when(serverCall.getAuthority()).thenReturn("not-match.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo("Could not find xDS virtual host matching RPC"); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_routeNotMatch() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url"); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("NotMatchMethod")); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo("Could not find xDS route matching RPC"); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_failingRouterConfig() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig failingConfig = ServerRoutingConfig.create( + ImmutableList.of(), ImmutableList.of()); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, failingConfig).build()); + + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo( + "Missing xDS routing config. RDS config unavailable."); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptors() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + final ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); + Filter filter = mock(Filter.class, withSettings() + .extraInterfaces(ServerInterceptorBuilder.class)); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + FilterConfig f0 = mock(FilterConfig.class); + FilterConfig f0Override = mock(FilterConfig.class); + when(f0.typeUrl()).thenReturn("filter-type-url"); + final List interceptorTrace = new ArrayList<>(); + ServerInterceptor interceptor0 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(0); + return next.startCall(call, headers); + } + }; + ServerInterceptor interceptor1 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(1); + return next.startCall(call, headers); + } + }; + + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), + Arrays.asList(Route.forAction(routeMatch, null, + ImmutableMap.of())), + ImmutableMap.of("filter-config-name-0", f0Override)); + ServerRoutingConfig routingConfig = ServerRoutingConfig.create( + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0), + new NamedFilterConfig("filter-config-name-1", f0)), + Collections.singletonList(virtualHost) + ); + ServerCall serverCall = mock(ServerCall.class); + ServerCallHandler mockNext = mock(ServerCallHandler.class); + final ServerCall.Listener listener = new ServerCall.Listener() {}; + when(mockNext.startCall(any(ServerCall.class), any(Metadata.class))).thenReturn(listener); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) + .thenReturn(null); + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) + .thenReturn(null); + ServerCall.Listener configApplyingInterceptorListener = + interceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(configApplyingInterceptorListener).isSameInstanceAs(listener); + verify(mockNext).startCall(eq(serverCall), any(Metadata.class)); + assertThat(interceptorTrace).isEqualTo(Arrays.asList()); + + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) + .thenReturn(null); + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) + .thenReturn(interceptor0); + configApplyingInterceptorListener = interceptor.interceptCall( + serverCall, new Metadata(), mockNext); + assertThat(configApplyingInterceptorListener).isSameInstanceAs(listener); + verify(mockNext, times(2)).startCall(eq(serverCall), any(Metadata.class)); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(0)); + + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) + .thenReturn(interceptor0); + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) + .thenReturn(interceptor1); + configApplyingInterceptorListener = interceptor.interceptCall( + serverCall, new Metadata(), mockNext); + assertThat(configApplyingInterceptorListener).isSameInstanceAs(listener); + verify(mockNext, times(3)).startCall(eq(serverCall), any(Metadata.class)); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(0, 0, 1)); + } + + private static FilterChain createFilterChain(String name, HttpConnectionManager hcm) { return new EnvoyServerProtoData.FilterChain(name, createMatch(), hcm, createTls(), tlsContextManager); } - private VirtualHost createVirtualHost(String name) { + private static VirtualHost createVirtualHost(String name) { return VirtualHost.create( name, Collections.singletonList("auth"), new ArrayList(), ImmutableMap.of()); } - private HttpConnectionManager createRds(String name) { + private static HttpConnectionManager createRds(String name) { return HttpConnectionManager.forRdsName(0L, name, - new ArrayList()); + Arrays.asList(new NamedFilterConfig("named-config-" + name, null))); } - private EnvoyServerProtoData.FilterChainMatch createMatch() { + private static EnvoyServerProtoData.FilterChainMatch createMatch() { return new EnvoyServerProtoData.FilterChainMatch( 0, Arrays.asList(), @@ -450,7 +970,35 @@ private EnvoyServerProtoData.FilterChainMatch createMatch() { null); } - private EnvoyServerProtoData.DownstreamTlsContext createTls() { + private static ServerRoutingConfig createRoutingConfig(String path, String domain, + String filterType) { + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath(path, true), + Collections.emptyList(), null); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList(domain), + Arrays.asList(Route.forAction(routeMatch, null, + ImmutableMap.of())), + Collections.emptyMap()); + FilterConfig f0 = mock(FilterConfig.class); + when(f0.typeUrl()).thenReturn(filterType); + return ServerRoutingConfig.create( + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0)), + Collections.singletonList(virtualHost) + ); + } + + private static MethodDescriptor createMethod(String path) { + return MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setFullMethodName(path) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) + .build(); + } + + private static EnvoyServerProtoData.DownstreamTlsContext createTls() { return CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); } } From f3337f28cee14f1d98249c13edd14ec8d4bc699d Mon Sep 17 00:00:00 2001 From: Kurt Alfred Kluever Date: Fri, 27 Aug 2021 17:11:06 -0400 Subject: [PATCH 48/82] stub: Add @InlineMe to deprecated gRPC APIs (#8457) Read more @ https://errorprone.info/docs/inlineme --- stub/BUILD.bazel | 1 + stub/build.gradle | 1 + stub/src/main/java/io/grpc/stub/MetadataUtils.java | 10 ++++++++++ 3 files changed, 12 insertions(+) diff --git a/stub/BUILD.bazel b/stub/BUILD.bazel index 181ffe0485d..c65b01a23dc 100644 --- a/stub/BUILD.bazel +++ b/stub/BUILD.bazel @@ -8,6 +8,7 @@ java_library( "//api", "//context", "@com_google_code_findbugs_jsr305//jar", + "@com_google_errorprone_error_prone_annotations//jar", "@com_google_guava_guava//jar", "@com_google_j2objc_j2objc_annotations//jar", ], diff --git a/stub/build.gradle b/stub/build.gradle index 4076460377c..2b5a6a4edb6 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -10,6 +10,7 @@ description = "gRPC: Stub" dependencies { api project(':grpc-api'), libraries.guava + implementation libraries.errorprone testImplementation libraries.truth, project(':grpc-testing') signature "org.codehaus.mojo.signature:java17:1.0@signature" diff --git a/stub/src/main/java/io/grpc/stub/MetadataUtils.java b/stub/src/main/java/io/grpc/stub/MetadataUtils.java index 94dfb8e56ee..5395ba9b5e3 100644 --- a/stub/src/main/java/io/grpc/stub/MetadataUtils.java +++ b/stub/src/main/java/io/grpc/stub/MetadataUtils.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.InlineMe; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -47,6 +48,10 @@ private MetadataUtils() {} */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1789") @Deprecated + @InlineMe( + replacement = + "stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(extraHeaders))", + imports = "io.grpc.stub.MetadataUtils") public static > T attachHeaders(T stub, Metadata extraHeaders) { return stub.withInterceptors(newAttachHeadersInterceptor(extraHeaders)); } @@ -104,6 +109,11 @@ public void start(Listener responseListener, Metadata headers) { */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1789") @Deprecated + @InlineMe( + replacement = + "stub.withInterceptors(MetadataUtils.newCaptureMetadataInterceptor(headersCapture," + + " trailersCapture))", + imports = "io.grpc.stub.MetadataUtils") public static > T captureMetadata( T stub, AtomicReference headersCapture, From 137bdaa868d487153521b6aeaeaa84c0490667f1 Mon Sep 17 00:00:00 2001 From: apolcyn Date: Fri, 27 Aug 2021 15:30:43 -0700 Subject: [PATCH 49/82] interop-testing: add soak test cases to test service client --- android-interop-testing/build.gradle | 1 + interop-testing/build.gradle | 1 + .../integration/AbstractInteropTest.java | 123 ++++++++++++++++++ .../grpc/testing/integration/TestCases.java | 4 +- .../integration/TestServiceClient.java | 50 +++++++ .../testing/integration/TestCasesTest.java | 4 +- 6 files changed, 181 insertions(+), 2 deletions(-) diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index f1c1d491233..bbf1fcfe99e 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -62,6 +62,7 @@ dependencies { project(':grpc-protobuf-lite'), project(':grpc-stub'), project(':grpc-testing'), + libraries.hdrhistogram, libraries.junit, libraries.truth, libraries.opencensus_contrib_grpc_metrics diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 852d5882cce..944c0daab81 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -27,6 +27,7 @@ dependencies { project(':grpc-stub'), project(':grpc-testing'), project(path: ':grpc-xds', configuration: 'shadow'), + libraries.hdrhistogram, libraries.junit, libraries.truth, libraries.opencensus_contrib_grpc_metrics, diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 693d9b2af7c..33d263e95de 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -127,6 +127,7 @@ import javax.annotation.Nullable; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; +import org.HdrHistogram.Histogram; import org.junit.After; import org.junit.Assert; import org.junit.Assume; @@ -1886,6 +1887,128 @@ public void googleDefaultCredentials( assertResponse(goldenResponse, response); } + private static class SoakIterationResult { + public SoakIterationResult(long latencyMs, Status status) { + this.latencyMs = latencyMs; + this.status = status; + } + + public long getLatencyMs() { + return latencyMs; + } + + public Status getStatus() { + return status; + } + + private long latencyMs = -1; + private Status status = Status.OK; + } + + private SoakIterationResult performOneSoakIteration(boolean resetChannel) throws Exception { + long startNs = System.nanoTime(); + Status status = Status.OK; + ManagedChannel soakChannel = channel; + TestServiceGrpc.TestServiceBlockingStub soakStub = blockingStub; + if (resetChannel) { + soakChannel = createChannel(); + soakStub = TestServiceGrpc.newBlockingStub(soakChannel); + } + try { + final SimpleRequest request = + SimpleRequest.newBuilder() + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, soakStub.unaryCall(request)); + } catch (StatusRuntimeException e) { + status = e.getStatus(); + } + long elapsedNs = System.nanoTime() - startNs; + if (resetChannel) { + soakChannel.shutdownNow(); + soakChannel.awaitTermination(10, TimeUnit.SECONDS); + } + return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); + } + + /** + * Runs large unary RPCs in a loop with configurable failure thresholds + * and channel creation behavior. + */ + public void performSoakTest( + boolean resetChannelPerIteration, + int soakIterations, + int maxFailures, + int maxAcceptablePerIterationLatencyMs, + int overallTimeoutSeconds) + throws Exception { + int iterationsDone = 0; + int totalFailures = 0; + Histogram latencies = new Histogram(4 /* number of significant value digits */); + long startNs = System.nanoTime(); + for (int i = 0; i < soakIterations; i++) { + if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { + break; + } + SoakIterationResult result = performOneSoakIteration(resetChannelPerIteration); + System.err.print( + String.format( + "soak iteration: %d elapsed: %d ms", i, result.getLatencyMs())); + if (!result.getStatus().equals(Status.OK)) { + totalFailures++; + System.err.println(String.format(" failed: %s", result.getStatus())); + } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { + totalFailures++; + System.err.println( + String.format( + " exceeds max acceptable latency: %d", maxAcceptablePerIterationLatencyMs)); + } else { + System.err.println(" succeeded"); + } + iterationsDone++; + latencies.recordValue(result.getLatencyMs()); + } + System.err.println( + String.format( + "soak test ran: %d / %d iterations\n" + + "total failures: %d\n" + + "max failures threshold: %d\n" + + "max acceptable per iteration latency ms: %d\n" + + " p50 soak iteration latency: %d ms\n" + + " p90 soak iteration latency: %d ms\n" + + "p100 soak iteration latency: %d ms\n" + + "See breakdown above for which iterations succeeded, failed, and " + + "why for more info.", + iterationsDone, + soakIterations, + totalFailures, + maxFailures, + maxAcceptablePerIterationLatencyMs, + latencies.getValueAtPercentile(50), + latencies.getValueAtPercentile(90), + latencies.getValueAtPercentile(100))); + // check if we timed out + String timeoutErrorMessage = + String.format( + "soak test consumed all %d seconds of time and quit early, only " + + "having ran %d out of desired %d iterations.", + overallTimeoutSeconds, + iterationsDone, + soakIterations); + assertEquals(timeoutErrorMessage, iterationsDone, soakIterations); + // check if we had too many failures + String tooManyFailuresErrorMessage = + String.format( + "soak test total failures: %d exceeds max failures threshold: %d.", + totalFailures, maxFailures); + assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); + } + protected static void assertSuccess(StreamRecorder recorder) { if (recorder.getError() != null) { throw new AssertionError(recorder.getError()); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java index 2d1648e157a..39afaa99d6e 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java @@ -54,7 +54,9 @@ public enum TestCases { CANCEL_AFTER_FIRST_RESPONSE("cancel on first response"), TIMEOUT_ON_SLEEPING_SERVER("timeout before receiving a response"), VERY_LARGE_REQUEST("very large request"), - PICK_FIRST_UNARY("all requests are sent to one server despite multiple servers are resolved"); + PICK_FIRST_UNARY("all requests are sent to one server despite multiple servers are resolved"), + RPC_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on the same channel"), + CHANNEL_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on a new channel"); private final String description; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index f8546880eae..914db12e5a8 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -86,6 +86,11 @@ public static void main(String[] args) throws Exception { private boolean fullStreamDecompression; private int localHandshakerPort = -1; private Map serviceConfig = null; + private int soakIterations = 10; + private int soakMaxFailures = 0; + private int soakPerIterationMaxAcceptableLatencyMs = 1000; + private int soakOverallTimeoutSeconds = + soakIterations * soakPerIterationMaxAcceptableLatencyMs / 1000; private Tester tester = new Tester(); @@ -150,6 +155,14 @@ void parseArgs(String[] args) throws Exception { @SuppressWarnings("unchecked") Map map = (Map) JsonParser.parse(value); serviceConfig = map; + } else if ("soak_iterations".equals(key)) { + soakIterations = Integer.parseInt(value); + } else if ("soak_max_failures".equals(key)) { + soakMaxFailures = Integer.parseInt(value); + } else if ("soak_per_iteration_max_acceptable_latency_ms".equals(key)) { + soakPerIterationMaxAcceptableLatencyMs = Integer.parseInt(value); + } else if ("soak_overall_timeout_seconds".equals(key)) { + soakOverallTimeoutSeconds = Integer.parseInt(value); } else { System.err.println("Unknown argument: " + key); usage = true; @@ -196,6 +209,23 @@ void parseArgs(String[] args) throws Exception { + "\n --service_config_json=SERVICE_CONFIG_JSON" + "\n Disables service config lookups and sets the provided " + "\n string as the default service config." + + "\n --soak_iterations The number of iterations to use for the two soak " + + "\n tests: rpc_soak and channel_soak. Default " + + c.soakIterations + + "\n --soak_max_failures The number of iterations in soak tests that are " + + "\n allowed to fail (either due to non-OK status code or " + + "\n exceeding the per-iteration max acceptable latency). " + + "\n Default " + c.soakMaxFailures + + "\n --soak_per_iteration_max_acceptable_latency_ms " + + "\n The number of milliseconds a single iteration in the " + + "\n two soak tests (rpc_soak and channel_soak) should " + + "\n take. Default " + + c.soakPerIterationMaxAcceptableLatencyMs + + "\n --soak_overall_timeout_seconds " + + "\n The overall number of seconds after which a soak test " + + "\n should stop and fail, if the desired number of " + + "\n iterations have not yet completed. Default " + + c.soakOverallTimeoutSeconds ); System.exit(1); } @@ -412,6 +442,26 @@ private void runTest(TestCases testCase) throws Exception { break; } + case RPC_SOAK: { + tester.performSoakTest( + false /* resetChannelPerIteration */, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakOverallTimeoutSeconds); + break; + } + + case CHANNEL_SOAK: { + tester.performSoakTest( + true /* resetChannelPerIteration */, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakOverallTimeoutSeconds); + break; + } + default: throw new IllegalArgumentException("Unknown test case: " + testCase); } diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java index 4e511c1cfe5..14a98514918 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java @@ -73,7 +73,9 @@ public void testCaseNamesShouldMapToEnums() { "client_compressed_unary_noprobe", "client_compressed_streaming_noprobe", "very_large_request", - "pick_first_unary" + "pick_first_unary", + "channel_soak", + "rpc_soak" }; assertEquals(testCases.length + additionalTestCases.length, TestCases.values().length); From b3ef588520f302a15e243a35714b572663a9e19c Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Fri, 27 Aug 2021 16:35:23 -0700 Subject: [PATCH 50/82] Fix Java Style (#8458) --- .../java/io/grpc/xds/GoogleCloudToProdNameResolver.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java index 9abdf12f175..2845d0a00e8 100644 --- a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java @@ -184,11 +184,11 @@ public void run() { ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); } ImmutableMap.Builder serverBuilder = ImmutableMap.builder(); - String server_uri = "directpath-pa.googleapis.com"; + String serverUri = "directpath-pa.googleapis.com"; if (serverUriOverride != null && serverUriOverride.length() > 0) { - server_uri = serverUriOverride; + serverUri = serverUriOverride; } - serverBuilder.put("server_uri", server_uri); + serverBuilder.put("server_uri", serverUri); serverBuilder.put("channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default"))); serverBuilder.put("server_features", ImmutableList.of("xds_v3")); From 5cc94a548870c1db00c55bc737b004483356da82 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 30 Aug 2021 11:15:05 -0700 Subject: [PATCH 51/82] stub: Document StreamObserver is an async API Missing docs were brought up in #8423 --- .../java/io/grpc/stub/ClientCallStreamObserver.java | 3 ++- .../java/io/grpc/stub/ServerCallStreamObserver.java | 3 ++- stub/src/main/java/io/grpc/stub/StreamObserver.java | 10 ++++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java b/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java index ea09bb99d55..5fb70c76de3 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java @@ -20,7 +20,8 @@ /** * A refinement of {@link CallStreamObserver} that allows for lower-level interaction with - * client calls. + * client calls. An instance of this class is obtained via {@link ClientResponseObserver}, or by + * manually casting the {@code StreamObserver} returned by a stub. * *

Like {@code StreamObserver}, implementations are not required to be thread-safe; if multiple * threads will be writing to an instance concurrently, the application must synchronize its calls. diff --git a/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java b/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java index 3ba1bf563ef..a4d4564a46d 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java @@ -18,7 +18,8 @@ /** * A refinement of {@link CallStreamObserver} to allows for interaction with call - * cancellation events on the server side. + * cancellation events on the server side. An instance of this class is obtained by casting the + * {@code StreamObserver} passed as an argument to service implementations. * *

Like {@code StreamObserver}, implementations are not required to be thread-safe; if multiple * threads will be writing to an instance concurrently, the application must synchronize its calls. diff --git a/stub/src/main/java/io/grpc/stub/StreamObserver.java b/stub/src/main/java/io/grpc/stub/StreamObserver.java index 92040d9bc58..cf7cc258961 100644 --- a/stub/src/main/java/io/grpc/stub/StreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/StreamObserver.java @@ -31,6 +31,16 @@ * not need to be synchronized together; incoming and outgoing directions are independent. * Since individual {@code StreamObserver}s are not thread-safe, if multiple threads will be * writing to a {@code StreamObserver} concurrently, the application must synchronize calls. + * + *

This API is asynchronous, so methods may return before the operation completes. The API + * provides no guarantees for how quickly an operation will complete, so utilizing flow control via + * {@link ClientCallStreamObserver} and {@link ServerCallStreamObserver} to avoid excessive + * buffering is recommended for streaming RPCs. gRPC's implementation of {@code onError()} on + * client-side causes the RPC to be cancelled and discards all messages, so completes quickly. + * + *

gRPC guarantees it does not block on I/O in its implementation, but applications are allowed + * to perform blocking operations in their implementations. However, doing so will delay other + * callbacks because the methods cannot be called concurrently. */ public interface StreamObserver { /** From 40f70ca3c1c3a5d4189da3e9788d6ccbeaf1e879 Mon Sep 17 00:00:00 2001 From: Lidi Zheng Date: Tue, 31 Aug 2021 10:35:51 -0700 Subject: [PATCH 52/82] Change to a non-workload-identity GKE cluster (#8461) Part of grpc/grpc#27189 and b/198291728. By disabling the workload identity, we should be able to run tests faster and avoid future IAM policy size issue. Kokoro run: https://fusion2.corp.google.com/invocations/b52b1684-47de-406d-a9f6-644909755f34/targets --- buildscripts/kokoro/xds_url_map.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/buildscripts/kokoro/xds_url_map.sh b/buildscripts/kokoro/xds_url_map.sh index b791528461d..d8487582980 100755 --- a/buildscripts/kokoro/xds_url_map.sh +++ b/buildscripts/kokoro/xds_url_map.sh @@ -4,8 +4,8 @@ set -eo pipefail # Constants readonly GITHUB_REPOSITORY_NAME="grpc-java" # GKE Cluster -readonly GKE_CLUSTER_NAME="interop-test-psm-sec-v2-us-central1-a" -readonly GKE_CLUSTER_ZONE="us-central1-a" +readonly GKE_CLUSTER_NAME="interop-test-psm-basic" +readonly GKE_CLUSTER_ZONE="us-central1-c" ## xDS test client Docker images readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-client" readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" From 4fa612ae3d8c3d38e84184f633fc4fbf02b93a35 Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Tue, 31 Aug 2021 19:45:37 -0400 Subject: [PATCH 53/82] xds: fix java style --- .../java/io/grpc/xds/XdsServerWrapperTest.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index f4ab7025e43..4c91d5758f9 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -221,11 +221,11 @@ public void run() { }); String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); - HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( 0L, Collections.singletonList(createVirtualHost("virtual-host-0")), new ArrayList()); FilterChain f0 = createFilterChain("filter-chain-0", createRds("rds")); - FilterChain f1 = createFilterChain("filter-chain-1", hcm_virtual); + FilterChain f1 = createFilterChain("filter-chain-1", hcmVirtual); xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); xdsServerWrapper.shutdown(); when(mockServer.isTerminated()).thenReturn(true); @@ -402,9 +402,9 @@ public void run() { String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); VirtualHost virtualHost = createVirtualHost("virtual-host-0"); - HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( 0L, Collections.singletonList(virtualHost), new ArrayList()); - EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); xdsClient.rdsCount = new CountDownLatch(3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); @@ -429,7 +429,7 @@ public void run() { verify(mockServer).start(); assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( f0, ServerRoutingConfig.create( - hcm_virtual.httpFilterConfigs(), hcm_virtual.virtualHosts()), + hcmVirtual.httpFilterConfigs(), hcmVirtual.virtualHosts()), f2, ServerRoutingConfig.create(f2.getHttpConnectionManager().httpFilterConfigs(), Collections.singletonList(createVirtualHost("virtual-host-1"))) )); @@ -527,9 +527,9 @@ public void run() { String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); VirtualHost virtualHost = createVirtualHost("virtual-host-0"); - HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( 0L, Collections.singletonList(virtualHost), new ArrayList()); - EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); xdsClient.rdsCount.await(); From 7cde473efa4ffa3f530baf9e9a07298c80ed543f Mon Sep 17 00:00:00 2001 From: Terry Wilson Date: Wed, 1 Sep 2021 09:49:20 -0700 Subject: [PATCH 54/82] core/auth: Remove CallCredentials2 (#8464) - Removes CallCredentials2 - Removes CallCredentials2ApplyingTest - Adds two tests from CallCredentials2ApplyingTest to CallCredentialsApplyingTest - Updates GoogleAuthLibraryCallCredentials to extend from CallCredentials instead of CallCredentials2 --- .../main/java/io/grpc/CallCredentials2.java | 73 ---- .../GoogleAuthLibraryCallCredentials.java | 6 +- .../CallCredentials2ApplyingTest.java | 351 ------------------ .../internal/CallCredentialsApplyingTest.java | 45 +++ 4 files changed, 47 insertions(+), 428 deletions(-) delete mode 100644 api/src/main/java/io/grpc/CallCredentials2.java delete mode 100644 core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java diff --git a/api/src/main/java/io/grpc/CallCredentials2.java b/api/src/main/java/io/grpc/CallCredentials2.java deleted file mode 100644 index fdb7f51070a..00000000000 --- a/api/src/main/java/io/grpc/CallCredentials2.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2016 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc; - -import java.util.concurrent.Executor; - -/** - * The new interface for {@link CallCredentials}. - * - *

THIS CLASS NAME IS TEMPORARY and is part of a migration. This class will BE DELETED as it - * replaces {@link CallCredentials} in short-term. THIS CLASS IS ONLY REFERENCED BY IMPLEMENTIONS. - * All consumers should be always referencing {@link CallCredentials}. - * - * @deprecated the new interface has been promoted into {@link CallCredentials}. Implementations - * should switch back to "{@code extends CallCredentials}". - */ -@Deprecated -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4901") -public abstract class CallCredentials2 extends CallCredentials { - /** - * Pass the credential data to the given {@link MetadataApplier}, which will propagate it to the - * request metadata. - * - *

It is called for each individual RPC, within the {@link Context} of the call, before the - * stream is about to be created on a transport. Implementations should not block in this - * method. If metadata is not immediately available, e.g., needs to be fetched from network, the - * implementation may give the {@code applier} to an asynchronous task which will eventually call - * the {@code applier}. The RPC proceeds only after the {@code applier} is called. - * - * @param requestInfo request-related information - * @param appExecutor The application thread-pool. It is provided to the implementation in case it - * needs to perform blocking operations. - * @param applier The outlet of the produced headers. It can be called either before or after this - * method returns. - */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1914") - public abstract void applyRequestMetadata( - RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier); - - @Override - public final void applyRequestMetadata( - RequestInfo requestInfo, Executor appExecutor, - final CallCredentials.MetadataApplier applier) { - applyRequestMetadata(requestInfo, appExecutor, new MetadataApplier() { - @Override - public void apply(Metadata headers) { - applier.apply(headers); - } - - @Override - public void fail(Status status) { - applier.fail(status); - } - }); - } - - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1914") - public abstract static class MetadataApplier extends CallCredentials.MetadataApplier {} -} diff --git a/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java b/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java index 852fba73b20..4b95a6c7f4d 100644 --- a/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java +++ b/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java @@ -42,11 +42,9 @@ import javax.annotation.Nullable; /** - * Wraps {@link Credentials} as a {@link CallCredentials}. + * Wraps {@link Credentials} as a {@link io.grpc.CallCredentials}. */ -// TODO(zhangkun83): remove the suppression after we change the base class to CallCredential -@SuppressWarnings("deprecation") -final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials2 { +final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials { private static final Logger log = Logger.getLogger(GoogleAuthLibraryCallCredentials.class.getName()); private static final JwtHelper jwtHelper diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java deleted file mode 100644 index 963a586319b..00000000000 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ /dev/null @@ -1,351 +0,0 @@ -/* - * Copyright 2016 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.internal; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.Attributes; -import io.grpc.CallCredentials.MetadataApplier; -import io.grpc.CallCredentials.RequestInfo; -import io.grpc.CallOptions; -import io.grpc.ChannelLogger; -import io.grpc.ClientStreamTracer; -import io.grpc.IntegerMarshaller; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.SecurityLevel; -import io.grpc.Status; -import io.grpc.StringMarshaller; -import java.net.SocketAddress; -import java.util.concurrent.Executor; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; -import org.mockito.stubbing.Answer; - -/** - * Unit test for {@link CallCredentials2} applying functionality implemented by {@link - * CallCredentialsApplyingTransportFactory} and {@link MetadataApplierImpl}. - */ -@SuppressWarnings("deprecation") -@RunWith(JUnit4.class) -public class CallCredentials2ApplyingTest { - @Rule - public final MockitoRule mocks = MockitoJUnit.rule(); - - @Mock - private ClientTransportFactory mockTransportFactory; - - @Mock - private ConnectionClientTransport mockTransport; - - @Mock - private ClientStream mockStream; - - @Mock - private io.grpc.CallCredentials2 mockCreds; - - @Mock - private Executor mockExecutor; - - @Mock - private SocketAddress address; - - // Noop logger; - @Mock - private ChannelLogger channelLogger; - - private static final String AUTHORITY = "testauthority"; - private static final String USER_AGENT = "testuseragent"; - private static final Attributes.Key ATTR_KEY = Attributes.Key.create("somekey"); - private static final String ATTR_VALUE = "somevalue"; - private static final MethodDescriptor method = - MethodDescriptor.newBuilder() - .setType(MethodDescriptor.MethodType.UNKNOWN) - .setFullMethodName("service/method") - .setRequestMarshaller(new StringMarshaller()) - .setResponseMarshaller(new IntegerMarshaller()) - .build(); - private static final Metadata.Key ORIG_HEADER_KEY = - Metadata.Key.of("header1", Metadata.ASCII_STRING_MARSHALLER); - private static final String ORIG_HEADER_VALUE = "some original header value"; - private static final Metadata.Key CREDS_KEY = - Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER); - private static final String CREDS_VALUE = "some credentials"; - private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { - new ClientStreamTracer() {} - }; - - private final Metadata origHeaders = new Metadata(); - private ForwardingConnectionClientTransport transport; - private CallOptions callOptions; - - @Before - public void setUp() { - ClientTransportFactory.ClientTransportOptions clientTransportOptions = - new ClientTransportFactory.ClientTransportOptions() - .setAuthority(AUTHORITY) - .setUserAgent(USER_AGENT); - - origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); - when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) - .thenReturn(mockTransport); - when(mockTransport.newStream( - same(method), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any())) - .thenReturn(mockStream); - ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( - mockTransportFactory, null, mockExecutor); - transport = (ForwardingConnectionClientTransport) - transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); - callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds); - verify(mockTransportFactory).newClientTransport(address, clientTransportOptions, channelLogger); - assertSame(mockTransport, transport.delegate()); - } - - @Test - public void parameterPropagation_base() { - Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); - when(mockTransport.getAttributes()).thenReturn(transportAttrs); - - transport.newStream(method, origHeaders, callOptions, tracers); - - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - infoCaptor.capture(), same(mockExecutor), - any(io.grpc.CallCredentials2.MetadataApplier.class)); - RequestInfo info = infoCaptor.getValue(); - assertSame(method, info.getMethodDescriptor()); - assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); - assertSame(AUTHORITY, info.getAuthority()); - assertSame(SecurityLevel.NONE, info.getSecurityLevel()); - } - - @Test - public void parameterPropagation_transportSetSecurityLevel() { - Attributes transportAttrs = Attributes.newBuilder() - .set(ATTR_KEY, ATTR_VALUE) - .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.INTEGRITY) - .build(); - when(mockTransport.getAttributes()).thenReturn(transportAttrs); - - transport.newStream(method, origHeaders, callOptions, tracers); - - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - infoCaptor.capture(), same(mockExecutor), - any(io.grpc.CallCredentials2.MetadataApplier.class)); - RequestInfo info = infoCaptor.getValue(); - assertSame(method, info.getMethodDescriptor()); - assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); - assertSame(AUTHORITY, info.getAuthority()); - assertSame(SecurityLevel.INTEGRITY, info.getSecurityLevel()); - } - - @Test - public void parameterPropagation_callOptionsSetAuthority() { - Attributes transportAttrs = Attributes.newBuilder() - .set(ATTR_KEY, ATTR_VALUE) - .build(); - when(mockTransport.getAttributes()).thenReturn(transportAttrs); - Executor anotherExecutor = mock(Executor.class); - - transport.newStream( - method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), - tracers); - - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - infoCaptor.capture(), same(anotherExecutor), - any(io.grpc.CallCredentials2.MetadataApplier.class)); - RequestInfo info = infoCaptor.getValue(); - assertSame(method, info.getMethodDescriptor()); - assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); - assertEquals("calloptions-authority", info.getAuthority()); - assertSame(SecurityLevel.NONE, info.getSecurityLevel()); - } - - @Test - public void credentialThrows() { - final RuntimeException ex = new RuntimeException(); - when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); - doThrow(ex).when(mockCreds).applyRequestMetadata( - any(RequestInfo.class), same(mockExecutor), - any(io.grpc.CallCredentials2.MetadataApplier.class)); - - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - - verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any()); - assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); - assertSame(ex, stream.getError().getCause()); - transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) - instanceof FailingClientStream); - verify(mockTransport).shutdown(Status.UNAVAILABLE); - } - - @Test - public void applyMetadata_inline() { - when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - MetadataApplier applier = (MetadataApplier) invocation.getArguments()[2]; - Metadata headers = new Metadata(); - headers.put(CREDS_KEY, CREDS_VALUE); - applier.apply(headers); - return null; - } - }).when(mockCreds).applyRequestMetadata( - any(RequestInfo.class), same(mockExecutor), - any(io.grpc.CallCredentials2.MetadataApplier.class)); - - ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - - verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); - assertSame(mockStream, stream); - assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); - assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); - transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) - instanceof FailingClientStream); - verify(mockTransport).shutdown(Status.UNAVAILABLE); - } - - @Test - public void fail_inline() { - final Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); - when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - MetadataApplier applier = (MetadataApplier) invocation.getArguments()[2]; - applier.fail(error); - return null; - } - }).when(mockCreds).applyRequestMetadata( - any(RequestInfo.class), same(mockExecutor), - any(io.grpc.CallCredentials2.MetadataApplier.class)); - - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - - verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any()); - assertSame(error, stream.getError()); - transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) - instanceof FailingClientStream); - verify(mockTransport).shutdownNow(Status.UNAVAILABLE); - } - - @Test - public void applyMetadata_delayed() { - when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); - - // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream( - method, origHeaders, callOptions, tracers); - - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any()); - - transport.shutdown(Status.UNAVAILABLE); - verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); - - Metadata headers = new Metadata(); - headers.put(CREDS_KEY, CREDS_VALUE); - applierCaptor.getValue().apply(headers); - - verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); - assertSame(mockStream, stream.getRealStream()); - assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); - assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); - assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) - instanceof FailingClientStream); - verify(mockTransport).shutdown(Status.UNAVAILABLE); - } - - @Test - public void fail_delayed() { - when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); - - // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream( - method, origHeaders, callOptions, tracers); - - ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - - Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); - applierCaptor.getValue().fail(error); - - verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any()); - FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); - assertSame(error, failingStream.getError()); - transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) - instanceof FailingClientStream); - verify(mockTransport).shutdown(Status.UNAVAILABLE); - } - - @Test - public void noCreds() { - callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - - verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); - assertSame(mockStream, stream); - assertNull(origHeaders.get(CREDS_KEY)); - assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); - transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) - instanceof FailingClientStream); - verify(mockTransport).shutdown(Status.UNAVAILABLE); - } -} diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index ef49e66bf2d..2f0ce1070b1 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -176,6 +176,51 @@ public void parameterPropagation_overrideByCallOptions() { assertSame(SecurityLevel.INTEGRITY, info.getSecurityLevel()); } + @Test + public void parameterPropagation_transportSetSecurityLevel() { + Attributes transportAttrs = Attributes.newBuilder() + .set(ATTR_KEY, ATTR_VALUE) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.INTEGRITY) + .build(); + when(mockTransport.getAttributes()).thenReturn(transportAttrs); + + transport.newStream(method, origHeaders, callOptions, tracers); + + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + infoCaptor.capture(), same(mockExecutor), + any(io.grpc.CallCredentials.MetadataApplier.class)); + RequestInfo info = infoCaptor.getValue(); + assertSame(method, info.getMethodDescriptor()); + assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); + assertSame(AUTHORITY, info.getAuthority()); + assertSame(SecurityLevel.INTEGRITY, info.getSecurityLevel()); + } + + @Test + public void parameterPropagation_callOptionsSetAuthority() { + Attributes transportAttrs = Attributes.newBuilder() + .set(ATTR_KEY, ATTR_VALUE) + .build(); + when(mockTransport.getAttributes()).thenReturn(transportAttrs); + Executor anotherExecutor = mock(Executor.class); + + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); + + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + infoCaptor.capture(), same(anotherExecutor), + any(io.grpc.CallCredentials.MetadataApplier.class)); + RequestInfo info = infoCaptor.getValue(); + assertSame(method, info.getMethodDescriptor()); + assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); + assertEquals("calloptions-authority", info.getAuthority()); + assertSame(SecurityLevel.NONE, info.getSecurityLevel()); + } + @Test public void credentialThrows() { final RuntimeException ex = new RuntimeException(); From b0b250024f659b1879cbed6bc0ef7d833b1b7303 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 1 Sep 2021 10:49:33 -0700 Subject: [PATCH 55/82] xds: fix implementation to comply with gRFC for security (#8468) --- .../main/java/io/grpc/xds/Bootstrapper.java | 6 +- .../java/io/grpc/xds/ClientXdsClient.java | 124 +++--- .../CertProviderClientSslContextProvider.java | 5 +- .../CertProviderServerSslContextProvider.java | 5 +- .../CertProviderSslContextProvider.java | 14 +- .../io/grpc/xds/ClientXdsClientDataTest.java | 367 ++++++------------ .../io/grpc/xds/ClientXdsClientTestBase.java | 9 +- .../sds/CommonTlsContextTestsUtil.java | 2 +- 8 files changed, 204 insertions(+), 328 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/Bootstrapper.java b/xds/src/main/java/io/grpc/xds/Bootstrapper.java index 08d11d174d4..e1ba4aa176e 100644 --- a/xds/src/main/java/io/grpc/xds/Bootstrapper.java +++ b/xds/src/main/java/io/grpc/xds/Bootstrapper.java @@ -84,7 +84,8 @@ public static class CertificateProviderInfo { private final String pluginName; private final Map config; - CertificateProviderInfo(String pluginName, Map config) { + @VisibleForTesting + public CertificateProviderInfo(String pluginName, Map config) { this.pluginName = checkNotNull(pluginName, "pluginName"); this.config = checkNotNull(config, "config"); } @@ -135,8 +136,9 @@ public Node getNode() { } /** Returns the cert-providers config map. */ + @Nullable public Map getCertProviders() { - return Collections.unmodifiableMap(certProviders); + return certProviders == null ? null : Collections.unmodifiableMap(certProviders); } @Nullable diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 93bfeb36b34..47e73d11266 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -266,14 +266,20 @@ private LdsUpdate processClientSideListener( private LdsUpdate processServerSideListener( Listener proto, Set rdsResources, boolean parseHttpFilter) throws ResourceInvalidException { + Set certProviderInstances = null; + if (getBootstrapInfo() != null && getBootstrapInfo().getCertProviders() != null) { + certProviderInstances = getBootstrapInfo().getCertProviders().keySet(); + } return LdsUpdate.forTcpListener(parseServerSideListener( - proto, rdsResources, tlsContextManager, filterRegistry, parseHttpFilter)); + proto, rdsResources, tlsContextManager, filterRegistry, certProviderInstances, + parseHttpFilter)); } @VisibleForTesting static EnvoyServerProtoData.Listener parseServerSideListener( Listener proto, Set rdsResources, TlsContextManager tlsContextManager, - FilterRegistry filterRegistry, boolean parseHttpFilter) throws ResourceInvalidException { + FilterRegistry filterRegistry, Set certProviderInstances, boolean parseHttpFilter) + throws ResourceInvalidException { if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) { throw new ResourceInvalidException( "Listener " + proto.getName() + " with invalid traffic direction: " @@ -309,13 +315,13 @@ static EnvoyServerProtoData.Listener parseServerSideListener( for (io.envoyproxy.envoy.config.listener.v3.FilterChain fc : proto.getFilterChainsList()) { filterChains.add( parseFilterChain(fc, rdsResources, tlsContextManager, filterRegistry, uniqueSet, - parseHttpFilter)); + certProviderInstances, parseHttpFilter)); } FilterChain defaultFilterChain = null; if (proto.hasDefaultFilterChain()) { defaultFilterChain = parseFilterChain( proto.getDefaultFilterChain(), rdsResources, tlsContextManager, filterRegistry, - null, parseHttpFilter); + null, certProviderInstances, parseHttpFilter); } return new EnvoyServerProtoData.Listener( @@ -326,7 +332,7 @@ static EnvoyServerProtoData.Listener parseServerSideListener( static FilterChain parseFilterChain( io.envoyproxy.envoy.config.listener.v3.FilterChain proto, Set rdsResources, TlsContextManager tlsContextManager, FilterRegistry filterRegistry, - Set uniqueSet, boolean parseHttpFilters) + Set uniqueSet, Set certProviderInstances, boolean parseHttpFilters) throws ResourceInvalidException { io.grpc.xds.HttpConnectionManager httpConnectionManager = null; HashSet uniqueNames = new HashSet<>(); @@ -380,7 +386,7 @@ static FilterChain parseFilterChain( } downstreamTlsContext = EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( - validateDownstreamTlsContext(downstreamTlsContextProto)); + validateDownstreamTlsContext(downstreamTlsContextProto, certProviderInstances)); } String name = proto.getName(); @@ -399,13 +405,12 @@ static FilterChain parseFilterChain( } @VisibleForTesting - static io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - validateDownstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - downstreamTlsContext) + static DownstreamTlsContext validateDownstreamTlsContext( + DownstreamTlsContext downstreamTlsContext, Set certProviderInstances) throws ResourceInvalidException { if (downstreamTlsContext.hasCommonTlsContext()) { - validateCommonTlsContext(downstreamTlsContext.getCommonTlsContext(), true); + validateCommonTlsContext(downstreamTlsContext.getCommonTlsContext(), certProviderInstances, + true); } else { throw new ResourceInvalidException( "common-tls-context is required in downstream-tls-context"); @@ -414,22 +419,6 @@ static FilterChain parseFilterChain( throw new ResourceInvalidException( "downstream-tls-context with require-sni is not supported"); } - if (downstreamTlsContext.hasSessionTicketKeys()) { - throw new ResourceInvalidException( - "downstream-tls-context with session_ticket_keys is not supported"); - } - if (downstreamTlsContext.hasSessionTicketKeysSdsSecretConfig()) { - throw new ResourceInvalidException( - "downstream-tls-context with session_ticket_keys_sds_secret_config is not supported"); - } - if (downstreamTlsContext.hasDisableStatelessSessionResumption()) { - throw new ResourceInvalidException( - "downstream-tls-context with disable_stateless_session_resumption is not supported"); - } - if (downstreamTlsContext.hasSessionTimeout()) { - throw new ResourceInvalidException( - "downstream-tls-context with session_timeout is not supported"); - } DownstreamTlsContext.OcspStaplePolicy ocspStaplePolicy = downstreamTlsContext .getOcspStaplePolicy(); if (ocspStaplePolicy != DownstreamTlsContext.OcspStaplePolicy.UNRECOGNIZED @@ -444,30 +433,22 @@ static FilterChain parseFilterChain( @VisibleForTesting static io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext validateUpstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext, + Set certProviderInstances) throws ResourceInvalidException { if (upstreamTlsContext.hasCommonTlsContext()) { - validateCommonTlsContext(upstreamTlsContext.getCommonTlsContext(), false); + validateCommonTlsContext(upstreamTlsContext.getCommonTlsContext(), certProviderInstances, + false); } else { throw new ResourceInvalidException("common-tls-context is required in upstream-tls-context"); } - if (!Strings.isNullOrEmpty(upstreamTlsContext.getSni())) { - throw new ResourceInvalidException("upstream-tls-context with sni is not supported"); - } - if (upstreamTlsContext.getAllowRenegotiation()) { - throw new ResourceInvalidException( - "upstream-tls-context with allow_renegotiation is not supported"); - } - if (upstreamTlsContext.hasMaxSessionKeys()) { - throw new ResourceInvalidException( - "upstream-tls-context with max_session_keys is not supported"); - } return upstreamTlsContext; } @VisibleForTesting static void validateCommonTlsContext( - CommonTlsContext commonTlsContext, boolean server) throws ResourceInvalidException { + CommonTlsContext commonTlsContext, Set certProviderInstances, boolean server) + throws ResourceInvalidException { if (commonTlsContext.hasCustomHandshaker()) { throw new ResourceInvalidException( "common-tls-context with custom_handshaker is not supported"); @@ -492,6 +473,7 @@ static void validateCommonTlsContext( "common-tls-context with validation_context_certificate_provider_instance is not" + " supported"); } + String certInstanceName = null; if (!commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { if (server) { throw new ResourceInvalidException( @@ -509,7 +491,18 @@ static void validateCommonTlsContext( throw new ResourceInvalidException( "common-tls-context with tls_certificate_certificate_provider is not supported"); } + } else { + certInstanceName = commonTlsContext.getTlsCertificateCertificateProviderInstance() + .getInstanceName(); } + if (certInstanceName != null) { + if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { + throw new ResourceInvalidException( + "CertificateProvider instance name '" + certInstanceName + + "' not defined in the bootstrap file."); + } + } + String rootCaInstanceName = null; if (!commonTlsContext.hasCombinedValidationContext()) { if (!server) { throw new ResourceInvalidException( @@ -523,6 +516,8 @@ static void validateCommonTlsContext( "validation_context_certificate_provider_instance is required in" + " combined_validation_context"); } + rootCaInstanceName = combinedCertificateValidationContext + .getValidationContextCertificateProviderInstance().getInstanceName(); if (combinedCertificateValidationContext.hasDefaultValidationContext()) { CertificateValidationContext certificateValidationContext = combinedCertificateValidationContext.getDefaultValidationContext(); @@ -530,14 +525,6 @@ static void validateCommonTlsContext( throw new ResourceInvalidException( "match_subject_alt_names only allowed in upstream_tls_context"); } - if (certificateValidationContext.hasTrustedCa()) { - throw new ResourceInvalidException( - "trusted_ca in default_validation_context is not supported"); - } - if (certificateValidationContext.hasWatchedDirectory()) { - throw new ResourceInvalidException( - "watched_directory in default_validation_context is not supported"); - } if (certificateValidationContext.getVerifyCertificateSpkiCount() > 0) { throw new ResourceInvalidException( "verify_certificate_spki in default_validation_context is not supported"); @@ -554,23 +541,19 @@ static void validateCommonTlsContext( if (certificateValidationContext.hasCrl()) { throw new ResourceInvalidException("crl in default_validation_context is not supported"); } - if (certificateValidationContext.getAllowExpiredCertificate()) { - throw new ResourceInvalidException( - "allow_expired_certificate in default_validation_context is not supported"); - } - CertificateValidationContext.TrustChainVerification trustChainVerification - = certificateValidationContext.getTrustChainVerification(); - if (trustChainVerification - != CertificateValidationContext.TrustChainVerification.VERIFY_TRUST_CHAIN) { - throw new ResourceInvalidException( - "Only VERIFY_TRUST_CHAIN for trust_chain_verification supported"); - } if (certificateValidationContext.hasCustomValidatorConfig()) { throw new ResourceInvalidException( "custom_validator_config in default_validation_context is not supported"); } } } + if (rootCaInstanceName != null) { + if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { + throw new ResourceInvalidException( + "ValidationContextProvider instance name '" + rootCaInstanceName + + "' not defined in the bootstrap file."); + } + } } private static void checkForUniqueness(Set uniqueSet, @@ -1397,7 +1380,11 @@ protected void handleCdsResponse(String versionInfo, List resources, String // Process Cluster into CdsUpdate. CdsUpdate cdsUpdate; try { - cdsUpdate = parseCluster(cluster, retainedEdsResources); + Set certProviderInstances = null; + if (getBootstrapInfo() != null && getBootstrapInfo().getCertProviders() != null) { + certProviderInstances = getBootstrapInfo().getCertProviders().keySet(); + } + cdsUpdate = parseCluster(cluster, retainedEdsResources, certProviderInstances); } catch (ResourceInvalidException e) { errors.add( "CDS response Cluster '" + clusterName + "' validation error: " + e.getMessage()); @@ -1426,12 +1413,14 @@ protected void handleCdsResponse(String versionInfo, List resources, String } @VisibleForTesting - static CdsUpdate parseCluster(Cluster cluster, Set retainedEdsResources) + static CdsUpdate parseCluster(Cluster cluster, Set retainedEdsResources, + Set certProviderInstances) throws ResourceInvalidException { StructOrError structOrError; switch (cluster.getClusterDiscoveryTypeCase()) { case TYPE: - structOrError = parseNonAggregateCluster(cluster, retainedEdsResources); + structOrError = parseNonAggregateCluster(cluster, retainedEdsResources, + certProviderInstances); break; case CLUSTER_TYPE: structOrError = parseAggregateCluster(cluster); @@ -1494,7 +1483,7 @@ private static StructOrError parseAggregateCluster(Cluster cl } private static StructOrError parseNonAggregateCluster( - Cluster cluster, Set edsResources) { + Cluster cluster, Set edsResources, Set certProviderInstances) { String clusterName = cluster.getName(); String lrsServerName = null; Long maxConcurrentRequests = null; @@ -1517,6 +1506,10 @@ private static StructOrError parseNonAggregateCluster( } } } + if (cluster.getTransportSocketMatchesCount() > 0) { + return StructOrError.fromError("Cluster " + clusterName + + ": transport-socket-matches not supported."); + } if (cluster.hasTransportSocket()) { if (!TRANSPORT_SOCKET_NAME_TLS.equals(cluster.getTransportSocket().getName())) { return StructOrError.fromError("transport-socket with name " @@ -1527,7 +1520,8 @@ private static StructOrError parseNonAggregateCluster( validateUpstreamTlsContext( unpackCompatibleType(cluster.getTransportSocket().getTypedConfig(), io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext.class, - TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2))); + TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), + certProviderInstances)); } catch (InvalidProtocolBufferException | ResourceInvalidException e) { return StructOrError.fromError( "Cluster " + clusterName + ": malformed UpstreamTlsContext: " + e); diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java index 1dc7be1be33..2ee21e7db6a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java @@ -32,6 +32,7 @@ import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; import java.util.Map; +import javax.annotation.Nullable; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ @Internal @@ -39,7 +40,7 @@ public final class CertProviderClientSslContextProvider extends CertProviderSslC private CertProviderClientSslContextProvider( Node node, - Map certProviders, + @Nullable Map certProviders, CommonTlsContext.CertificateProviderInstance certInstance, CommonTlsContext.CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, @@ -90,7 +91,7 @@ public static Factory getInstance() { public CertProviderClientSslContextProvider getProvider( UpstreamTlsContext upstreamTlsContext, Node node, - Map certProviders) { + @Nullable Map certProviders) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); CommonTlsContext.CertificateProviderInstance rootCertInstance = null; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java index 78e825f60fd..1f33e1de789 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java @@ -35,6 +35,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Map; +import javax.annotation.Nullable; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ @Internal @@ -42,7 +43,7 @@ public final class CertProviderServerSslContextProvider extends CertProviderSslC private CertProviderServerSslContextProvider( Node node, - Map certProviders, + @Nullable Map certProviders, CommonTlsContext.CertificateProviderInstance certInstance, CommonTlsContext.CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, @@ -93,7 +94,7 @@ public static Factory getInstance() { public CertProviderServerSslContextProvider getProvider( DownstreamTlsContext downstreamTlsContext, Node node, - Map certProviders) { + @Nullable Map certProviders) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); CommonTlsContext.CertificateProviderInstance rootCertInstance = null; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java index eef5ee551e7..1af9e1670d3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java @@ -42,7 +42,7 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider protected CertProviderSslContextProvider( Node node, - Map certProviders, + @Nullable Map certProviders, CertificateProviderInstance certInstance, CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, @@ -56,8 +56,8 @@ protected CertProviderSslContextProvider( certInstanceName = certInstance.getInstanceName(); CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, certInstanceName); - certHandle = - certificateProviderStore.createOrGetProvider( + certHandle = certProviderInstanceConfig == null ? null + : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.getPluginName(), certProviderInstanceConfig.getConfig(), @@ -71,8 +71,8 @@ protected CertProviderSslContextProvider( && !rootCertInstance.getInstanceName().equals(certInstanceName)) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); - rootCertHandle = - certificateProviderStore.createOrGetProvider( + rootCertHandle = certProviderInstanceConfig == null ? null + : certificateProviderStore.createOrGetProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.getPluginName(), certProviderInstanceConfig.getConfig(), @@ -84,8 +84,8 @@ protected CertProviderSslContextProvider( } private static CertificateProviderInfo getCertProviderConfig( - Map certProviders, String pluginInstanceName) { - return certProviders.get(pluginInstanceName); + @Nullable Map certProviders, String pluginInstanceName) { + return certProviders != null ? certProviders.get(pluginInstanceName) : null; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 60ce2befe45..77320be9d20 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -19,9 +19,9 @@ import static com.google.common.truth.Truth.assertThat; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.protobuf.Any; import com.google.protobuf.BoolValue; -import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.StringValue; import com.google.protobuf.UInt32Value; @@ -46,7 +46,6 @@ import io.envoyproxy.envoy.config.core.v3.TrafficDirection; import io.envoyproxy.envoy.config.core.v3.TransportSocket; import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; -import io.envoyproxy.envoy.config.core.v3.WatchedDirectory; import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.config.listener.v3.Filter; import io.envoyproxy.envoy.config.listener.v3.FilterChain; @@ -86,7 +85,6 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsParameters; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsSessionTicketKeys; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatchAndSubstitute; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; @@ -1130,7 +1128,7 @@ public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInval .setLbPolicy(LbPolicy.RING_HASH) .build(); - CdsUpdate update = ClientXdsClient.parseCluster(cluster, new HashSet()); + CdsUpdate update = ClientXdsClient.parseCluster(cluster, new HashSet(), null); assertThat(update.lbPolicy()).isEqualTo(CdsUpdate.LbPolicy.RING_HASH); assertThat(update.minRingSize()) .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE); @@ -1138,6 +1136,28 @@ public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInval .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE); } + @Test + public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .addTransportSocketMatches( + Cluster.TransportSocketMatch.newBuilder().setName("match1").build()) + .build(); + + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage( + "Cluster cluster-foo.googleapis.com: transport-socket-matches not supported."); + ClientXdsClient.parseCluster(cluster, new HashSet(), null); + } + @Test public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMax() throws ResourceInvalidException { @@ -1160,7 +1180,7 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMa thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); - ClientXdsClient.parseCluster(cluster, new HashSet()); + ClientXdsClient.parseCluster(cluster, new HashSet(), null); } @Test @@ -1187,7 +1207,7 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_tooLargeRingSize thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); - ClientXdsClient.parseCluster(cluster, new HashSet()); + ClientXdsClient.parseCluster(cluster, new HashSet(), null); } @Test @@ -1200,7 +1220,7 @@ public void parseServerSideListener_invalidTrafficDirection() throws ResourceInv thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 with invalid traffic direction: OUTBOUND"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1214,7 +1234,7 @@ public void parseServerSideListener_listenerFiltersPresent() throws ResourceInva thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 cannot have listener_filters"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1228,7 +1248,7 @@ public void parseServerSideListener_useOriginalDst() throws ResourceInvalidExcep thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 cannot have use_original_dst set to true"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1275,7 +1295,7 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Found duplicate matcher:"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1322,7 +1342,7 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Found duplicate matcher:"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1369,7 +1389,7 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1384,7 +1404,8 @@ public void parseFilterChain_noHcm() throws ResourceInvalidException { thrown.expectMessage( "FilterChain filter-chain-foo missing required HttpConnectionManager filter"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1402,7 +1423,8 @@ public void parseFilterChain_duplicateFilter() throws ResourceInvalidException { thrown.expectMessage( "FilterChain filter-chain-foo with duplicated filter: envoy.http_connection_manager"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1420,7 +1442,8 @@ public void parseFilterChain_filterMissingTypedConfig() throws ResourceInvalidEx "FilterChain filter-chain-foo contains filter envoy.http_connection_manager " + "without typed_config"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1442,7 +1465,8 @@ public void parseFilterChain_unsupportedFilter() throws ResourceInvalidException "FilterChain filter-chain-foo contains filter unsupported with unsupported " + "typed_config type unsupported-type-url"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1468,10 +1492,10 @@ public void parseFilterChain_noName_generatedUuid() throws ResourceInvalidExcept EnvoyServerProtoData.FilterChain parsedFilterChain1 = ClientXdsClient.parseFilterChain( filterChain1, new HashSet(), null, filterRegistry, null, - true /* does not matter */); + null, true /* does not matter */); EnvoyServerProtoData.FilterChain parsedFilterChain2 = ClientXdsClient.parseFilterChain( filterChain2, new HashSet(), null, filterRegistry, null, - true /* does not matter */); + null, true /* does not matter */); assertThat(parsedFilterChain1.getName()).isNotEqualTo(parsedFilterChain2.getName()); } @@ -1483,7 +1507,7 @@ public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with tls_params is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1493,7 +1517,7 @@ public void validateCommonTlsContext_customHandshaker() throws ResourceInvalidEx .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with custom_handshaker is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1503,7 +1527,7 @@ public void validateCommonTlsContext_validationContext() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1515,7 +1539,7 @@ public void validateCommonTlsContext_validationContextSdsSecretConfig() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with validation_context_sds_secret_config is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1528,7 +1552,7 @@ public void validateCommonTlsContext_validationContextCertificateProvider() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with validation_context_certificate_provider is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1542,7 +1566,7 @@ public void validateCommonTlsContext_validationContextCertificateProviderInstanc thrown.expectMessage( "common-tls-context with validation_context_certificate_provider_instance is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1553,9 +1577,66 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredFo thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, true); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, true); + } + + @Test + public void validateCommonTlsContext_tlsCertificateProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); + } + + @Test + public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBootstrapFile() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage( + "CertificateProvider instance name 'bad-name' not defined in the bootstrap file."); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); + } + + @Test + public void validateCommonTlsContext_validationContextProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setValidationContextCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) + .build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); + } + + @Test + public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setValidationContextCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + .build()) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage( + "ValidationContextProvider instance name 'bad-name' not defined in the bootstrap file."); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } + @Test public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1563,7 +1644,7 @@ public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInval .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with tls_certificates is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1575,7 +1656,7 @@ public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with tls_certificate_sds_secret_configs is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1588,7 +1669,7 @@ public void validateCommonTlsContext_tlsCertificateCertificateProvider() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with tls_certificate_certificate_provider is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1598,7 +1679,7 @@ public void validateCommonTlsContext_combinedValidationContext_isRequiredForClie .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("combined_validation_context is required in upstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1612,7 +1693,7 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide thrown.expectMessage( "validation_context_certificate_provider_instance is required in " + "combined_validation_context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1631,43 +1712,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextForS .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("match_subject_alt_names only allowed in upstream_tls_context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, true); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDefaultValidationContextTrustedCa() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext(CertificateValidationContext.newBuilder() - .setTrustedCa(DataSource.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("trusted_ca in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDefaultValContextWatchedDirectory() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext(CertificateValidationContext.newBuilder() - .setWatchedDirectory(WatchedDirectory.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("watched_directory in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true); } @Test @@ -1686,7 +1731,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri thrown.expect(ResourceInvalidException.class); thrown.expectMessage("verify_certificate_spki in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -1705,7 +1750,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri thrown.expect(ResourceInvalidException.class); thrown.expectMessage("verify_certificate_hash in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -1725,7 +1770,7 @@ public void validateCommonTlsContext_combinedValContextDfltValContextRequireSign thrown.expectMessage( "require_signed_certificate_timestamp in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -1743,46 +1788,7 @@ public void validateCommonTlsContext_combinedValidationContextWithDefaultValidat .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("crl in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDefaultValContextAllowExpiredCert() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext( - CertificateValidationContext.newBuilder().setAllowExpiredCertificate(true))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown - .expectMessage("allow_expired_certificate in default_validation_context is not " - + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDfltValContextTrustChainVerification() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext(CertificateValidationContext.newBuilder() - .setTrustChainVerification( - CertificateValidationContext.TrustChainVerification.ACCEPT_UNTRUSTED))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Only VERIFY_TRUST_CHAIN for trust_chain_verification supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -1801,7 +1807,7 @@ public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomV thrown.expect(ResourceInvalidException.class); thrown.expectMessage("custom_validator_config in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -1809,7 +1815,7 @@ public void validateDownstreamTlsContext_noCommonTlsContext() throws ResourceInv DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.getDefaultInstance(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context is required in downstream-tls-context"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); + ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, null); } @Test @@ -1828,87 +1834,7 @@ public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("downstream-tls-context with require-sni is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasSessionTikcetKeys() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSessionTicketKeys(TlsSessionTicketKeys.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("downstream-tls-context with session_ticket_keys is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasSessionTikcetKeysSdsSecretConfig() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSessionTicketKeysSdsSecretConfig(SdsSecretConfig.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "downstream-tls-context with session_ticket_keys_sds_secret_config is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasDisableStatelessSessionResumption() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setDisableStatelessSessionResumption(true) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "downstream-tls-context with disable_stateless_session_resumption is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasSessionTimeout() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSessionTimeout(Duration.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("downstream-tls-context with session_timeout is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); + ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test @@ -1928,7 +1854,7 @@ public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceIn thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "downstream-tls-context with ocsp_staple_policy value STRICT_STAPLING is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); + ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test @@ -1936,58 +1862,7 @@ public void validateUpstreamTlsContext_noCommonTlsContext() throws ResourceInval UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.getDefaultInstance(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context is required in upstream-tls-context"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); - } - - @Test - public void validateUpstreamTlsContext_sni() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .build(); - UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSni("foo") - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("upstream-tls-context with sni is not supported"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); - } - - @Test - public void validateUpstreamTlsContext_allowRenegotiation() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .build(); - UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setAllowRenegotiation(true) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("upstream-tls-context with allow_renegotiation is not supported"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); - } - - @Test - public void validateUpstreamTlsContext_maxSessionKeys() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .build(); - UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setMaxSessionKeys(UInt32Value.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("upstream-tls-context with max_session_keys is not supported"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); + ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext, null); } private static Filter buildHttpConnectionManagerFilter(HttpFilter... httpFilters) { diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 913eb208ad7..78892772992 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -56,6 +56,7 @@ import io.grpc.internal.TimeProvider; import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.AbstractXdsClient.ResourceType; +import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; @@ -273,7 +274,8 @@ public void setUp() throws IOException { new Bootstrapper.ServerInfo( SERVER_URI, InsecureChannelCredentials.create(), useProtocolV3())), EnvoyProtoData.Node.newBuilder().build(), - null, + ImmutableMap.of("cert-instance-name", + new CertificateProviderInfo("file-watcher", ImmutableMap.of())), null); xdsClient = new ClientXdsClient( @@ -1327,7 +1329,8 @@ public void cdsResponseWithUpstreamTlsContext() { Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", null, true, - mf.buildUpstreamTlsContext("secret1", "cert1"), "envoy.transport_sockets.tls", null)); + mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), @@ -1343,7 +1346,7 @@ public void cdsResponseWithUpstreamTlsContext() { CommonTlsContext.CertificateProviderInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getCombinedValidationContext() .getValidationContextCertificateProviderInstance(); - assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("secret1"); + assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index 2914e5f3937..2918ce56224 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -146,7 +146,7 @@ public static DownstreamTlsContext buildTestDownstreamTlsContext( if (certName != null || validationContextCertName != null || useSans) { commonTlsContext = buildCommonTlsContextWithAdditionalValues( "cert-instance-name", certName, - "val-cert-instance-name", validationContextCertName, + "cert-instance-name", validationContextCertName, useSans ? Arrays.asList( StringMatcher.newBuilder() .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob") From 522b37bc3b2acb419cda1fe19ac9f2794a599d2e Mon Sep 17 00:00:00 2001 From: Anuraag Agrawal Date: Fri, 3 Sep 2021 00:56:56 +0900 Subject: [PATCH 56/82] Fix drift in MessageFramer comment (#8427) --- core/src/main/java/io/grpc/internal/MessageFramer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 83592e691a9..2042bddca03 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -267,7 +267,7 @@ private static int writeToOutputStream(InputStream message, OutputStream outputS return ((Drainable) message).drainTo(outputStream); } else { // This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we - // expect performance-critical code to support flushTo(). + // expect performance-critical code to support drainTo(). @SuppressWarnings("BetaApi") // ByteStreams is not Beta in v27 long written = ByteStreams.copy(message, outputStream); checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written); From 2faa74879772199d9c32d6095492474ee502d290 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Thu, 2 Sep 2021 10:24:22 -0700 Subject: [PATCH 57/82] census: Fix retry stats data race (#8459) There is data race in `CensusStatsModule. CallAttemptsTracerFactory`: If client call is cancelled while an active stream on the transport is not committed, then a [noop substream](https://github.com/grpc/grpc-java/blob/v1.40.0/core/src/main/java/io/grpc/internal/RetriableStream.java#L486) will be committed and the active stream will be cancelled. Because the active stream cancellation triggers the stream listener closed() on the _transport_ thread, the closed() method can be invoked concurrently with the call listener onClose(). Therefore, one `CallAttemptsTracerFactory.attemptEnded()` can be called concurrently with `CallAttemptsTracerFactory.callEnded()`, and there could be data race on RETRY_DELAY_PER_CALL. See also the regression test added. The same data race can happen in hedging case when one of hedges is committed and completes the call, other uncommitted hedges would cancel themselves and trigger their stream listeners closed() on the transport_thread concurrently. Fixing the race by recording RETRY_DELAY_PER_CALL once both the conditions are met: - callEnded is true - number of active streams is 0. --- .../io/grpc/census/CensusStatsModule.java | 111 ++++++++++-------- .../grpc/testing/integration/RetryTest.java | 64 ++++++++++ 2 files changed, 123 insertions(+), 52 deletions(-) diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index 6faeb575ccc..6f8acdb71e9 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -55,13 +55,13 @@ import io.opencensus.tags.unsafe.ContextUtils; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; /** * Provides factories for {@link StreamTracer} that records stats to Census. @@ -356,12 +356,12 @@ public void streamClosed(Status status) { if (module.recordFinishedRpcs) { // Stream is closed early. So no need to record metrics for any inbound events after this // point. - recordFinishedRpc(); + recordFinishedAttempt(); } } // Otherwise will report stats in callEnded() to guarantee all inbound metrics are recorded. } - void recordFinishedRpc() { + void recordFinishedAttempt() { MeasureMap measureMap = module.statsRecorder.newMeasureMap() // TODO(songya): remove the deprecated measure constants once they are completed removed. .put(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT, 1) @@ -405,30 +405,11 @@ static final class CallAttemptsTracerFactory extends Measure.MeasureDouble.create( "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); - @Nullable - private static final AtomicIntegerFieldUpdater callEndedUpdater; - - /** - * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their - * JDK reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to - * (potentially racy) direct updates of the volatile variables. - */ - static { - AtomicIntegerFieldUpdater tmpCallEndedUpdater; - try { - tmpCallEndedUpdater = - AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); - } catch (Throwable t) { - logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); - tmpCallEndedUpdater = null; - } - callEndedUpdater = tmpCallEndedUpdater; - } - ClientTracer inboundMetricTracer; private final CensusStatsModule module; private final Stopwatch stopwatch; - private volatile int callEnded; + @GuardedBy("lock") + private boolean callEnded; private final TagContext parentCtx; private final TagContext startCtx; private final String fullMethodName; @@ -436,17 +417,22 @@ static final class CallAttemptsTracerFactory extends // TODO(zdapeng): optimize memory allocation using AtomicFieldUpdater. private final AtomicLong attemptsPerCall = new AtomicLong(); private final AtomicLong transparentRetriesPerCall = new AtomicLong(); - private final AtomicLong retryDelayNanos = new AtomicLong(); - private final AtomicLong lastInactiveTimeStamp = new AtomicLong(); - private final AtomicInteger activeStreams = new AtomicInteger(); - private final AtomicBoolean activated = new AtomicBoolean(); + // write happens before read + private Status status; + private final Object lock = new Object(); + // write @GuardedBy("lock") and happens before read + private long retryDelayNanos; + @GuardedBy("lock") + private int activeStreams; + @GuardedBy("lock") + private boolean finishedCallToBeRecorded; CallAttemptsTracerFactory( CensusStatsModule module, TagContext parentCtx, String fullMethodName) { this.module = checkNotNull(module, "module"); this.parentCtx = checkNotNull(parentCtx, "parentCtx"); this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); - this.stopwatch = module.stopwatchSupplier.get().start(); + this.stopwatch = module.stopwatchSupplier.get(); TagValue methodTag = TagValue.create(fullMethodName); startCtx = module.tagger.toBuilder(parentCtx) .putLocal(RpcMeasureConstants.GRPC_CLIENT_METHOD, methodTag) @@ -461,10 +447,14 @@ static final class CallAttemptsTracerFactory extends @Override public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { - ClientTracer tracer = new ClientTracer(this, module, parentCtx, startCtx, info); - if (activeStreams.incrementAndGet() == 1) { - if (!activated.compareAndSet(false, true)) { - retryDelayNanos.addAndGet(stopwatch.elapsed(TimeUnit.NANOSECONDS)); + synchronized (lock) { + if (finishedCallToBeRecorded) { + // This can be the case when the called is cancelled but a retry attempt is created. + return new ClientStreamTracer() {}; + } + if (++activeStreams == 1 && stopwatch.isRunning()) { + stopwatch.stop(); + retryDelayNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); } } if (module.recordStartedRpcs && attemptsPerCall.get() > 0) { @@ -477,42 +467,59 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metada } else { attemptsPerCall.incrementAndGet(); } - return tracer; + return new ClientTracer(this, module, parentCtx, startCtx, info); } // Called whenever each attempt is ended. void attemptEnded() { - if (activeStreams.decrementAndGet() == 0) { - // Race condition between two extremely close events does not matter because the difference - // in the result would be very small. - long lastInactiveTimeStamp = - this.lastInactiveTimeStamp.getAndSet(stopwatch.elapsed(TimeUnit.NANOSECONDS)); - retryDelayNanos.addAndGet(-lastInactiveTimeStamp); + if (!module.recordFinishedRpcs) { + return; + } + boolean shouldRecordFinishedCall = false; + synchronized (lock) { + if (--activeStreams == 0) { + stopwatch.start(); + if (callEnded && !finishedCallToBeRecorded) { + shouldRecordFinishedCall = true; + finishedCallToBeRecorded = true; + } + } + } + if (shouldRecordFinishedCall) { + recordFinishedCall(); } } void callEnded(Status status) { - if (callEndedUpdater != null) { - if (callEndedUpdater.getAndSet(this, 1) != 0) { + if (!module.recordFinishedRpcs) { + return; + } + this.status = status; + boolean shouldRecordFinishedCall = false; + synchronized (lock) { + if (callEnded) { + // FIXME(https://github.com/grpc/grpc-java/issues/7921): this shouldn't happen return; } - } else { - if (callEnded != 0) { - return; + callEnded = true; + if (activeStreams == 0 && !finishedCallToBeRecorded) { + shouldRecordFinishedCall = true; + finishedCallToBeRecorded = true; } - callEnded = 1; } - if (!module.recordFinishedRpcs) { - return; + if (shouldRecordFinishedCall) { + recordFinishedCall(); } - stopwatch.stop(); + } + + void recordFinishedCall() { if (attemptsPerCall.get() == 0) { ClientTracer tracer = new ClientTracer(this, module, parentCtx, startCtx, null); tracer.roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); tracer.statusCode = status.getCode(); - tracer.recordFinishedRpc(); + tracer.recordFinishedAttempt(); } else if (inboundMetricTracer != null) { - inboundMetricTracer.recordFinishedRpc(); + inboundMetricTracer.recordFinishedAttempt(); } long retriesPerCall = 0; @@ -523,7 +530,7 @@ void callEnded(Status status) { MeasureMap measureMap = module.statsRecorder.newMeasureMap() .put(RETRIES_PER_CALL, retriesPerCall) .put(TRANSPARENT_RETRIES_PER_CALL, transparentRetriesPerCall.get()) - .put(RETRY_DELAY_PER_CALL, retryDelayNanos.get() / NANOS_PER_MILLI); + .put(RETRY_DELAY_PER_CALL, retryDelayNanos / NANOS_PER_MILLI); TagValue methodTag = TagValue.create(fullMethodName); TagValue statusTag = TagValue.create(status.getCode().toString()); measureMap.record( diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java index bdf39e8546a..eb815501d5c 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -356,6 +356,70 @@ public void statsRecorded() throws Exception { assertRetryStatsRecorded(1, 0, 10_000); } + @Test + public void statsRecorde_callCancelledBeforeCommit() throws Exception { + startNewServer(); + retryPolicy = ImmutableMap.builder() + .put("maxAttempts", 4D) + .put("initialBackoff", "10s") + .put("maxBackoff", "10s") + .put("backoffMultiplier", 1D) + .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) + .build(); + createNewChannel(); + + // We will have streamClosed return at a particular moment that we want. + final CountDownLatch streamClosedLatch = new CountDownLatch(1); + ClientStreamTracer.Factory streamTracerFactory = new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new ClientStreamTracer() { + @Override + public void streamClosed(Status status) { + if (status.getCode().equals(Code.CANCELLED)) { + try { + streamClosedLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError("streamClosedLatch interrupted", e); + } + } + } + }; + } + }; + ClientCall call = channel.newCall( + clientStreamingMethod, CallOptions.DEFAULT.withStreamTracerFactory(streamTracerFactory)); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + fakeClock.forwardTime(5, SECONDS); + String message = "String of length 20."; + call.sendMessage(message); + assertOutboundMessageRecorded(); + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + // trigger retry + serverCall.close( + Status.UNAVAILABLE.withDescription("original attempt failed"), + new Metadata()); + assertRpcStatusRecorded(Code.UNAVAILABLE, 5000, 1); + elapseBackoff(10, SECONDS); + assertRpcStartedRecorded(); + assertOutboundMessageRecorded(); + serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + fakeClock.forwardTime(7, SECONDS); + call.cancel("Cancelled before commit", null); // A noop substream will commit. + // The call listener is closed, but the netty substream listener is not yet closed. + verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); + // Let the netty substream listener be closed. + streamClosedLatch.countDown(); + assertRetryStatsRecorded(1, 0, 10_000); + assertRpcStatusRecorded(Code.CANCELLED, 7_000, 1); + } + @Test public void serverCancelledAndClientDeadlineExceeded() throws Exception { startNewServer(); From 07747c59a2f21190b9fe07120d5d1a57eaaebcb0 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Thu, 2 Sep 2021 10:25:15 -0700 Subject: [PATCH 58/82] xds: Fix WeakReference bug in SharedCallCounterMap (#8466) Fixes #8397. #8397 is caused by mistakenly clearing up a map entry right after the entry is recreated after gc. Reproduced in regression test. --- .../java/io/grpc/xds/SharedCallCounterMap.java | 13 +++++++++++-- .../io/grpc/xds/SharedCallCounterMapTest.java | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java b/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java index 71cff0cedf3..7aa55c27429 100644 --- a/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java +++ b/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java @@ -58,8 +58,14 @@ public synchronized AtomicLong getOrCreate(String cluster, @Nullable String edsS counters.put(cluster, clusterCounters); } CounterReference ref = clusterCounters.get(edsServiceName); - AtomicLong counter; - if (ref == null || (counter = ref.get()) == null) { + AtomicLong counter = null; + if (ref != null) { + counter = ref.get(); + if (counter == null) { + ref.enqueue(); + } + } + if (counter == null) { counter = new AtomicLong(); ref = new CounterReference(counter, refQueue, cluster, edsServiceName); clusterCounters.put(edsServiceName, ref); @@ -73,6 +79,9 @@ void cleanQueue() { CounterReference ref; while ((ref = (CounterReference) refQueue.poll()) != null) { Map clusterCounter = counters.get(ref.cluster); + if (clusterCounter.get(ref.edsServiceName) != ref) { + continue; + } clusterCounter.remove(ref.edsServiceName); if (clusterCounter.isEmpty()) { counters.remove(ref.cluster); diff --git a/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java b/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java index 9f2293d3c53..3051a021870 100644 --- a/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java @@ -62,4 +62,22 @@ public boolean isDone() { map.cleanQueue(); assertThat(counters).isEmpty(); } + + @Test + public void gcAndRecreate() { + @SuppressWarnings("UnusedVariable") // assign to null for GC only + AtomicLong counter = map.getOrCreate(CLUSTER, EDS_SERVICE_NAME); + final CounterReference ref = counters.get(CLUSTER).get(EDS_SERVICE_NAME); + assertThat(counter.get()).isEqualTo(0); + counter = null; + GcFinalization.awaitDone(new FinalizationPredicate() { + @Override + public boolean isDone() { + return ref.isEnqueued(); + } + }); + map.getOrCreate(CLUSTER, EDS_SERVICE_NAME); + assertThat(counters.get(CLUSTER)).isNotNull(); + assertThat(counters.get(CLUSTER).get(EDS_SERVICE_NAME)).isNotNull(); + } } From 0838b736748b398292a696d2444f9aa0f6778081 Mon Sep 17 00:00:00 2001 From: zpencer Date: Thu, 2 Sep 2021 12:01:44 -0700 Subject: [PATCH 59/82] netty: remove unneeded TransportTracer null checks --- .../io/grpc/netty/NettyClientHandler.java | 13 +------ .../io/grpc/netty/NettyServerHandler.java | 36 +++++-------------- netty/src/main/java/io/grpc/netty/Utils.java | 23 ++++++++++++ 3 files changed, 33 insertions(+), 39 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index 22d8fcadb75..6dde8c825ef 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -62,7 +62,6 @@ import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2Error; import io.netty.handler.codec.http2.Http2Exception; -import io.netty.handler.codec.http2.Http2FlowController; import io.netty.handler.codec.http2.Http2FrameAdapter; import io.netty.handler.codec.http2.Http2FrameLogger; import io.netty.handler.codec.http2.Http2FrameReader; @@ -217,17 +216,7 @@ static NettyClientHandler newHandler( Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); - transportTracer.setFlowControlWindowReader(new TransportTracer.FlowControlReader() { - final Http2FlowController local = connection.local().flowController(); - final Http2FlowController remote = connection.remote().flowController(); - - @Override - public TransportTracer.FlowControlWindows read() { - return new TransportTracer.FlowControlWindows( - local.windowSize(connection.connectionStream()), - remote.windowSize(connection.connectionStream())); - } - }); + transportTracer.setFlowControlWindowReader(new Utils.FlowControlReader(connection)); Http2Settings settings = new Http2Settings(); settings.pushEnabled(false); diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 01ed7c0c373..6fca656e795 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -69,7 +69,6 @@ import io.netty.handler.codec.http2.Http2Error; import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Exception.StreamException; -import io.netty.handler.codec.http2.Http2FlowController; import io.netty.handler.codec.http2.Http2FrameAdapter; import io.netty.handler.codec.http2.Http2FrameLogger; import io.netty.handler.codec.http2.Http2FrameReader; @@ -367,23 +366,8 @@ public void run() { keepAliveManager.onTransportStarted(); } - - if (transportTracer != null) { - assert encoder().connection().equals(decoder().connection()); - final Http2Connection connection = encoder().connection(); - transportTracer.setFlowControlWindowReader(new TransportTracer.FlowControlReader() { - private final Http2FlowController local = connection.local().flowController(); - private final Http2FlowController remote = connection.remote().flowController(); - - @Override - public TransportTracer.FlowControlWindows read() { - assert ctx.executor().inEventLoop(); - return new TransportTracer.FlowControlWindows( - local.windowSize(connection.connectionStream()), - remote.windowSize(connection.connectionStream())); - } - }); - } + assert encoder().connection().equals(decoder().connection()); + transportTracer.setFlowControlWindowReader(new Utils.FlowControlReader(encoder().connection())); super.handlerAdded(ctx); } @@ -895,16 +879,14 @@ public void ping() { ChannelFuture pingFuture = encoder().writePing( ctx, false /* isAck */, KEEPALIVE_PING, ctx.newPromise()); ctx.flush(); - if (transportTracer != null) { - pingFuture.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - transportTracer.reportKeepAliveSent(); - } + pingFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + transportTracer.reportKeepAliveSent(); } - }); - } + } + }); } @Override diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java index 082ce63dd54..c2f2fa4a7bf 100644 --- a/netty/src/main/java/io/grpc/netty/Utils.java +++ b/netty/src/main/java/io/grpc/netty/Utils.java @@ -32,6 +32,7 @@ import io.grpc.Status; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2InboundHeaders; import io.grpc.netty.NettySocketSupport.NativeSocketOptions; import io.netty.buffer.ByteBufAllocator; @@ -47,8 +48,11 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2FlowController; import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.AsciiString; import io.netty.util.NettyRuntime; import io.netty.util.concurrent.DefaultThreadFactory; @@ -441,6 +445,25 @@ public String toString() { } } + static final class FlowControlReader implements TransportTracer.FlowControlReader { + private final Http2Stream connectionStream; + private final Http2FlowController local; + private final Http2FlowController remote; + + FlowControlReader(Http2Connection connection) { + local = connection.local().flowController(); + remote = connection.remote().flowController(); + connectionStream = connection.connectionStream(); + } + + @Override + public TransportTracer.FlowControlWindows read() { + return new TransportTracer.FlowControlWindows( + local.windowSize(connectionStream), + remote.windowSize(connectionStream)); + } + } + static InternalChannelz.SocketOptions getSocketOptions(Channel channel) { ChannelConfig config = channel.config(); InternalChannelz.SocketOptions.Builder b = new InternalChannelz.SocketOptions.Builder(); From ffebe231c0d8eaf380e05c18c62ce6090cc283f1 Mon Sep 17 00:00:00 2001 From: Daniel Zou Date: Thu, 2 Sep 2021 22:12:10 +0000 Subject: [PATCH 60/82] netty-shaded: Rename the directory of netty shaded resources to avoid collisions --- netty/shaded/build.gradle | 3 ++- .../src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index 9cb3de9a252..6b1dad644d1 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -110,8 +110,9 @@ class NettyResourceTransformer implements Transformer { @Override void transform(TransformerContext context) { + String updatedPath = context.path.replace("io.netty", "io.grpc.netty.shaded.io.netty") String updatedContent = context.is.getText().replace("io.netty", "io.grpc.netty.shaded.io.netty") - resources.put(context.path, updatedContent) + resources.put(updatedPath, updatedContent) } @Override diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index d3bdc4394ca..5c2ff317ccd 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -78,7 +78,8 @@ public void noNormalNetty() throws Exception { @Test public void nettyResourcesUpdated() throws IOException { InputStream inputStream = NettyChannelBuilder.class.getClassLoader() - .getResourceAsStream("META-INF/native-image/io.netty/transport/reflection-config.json"); + .getResourceAsStream( + "META-INF/native-image/io.grpc.netty.shaded.io.netty/transport/reflection-config.json"); assertThat(inputStream).isNotNull(); Scanner s = new Scanner(inputStream, StandardCharsets.UTF_8.name()).useDelimiter("\\A"); From 62fafe7edae5540022c917ac0fa04ad761acab51 Mon Sep 17 00:00:00 2001 From: Brice Jaglin Date: Fri, 16 Apr 2021 12:43:55 +0200 Subject: [PATCH 61/82] core: clarify exception message Reformulate message to highlight that SizeEnforcingInputStream is applied on the message size of the message after decompression. --- core/src/main/java/io/grpc/internal/MessageDeframer.java | 4 ++-- .../src/test/java/io/grpc/internal/MessageDeframerTest.java | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index 9a523746e50..534398315e8 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -517,8 +517,8 @@ private void reportCount() { private void verifySize() { if (count > maxMessageSize) { throw Status.RESOURCE_EXHAUSTED.withDescription(String.format( - "Compressed gRPC message exceeds maximum size %d: %d bytes read", - maxMessageSize, count)).asRuntimeException(); + "Decompressed gRPC message exceeds maximum size %d", + maxMessageSize)).asRuntimeException(); } } } diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index c1907e51703..5edf64ef85d 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -378,7 +378,7 @@ public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { try { thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Compressed gRPC message exceeds"); + thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); while (stream.read() != -1) { } @@ -424,7 +424,7 @@ public void sizeEnforcingInputStream_readAboveLimit() throws IOException { try { thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Compressed gRPC message exceeds"); + thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); stream.read(buf, 0, buf.length); } finally { @@ -467,7 +467,7 @@ public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { try { thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Compressed gRPC message exceeds"); + thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); stream.skip(4); } finally { From a91cc85dfd2f1f8cb749f6e8b5984a61da294902 Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Thu, 2 Sep 2021 20:20:20 -0400 Subject: [PATCH 62/82] Revert "core/auth: Remove CallCredentials2 (#8464)" This reverts commit 7cde473efa4ffa3f530baf9e9a07298c80ed543f. --- .../main/java/io/grpc/CallCredentials2.java | 73 ++++ .../GoogleAuthLibraryCallCredentials.java | 6 +- .../CallCredentials2ApplyingTest.java | 351 ++++++++++++++++++ .../internal/CallCredentialsApplyingTest.java | 45 --- 4 files changed, 428 insertions(+), 47 deletions(-) create mode 100644 api/src/main/java/io/grpc/CallCredentials2.java create mode 100644 core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java diff --git a/api/src/main/java/io/grpc/CallCredentials2.java b/api/src/main/java/io/grpc/CallCredentials2.java new file mode 100644 index 00000000000..fdb7f51070a --- /dev/null +++ b/api/src/main/java/io/grpc/CallCredentials2.java @@ -0,0 +1,73 @@ +/* + * Copyright 2016 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.util.concurrent.Executor; + +/** + * The new interface for {@link CallCredentials}. + * + *

THIS CLASS NAME IS TEMPORARY and is part of a migration. This class will BE DELETED as it + * replaces {@link CallCredentials} in short-term. THIS CLASS IS ONLY REFERENCED BY IMPLEMENTIONS. + * All consumers should be always referencing {@link CallCredentials}. + * + * @deprecated the new interface has been promoted into {@link CallCredentials}. Implementations + * should switch back to "{@code extends CallCredentials}". + */ +@Deprecated +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4901") +public abstract class CallCredentials2 extends CallCredentials { + /** + * Pass the credential data to the given {@link MetadataApplier}, which will propagate it to the + * request metadata. + * + *

It is called for each individual RPC, within the {@link Context} of the call, before the + * stream is about to be created on a transport. Implementations should not block in this + * method. If metadata is not immediately available, e.g., needs to be fetched from network, the + * implementation may give the {@code applier} to an asynchronous task which will eventually call + * the {@code applier}. The RPC proceeds only after the {@code applier} is called. + * + * @param requestInfo request-related information + * @param appExecutor The application thread-pool. It is provided to the implementation in case it + * needs to perform blocking operations. + * @param applier The outlet of the produced headers. It can be called either before or after this + * method returns. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1914") + public abstract void applyRequestMetadata( + RequestInfo requestInfo, Executor appExecutor, MetadataApplier applier); + + @Override + public final void applyRequestMetadata( + RequestInfo requestInfo, Executor appExecutor, + final CallCredentials.MetadataApplier applier) { + applyRequestMetadata(requestInfo, appExecutor, new MetadataApplier() { + @Override + public void apply(Metadata headers) { + applier.apply(headers); + } + + @Override + public void fail(Status status) { + applier.fail(status); + } + }); + } + + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1914") + public abstract static class MetadataApplier extends CallCredentials.MetadataApplier {} +} diff --git a/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java b/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java index 4b95a6c7f4d..852fba73b20 100644 --- a/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java +++ b/auth/src/main/java/io/grpc/auth/GoogleAuthLibraryCallCredentials.java @@ -42,9 +42,11 @@ import javax.annotation.Nullable; /** - * Wraps {@link Credentials} as a {@link io.grpc.CallCredentials}. + * Wraps {@link Credentials} as a {@link CallCredentials}. */ -final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials { +// TODO(zhangkun83): remove the suppression after we change the base class to CallCredential +@SuppressWarnings("deprecation") +final class GoogleAuthLibraryCallCredentials extends io.grpc.CallCredentials2 { private static final Logger log = Logger.getLogger(GoogleAuthLibraryCallCredentials.class.getName()); private static final JwtHelper jwtHelper diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java new file mode 100644 index 00000000000..963a586319b --- /dev/null +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -0,0 +1,351 @@ +/* + * Copyright 2016 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.grpc.Attributes; +import io.grpc.CallCredentials.MetadataApplier; +import io.grpc.CallCredentials.RequestInfo; +import io.grpc.CallOptions; +import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; +import io.grpc.IntegerMarshaller; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.Status; +import io.grpc.StringMarshaller; +import java.net.SocketAddress; +import java.util.concurrent.Executor; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; + +/** + * Unit test for {@link CallCredentials2} applying functionality implemented by {@link + * CallCredentialsApplyingTransportFactory} and {@link MetadataApplierImpl}. + */ +@SuppressWarnings("deprecation") +@RunWith(JUnit4.class) +public class CallCredentials2ApplyingTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private ClientTransportFactory mockTransportFactory; + + @Mock + private ConnectionClientTransport mockTransport; + + @Mock + private ClientStream mockStream; + + @Mock + private io.grpc.CallCredentials2 mockCreds; + + @Mock + private Executor mockExecutor; + + @Mock + private SocketAddress address; + + // Noop logger; + @Mock + private ChannelLogger channelLogger; + + private static final String AUTHORITY = "testauthority"; + private static final String USER_AGENT = "testuseragent"; + private static final Attributes.Key ATTR_KEY = Attributes.Key.create("somekey"); + private static final String ATTR_VALUE = "somevalue"; + private static final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setFullMethodName("service/method") + .setRequestMarshaller(new StringMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + private static final Metadata.Key ORIG_HEADER_KEY = + Metadata.Key.of("header1", Metadata.ASCII_STRING_MARSHALLER); + private static final String ORIG_HEADER_VALUE = "some original header value"; + private static final Metadata.Key CREDS_KEY = + Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER); + private static final String CREDS_VALUE = "some credentials"; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; + + private final Metadata origHeaders = new Metadata(); + private ForwardingConnectionClientTransport transport; + private CallOptions callOptions; + + @Before + public void setUp() { + ClientTransportFactory.ClientTransportOptions clientTransportOptions = + new ClientTransportFactory.ClientTransportOptions() + .setAuthority(AUTHORITY) + .setUserAgent(USER_AGENT); + + origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); + when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) + .thenReturn(mockTransport); + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockStream); + ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( + mockTransportFactory, null, mockExecutor); + transport = (ForwardingConnectionClientTransport) + transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); + callOptions = CallOptions.DEFAULT.withCallCredentials(mockCreds); + verify(mockTransportFactory).newClientTransport(address, clientTransportOptions, channelLogger); + assertSame(mockTransport, transport.delegate()); + } + + @Test + public void parameterPropagation_base() { + Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); + when(mockTransport.getAttributes()).thenReturn(transportAttrs); + + transport.newStream(method, origHeaders, callOptions, tracers); + + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + infoCaptor.capture(), same(mockExecutor), + any(io.grpc.CallCredentials2.MetadataApplier.class)); + RequestInfo info = infoCaptor.getValue(); + assertSame(method, info.getMethodDescriptor()); + assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); + assertSame(AUTHORITY, info.getAuthority()); + assertSame(SecurityLevel.NONE, info.getSecurityLevel()); + } + + @Test + public void parameterPropagation_transportSetSecurityLevel() { + Attributes transportAttrs = Attributes.newBuilder() + .set(ATTR_KEY, ATTR_VALUE) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.INTEGRITY) + .build(); + when(mockTransport.getAttributes()).thenReturn(transportAttrs); + + transport.newStream(method, origHeaders, callOptions, tracers); + + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + infoCaptor.capture(), same(mockExecutor), + any(io.grpc.CallCredentials2.MetadataApplier.class)); + RequestInfo info = infoCaptor.getValue(); + assertSame(method, info.getMethodDescriptor()); + assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); + assertSame(AUTHORITY, info.getAuthority()); + assertSame(SecurityLevel.INTEGRITY, info.getSecurityLevel()); + } + + @Test + public void parameterPropagation_callOptionsSetAuthority() { + Attributes transportAttrs = Attributes.newBuilder() + .set(ATTR_KEY, ATTR_VALUE) + .build(); + when(mockTransport.getAttributes()).thenReturn(transportAttrs); + Executor anotherExecutor = mock(Executor.class); + + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); + + ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + infoCaptor.capture(), same(anotherExecutor), + any(io.grpc.CallCredentials2.MetadataApplier.class)); + RequestInfo info = infoCaptor.getValue(); + assertSame(method, info.getMethodDescriptor()); + assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); + assertEquals("calloptions-authority", info.getAuthority()); + assertSame(SecurityLevel.NONE, info.getSecurityLevel()); + } + + @Test + public void credentialThrows() { + final RuntimeException ex = new RuntimeException(); + when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); + doThrow(ex).when(mockCreds).applyRequestMetadata( + any(RequestInfo.class), same(mockExecutor), + any(io.grpc.CallCredentials2.MetadataApplier.class)); + + FailingClientStream stream = + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); + + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); + assertSame(ex, stream.getError().getCause()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void applyMetadata_inline() { + when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + MetadataApplier applier = (MetadataApplier) invocation.getArguments()[2]; + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applier.apply(headers); + return null; + } + }).when(mockCreds).applyRequestMetadata( + any(RequestInfo.class), same(mockExecutor), + any(io.grpc.CallCredentials2.MetadataApplier.class)); + + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); + + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); + assertSame(mockStream, stream); + assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); + assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void fail_inline() { + final Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); + when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + MetadataApplier applier = (MetadataApplier) invocation.getArguments()[2]; + applier.fail(error); + return null; + } + }).when(mockCreds).applyRequestMetadata( + any(RequestInfo.class), same(mockExecutor), + any(io.grpc.CallCredentials2.MetadataApplier.class)); + + FailingClientStream stream = + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); + + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + assertSame(error, stream.getError()); + transport.shutdownNow(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdownNow(Status.UNAVAILABLE); + } + + @Test + public void applyMetadata_delayed() { + when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); + + // Will call applyRequestMetadata(), which is no-op. + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); + + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + + transport.shutdown(Status.UNAVAILABLE); + verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); + + Metadata headers = new Metadata(); + headers.put(CREDS_KEY, CREDS_VALUE); + applierCaptor.getValue().apply(headers); + + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); + assertSame(mockStream, stream.getRealStream()); + assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); + assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void fail_delayed() { + when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); + + // Will call applyRequestMetadata(), which is no-op. + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); + + ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); + verify(mockCreds).applyRequestMetadata( + any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); + + Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); + applierCaptor.getValue().fail(error); + + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); + assertSame(error, failingStream.getError()); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } + + @Test + public void noCreds() { + callOptions = callOptions.withCallCredentials(null); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); + + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); + assertSame(mockStream, stream); + assertNull(origHeaders.get(CREDS_KEY)); + assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); + transport.shutdown(Status.UNAVAILABLE); + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) + instanceof FailingClientStream); + verify(mockTransport).shutdown(Status.UNAVAILABLE); + } +} diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 2f0ce1070b1..ef49e66bf2d 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -176,51 +176,6 @@ public void parameterPropagation_overrideByCallOptions() { assertSame(SecurityLevel.INTEGRITY, info.getSecurityLevel()); } - @Test - public void parameterPropagation_transportSetSecurityLevel() { - Attributes transportAttrs = Attributes.newBuilder() - .set(ATTR_KEY, ATTR_VALUE) - .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.INTEGRITY) - .build(); - when(mockTransport.getAttributes()).thenReturn(transportAttrs); - - transport.newStream(method, origHeaders, callOptions, tracers); - - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - infoCaptor.capture(), same(mockExecutor), - any(io.grpc.CallCredentials.MetadataApplier.class)); - RequestInfo info = infoCaptor.getValue(); - assertSame(method, info.getMethodDescriptor()); - assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); - assertSame(AUTHORITY, info.getAuthority()); - assertSame(SecurityLevel.INTEGRITY, info.getSecurityLevel()); - } - - @Test - public void parameterPropagation_callOptionsSetAuthority() { - Attributes transportAttrs = Attributes.newBuilder() - .set(ATTR_KEY, ATTR_VALUE) - .build(); - when(mockTransport.getAttributes()).thenReturn(transportAttrs); - Executor anotherExecutor = mock(Executor.class); - - transport.newStream( - method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), - tracers); - - ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); - verify(mockCreds).applyRequestMetadata( - infoCaptor.capture(), same(anotherExecutor), - any(io.grpc.CallCredentials.MetadataApplier.class)); - RequestInfo info = infoCaptor.getValue(); - assertSame(method, info.getMethodDescriptor()); - assertSame(ATTR_VALUE, info.getTransportAttrs().get(ATTR_KEY)); - assertEquals("calloptions-authority", info.getAuthority()); - assertSame(SecurityLevel.NONE, info.getSecurityLevel()); - } - @Test public void credentialThrows() { final RuntimeException ex = new RuntimeException(); From 4828698bec2c44005a3f0ad4f70b70ceeeae9049 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Fri, 3 Sep 2021 12:38:26 -0700 Subject: [PATCH 63/82] xds: enable PSM security by default (#8478) --- .../main/java/io/grpc/xds/ClusterImplLoadBalancer.java | 3 ++- .../java/io/grpc/xds/ClusterImplLoadBalancerTest.java | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index dffbe3dade7..d95361935a7 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -69,7 +69,8 @@ final class ClusterImplLoadBalancer extends LoadBalancer { || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING")); @VisibleForTesting static boolean enableSecurity = - Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT")); + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT")); private static final Attributes.Key ATTR_CLUSTER_LOCALITY_STATS = Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityStats"); diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 3b2a54c2c25..dfcf101fcf5 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -480,16 +480,16 @@ public void endpointAddressesAttachedWithClusterName() { } @Test - public void endpointAddressesAttachedWithTlsConfig_enableSecurity() { + public void endpointAddressesAttachedWithTlsConfig_disableSecurity() { boolean originalEnableSecurity = ClusterImplLoadBalancer.enableSecurity; - ClusterImplLoadBalancer.enableSecurity = true; - subtest_endpointAddressesAttachedWithTlsConfig(true); + ClusterImplLoadBalancer.enableSecurity = false; + subtest_endpointAddressesAttachedWithTlsConfig(false); ClusterImplLoadBalancer.enableSecurity = originalEnableSecurity; } @Test - public void endpointAddressesAttachedWithTlsConfig_securityDisabledByDefault() { - subtest_endpointAddressesAttachedWithTlsConfig(false); + public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { + subtest_endpointAddressesAttachedWithTlsConfig(true); } private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecurity) { From 5475cf12bbca8a4b3cbff522e8ee8abc54f545aa Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Fri, 3 Sep 2021 12:47:38 -0700 Subject: [PATCH 64/82] xds: fix parsing retryOn values (#8477) - Envoy ignores white spaces in `retryOn` field https://github.com/envoyproxy/envoy/blob/v1.19.1/source/common/router/retry_state_impl.cc#L166 We should do the same. - Envoy ignores unsupported values https://github.com/envoyproxy/envoy/blob/v1.19.1/source/common/router/config_impl.cc#L89-L90 and we should do the same. --- .../java/io/grpc/xds/ClientXdsClient.java | 9 ++++---- .../io/grpc/xds/ClientXdsClientDataTest.java | 23 +++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 47e73d11266..83845515978 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -1237,21 +1237,20 @@ private static StructOrError parseRetryPolicy( maxBackoff = Durations.fromNanos(Durations.toNanos(initialBackoff) * 10); } } - Iterable retryOns = Splitter.on(',').split(retryPolicyProto.getRetryOn()); + Iterable retryOns = + Splitter.on(',').omitEmptyStrings().trimResults().split(retryPolicyProto.getRetryOn()); ImmutableList.Builder retryableStatusCodesBuilder = ImmutableList.builder(); for (String retryOn : retryOns) { Code code; try { code = Code.valueOf(retryOn.toUpperCase(Locale.US).replace('-', '_')); } catch (IllegalArgumentException e) { - // TODO(zdapeng): TBD // unsupported value, such as "5xx" - return null; + continue; } if (!SUPPORTED_RETRYABLE_CODES.contains(code)) { - // TODO(zdapeng): TBD // unsupported value - return null; + continue; } retryableStatusCodesBuilder.add(code); } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 77320be9d20..2f73c6e2ea4 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -644,7 +644,8 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); - assertThat(struct.getStruct().retryPolicy()).isNull(); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) + .containsExactly(Code.CANCELLED); // unsupported retry_on code builder = RetryPolicy.newBuilder() @@ -660,7 +661,25 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); - assertThat(struct.getStruct().retryPolicy()).isNull(); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) + .containsExactly(Code.CANCELLED); + + // whitespace in retry_on + builder = RetryPolicy.newBuilder() + .setNumRetries(UInt32Value.of(3)) + .setRetryBackOff( + RetryBackOff.newBuilder() + .setBaseInterval(Durations.fromMillis(500)) + .setMaxInterval(Durations.fromMillis(600))) + .setPerTryTimeout(Durations.fromMillis(300)) + .setRetryOn("abort, , cancelled , "); + proto = io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setCluster("cluster-foo") + .setRetryPolicy(builder) + .build(); + struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) + .containsExactly(Code.CANCELLED); } @Test From 6cd911757a4b0478eed20c6a77bc36cadbfa7d3b Mon Sep 17 00:00:00 2001 From: Sergii Tkachenko Date: Fri, 3 Sep 2021 16:26:43 -0400 Subject: [PATCH 65/82] census: make internal linter happy TODO is preferred to FIXME. --- census/src/main/java/io/grpc/census/CensusStatsModule.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index 6f8acdb71e9..de860d0854c 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -498,7 +498,7 @@ void callEnded(Status status) { boolean shouldRecordFinishedCall = false; synchronized (lock) { if (callEnded) { - // FIXME(https://github.com/grpc/grpc-java/issues/7921): this shouldn't happen + // TODO(https://github.com/grpc/grpc-java/issues/7921): this shouldn't happen return; } callEnded = true; From 5dc6e0ca5479042d5c976c278c56e3d3b1f06452 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Tue, 7 Sep 2021 14:27:49 -0700 Subject: [PATCH 66/82] xds: update Envoy protos to a later revision for the new CertificateProvider definitions (#8490) --- .../io/grpc/xds/ClientXdsClientDataTest.java | 21 ++++++ .../io/grpc/xds/ClientXdsClientTestBase.java | 1 + .../io/grpc/xds/ClientXdsClientV3Test.java | 1 + ...rChainMatchingProtocolNegotiatorsTest.java | 1 + .../test/java/io/grpc/xds/RbacFilterTest.java | 2 +- .../ClientSslContextProviderFactoryTest.java | 1 + .../sds/CommonTlsContextTestsUtil.java | 3 + xds/third_party/envoy/import.sh | 2 +- .../api/v2/listener/listener_components.proto | 2 +- .../envoy/config/accesslog/v3/accesslog.proto | 1 + .../envoy/config/bootstrap/v3/bootstrap.proto | 66 +++++++++++++++++- .../envoy/config/cluster/v3/cluster.proto | 60 +++++++++++----- .../proto/envoy/config/core/v3/protocol.proto | 18 ++++- .../endpoint/v3/endpoint_components.proto | 34 ++++++++- .../config/listener/v3/api_listener.proto | 1 + .../envoy/config/listener/v3/listener.proto | 29 +++++--- .../listener/v3/listener_components.proto | 20 ++++-- .../envoy/config/overload/v3/overload.proto | 23 +++++++ .../proto/envoy/config/rbac/v3/rbac.proto | 11 ++- .../proto/envoy/config/route/v3/route.proto | 15 +++- .../config/route/v3/route_components.proto | 55 ++++++++++++--- .../filters/common/fault/v3/fault.proto | 9 ++- .../v3/http_connection_manager.proto | 41 ++++++----- .../transport_sockets/tls/v3/common.proto | 38 +++++++++- .../transport_sockets/tls/v3/tls.proto | 69 +++++++++++++------ .../proto/envoy/type/matcher/v3/string.proto | 4 +- 26 files changed, 427 insertions(+), 101 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 2f73c6e2ea4..fb5c349f123 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -256,6 +256,7 @@ public void parseRoute_skipRouteWithUnsupportedAction() { } @Test + @SuppressWarnings("deprecation") public void parseRouteMatch_withHeaderMatcher() { io.envoyproxy.envoy.config.route.v3.RouteMatch proto = io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder() @@ -336,6 +337,7 @@ public void parsePathMatcher_withSafeRegEx() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withExactMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -349,6 +351,7 @@ public void parseHeaderMatcher_withExactMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withSafeRegExMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -388,6 +391,7 @@ public void parseHeaderMatcher_withPresentMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withPrefixMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -401,6 +405,7 @@ public void parseHeaderMatcher_withPrefixMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withSuffixMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -414,6 +419,7 @@ public void parseHeaderMatcher_withSuffixMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_malformedRegExPattern() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -1562,6 +1568,7 @@ public void validateCommonTlsContext_validationContextSdsSecretConfig() } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextCertificateProvider() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1575,6 +1582,7 @@ public void validateCommonTlsContext_validationContextCertificateProvider() } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1600,6 +1608,7 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredFo } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1611,6 +1620,7 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance() } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBootstrapFile() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1625,6 +1635,7 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBoot } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1639,6 +1650,7 @@ public void validateCommonTlsContext_validationContextProviderInstance() } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1679,6 +1691,7 @@ public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateCertificateProvider() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1716,6 +1729,7 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextForServer() throws ResourceInvalidException, InvalidProtocolBufferException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1735,6 +1749,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextForS } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertSpki() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1754,6 +1769,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertHash() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1773,6 +1789,7 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextDfltValContextRequireSignedCertTimestamp() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1793,6 +1810,7 @@ public void validateCommonTlsContext_combinedValContextDfltValContextRequireSign } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValidationContextWithDefaultValidationContextCrl() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1811,6 +1829,7 @@ public void validateCommonTlsContext_combinedValidationContextWithDefaultValidat } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomValidatorConfig() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1838,6 +1857,7 @@ public void validateDownstreamTlsContext_noCommonTlsContext() throws ResourceInv } @Test + @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( @@ -1857,6 +1877,7 @@ public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidE } @Test + @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 78892772992..9e4d92fb344 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -1321,6 +1321,7 @@ public void cdsResponseWithCircuitBreakers() { * CDS response containing UpstreamTlsContext for a cluster. */ @Test + @SuppressWarnings("deprecation") public void cdsResponseWithUpstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index 0da6bf7bde5..eddba1040d4 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -535,6 +535,7 @@ protected Message buildRingHashLbConfig(String hashFunction, long minRingSize, } @Override + @SuppressWarnings("deprecation") protected Message buildUpstreamTlsContext(String instanceName, String certName) { CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); if (instanceName != null && certName != null) { diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index a7a1a7c62e3..d79785c9f32 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -1087,6 +1087,7 @@ public void filterChain_5stepMatch() throws Exception { } @Test + @SuppressWarnings("deprecation") public void filterChainMatch_unsupportedMatchers() throws Exception { EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index c5fe7b3d1bd..d8f1d8aa825 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -155,7 +155,7 @@ public void authenticatedParser() throws Exception { } @Test - @SuppressWarnings("unchecked") + @SuppressWarnings({"unchecked", "deprecation"}) public void headerParser() { HeaderMatcher headerMatcher = HeaderMatcher.newBuilder() .setName("party").setExactMatch("win").build(); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index dfadee957c1..06a3198b263 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -258,6 +258,7 @@ static void verifyWatcher( .isSameInstanceAs(sslContextProvider); } + @SuppressWarnings("deprecation") static CommonTlsContext.Builder addFilenames( CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) { TlsCertificate tlsCert = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index 2918ce56224..81fbda9bde4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -62,6 +62,7 @@ public class CommonTlsContextTestsUtil { public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; /** takes additional values and creates CombinedCertificateValidationContext as needed. */ + @SuppressWarnings("deprecation") static CommonTlsContext buildCommonTlsContextWithAdditionalValues( String certInstanceName, String certName, String validationContextCertInstanceName, String validationContextCertName, @@ -208,6 +209,7 @@ public static String getResourceContents(String resourceName) throws IOException return text; } + @SuppressWarnings("deprecation") private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, @@ -232,6 +234,7 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( return builder.build(); } + @SuppressWarnings("deprecation") private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, diff --git a/xds/third_party/envoy/import.sh b/xds/third_party/envoy/import.sh index 4c3fd1b3c70..c77ee9272e0 100755 --- a/xds/third_party/envoy/import.sh +++ b/xds/third_party/envoy/import.sh @@ -18,7 +18,7 @@ set -e BRANCH=main # import VERSION from one of the google internal CLs -VERSION=62ca8bd2b5960ed1c6ce2be97d3120cee719ecab +VERSION=c223756b0856f734a6a5cff2d0b95388cd2583d4 GIT_REPO="https://github.com/envoyproxy/envoy.git" GIT_BASE_DIR=envoy SOURCE_PROTO_BASE_DIR=envoy/api diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto index a6791c86cd0..08738962c5e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto @@ -230,7 +230,7 @@ message FilterChain { // rules: // - destination_port_range: // start: 3306 -// end: 3306 +// end: 3307 // - destination_port_range: // start: 15000 // end: 15001 diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto index ad129a3ed64..bb53286380c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto @@ -246,6 +246,7 @@ message ResponseFlagFilter { in: "DT" in: "UPE" in: "NC" + in: "OM" } } }]; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto index 431b45b6617..0e8de366333 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto @@ -40,7 +40,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // ` for more detail. // Bootstrap :ref:`configuration overview `. -// [#next-free-field: 31] +// [#next-free-field: 33] message Bootstrap { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Bootstrap"; @@ -260,8 +260,25 @@ message Bootstrap { // This may be overridden on a per-cluster basis in cds_config, when // :ref:`dns_resolution_config ` // is specified. + // *dns_resolution_config* will be deprecated once + // :ref:'typed_dns_resolver_config ' + // is fully supported. core.v3.DnsResolutionConfig dns_resolution_config = 30; + // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, + // or any other DNS resolver types and the related parameters. + // For example, an object of :ref:`DnsResolutionConfig ` + // can be packed into this *typed_dns_resolver_config*. This configuration will replace the + // :ref:'dns_resolution_config ' + // configuration eventually. + // TODO(yanjunxiang): Investigate the deprecation plan for *dns_resolution_config*. + // During the transition period when both *dns_resolution_config* and *typed_dns_resolver_config* exists, + // this configuration is optional. + // When *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. + // When *typed_dns_resolver_config* is missing, the default behavior is in place. + // [#not-implemented-hide:] + core.v3.TypedExtensionConfig typed_dns_resolver_config = 31; + // Specifies optional bootstrap extensions to be instantiated at startup time. // Each item contains extension specific configuration. // [#extension-category: envoy.bootstrap] @@ -305,6 +322,13 @@ message Bootstrap { // field. // [#not-implemented-hide:] map certificate_provider_instances = 25; + + // Specifies a set of headers that need to be registered as inline header. This configuration + // allows users to customize the inline headers on-demand at Envoy startup without modifying + // Envoy's source code. + // + // Note that the 'set-cookie' header cannot be registered as inline header. + repeated CustomInlineHeader inline_headers = 32; } // Administration interface :ref:`operations documentation @@ -578,3 +602,43 @@ message LayeredRuntime { // such that later layers in the list overlay earlier entries. repeated RuntimeLayer layers = 1; } + +// Used to specify the header that needs to be registered as an inline header. +// +// If request or response contain multiple headers with the same name and the header +// name is registered as an inline header. Then multiple headers will be folded +// into one, and multiple header values will be concatenated by a suitable delimiter. +// The delimiter is generally a comma. +// +// For example, if 'foo' is registered as an inline header, and the headers contains +// the following two headers: +// +// .. code-block:: text +// +// foo: bar +// foo: eep +// +// Then they will eventually be folded into: +// +// .. code-block:: text +// +// foo: bar, eep +// +// Inline headers provide O(1) search performance, but each inline header imposes +// an additional memory overhead on all instances of the corresponding type of +// HeaderMap or TrailerMap. +message CustomInlineHeader { + enum InlineHeaderType { + REQUEST_HEADER = 0; + REQUEST_TRAILER = 1; + RESPONSE_HEADER = 2; + RESPONSE_TRAILER = 3; + } + + // The name of the header that is expected to be set as the inline header. + string inline_header_name = 1 + [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_NAME strict: false}]; + + // The type of the header that is expected to be set as the inline header. + InlineHeaderType inline_header_type = 2 [(validate.rules).enum = {defined_only: true}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto index 5470b1807d4..d6213d6fe94 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto @@ -43,7 +43,7 @@ message ClusterCollection { } // Configuration for a single upstream cluster. -// [#next-free-field: 54] +// [#next-free-field: 56] message Cluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster"; @@ -110,7 +110,7 @@ message Cluster { // this option or not. CLUSTER_PROVIDED = 6; - // [#not-implemented-hide:] Use the new :ref:`load_balancing_policy + // Use the new :ref:`load_balancing_policy // ` field to determine the LB policy. // [#next-major-version: In the v3 API, we should consider deprecating the lb_policy field // and instead using the new load_balancing_policy field as the one and only mechanism for @@ -413,8 +413,8 @@ message Cluster { // The table size for Maglev hashing. The Maglev aims for ‘minimal disruption’ rather than an absolute guarantee. // Minimal disruption means that when the set of upstreams changes, a connection will likely be sent to the same // upstream as it was before. Increasing the table size reduces the amount of disruption. - // The table size must be prime number. If it is not specified, the default is 65537. - google.protobuf.UInt64Value table_size = 1; + // The table size must be prime number limited to 5000011. If it is not specified, the default is 65537. + google.protobuf.UInt64Value table_size = 1 [(validate.rules).uint64 = {lte: 5000011}]; } // Specific configuration for the @@ -720,8 +720,7 @@ message Cluster { // The :ref:`load balancer type ` to use // when picking a host in the cluster. - // [#comment:TODO: Remove enum constraint :ref:`LOAD_BALANCING_POLICY_CONFIG` when implemented.] - LbPolicy lb_policy = 6 [(validate.rules).enum = {defined_only: true not_in: 7}]; + LbPolicy lb_policy = 6 [(validate.rules).enum = {defined_only: true}]; // Setting this is required for specifying members of // :ref:`STATIC`, @@ -746,7 +745,11 @@ message Cluster { // is respected by both the HTTP/1.1 and HTTP/2 connection pool // implementations. If not specified, there is no limit. Setting this // parameter to 1 will effectively disable keep alive. - google.protobuf.UInt32Value max_requests_per_connection = 9; + // + // .. attention:: + // This field has been deprecated in favor of the :ref:`max_requests_per_connection ` field. + google.protobuf.UInt32Value max_requests_per_connection = 9 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Optional :ref:`circuit breaking ` for the cluster. CircuitBreakers circuit_breakers = 10; @@ -778,7 +781,7 @@ message Cluster { [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Additional options when handling HTTP1 requests. - // This has been deprecated in favor of http_protocol_options fields in the in the + // This has been deprecated in favor of http_protocol_options fields in the // :ref:`http_protocol_options ` message. // http_protocol_options can be set via the cluster's // :ref:`extension_protocol_options`. @@ -794,7 +797,7 @@ message Cluster { // supports prior knowledge for upstream connections. Even if TLS is used // with ALPN, `http2_protocol_options` must be specified. As an aside this allows HTTP/2 // connections to happen over plain text. - // This has been deprecated in favor of http2_protocol_options fields in the in the + // This has been deprecated in favor of http2_protocol_options fields in the // :ref:`http_protocol_options ` // message. http2_protocol_options can be set via the cluster's // :ref:`extension_protocol_options`. @@ -874,8 +877,32 @@ message Cluster { [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // DNS resolution configuration which includes the underlying dns resolver addresses and options. + // *dns_resolution_config* will be deprecated once + // :ref:'typed_dns_resolver_config ' + // is fully supported. core.v3.DnsResolutionConfig dns_resolution_config = 53; + // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, + // or any other DNS resolver types and the related parameters. + // For example, an object of :ref:`DnsResolutionConfig ` + // can be packed into this *typed_dns_resolver_config*. This configuration will replace the + // :ref:'dns_resolution_config ' + // configuration eventually. + // TODO(yanjunxiang): Investigate the deprecation plan for *dns_resolution_config*. + // During the transition period when both *dns_resolution_config* and *typed_dns_resolver_config* exists, + // this configuration is optional. + // When *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. + // When *typed_dns_resolver_config* is missing, the default behavior is in place. + // [#not-implemented-hide:] + core.v3.TypedExtensionConfig typed_dns_resolver_config = 55; + + // Optional configuration for having cluster readiness block on warm-up. Currently, only applicable for + // :ref:`STRICT_DNS`, + // or :ref:`LOGICAL_DNS`. + // If true, cluster readiness blocks on warm-up. If false, the cluster will complete + // initialization whether or not warm-up has completed. Defaults to true. + google.protobuf.BoolValue wait_for_warm_on_init = 54; + // If specified, outlier detection will be enabled for this upstream cluster. // Each of the configuration values can be overridden via // :ref:`runtime values `. @@ -930,7 +957,7 @@ message Cluster { CommonLbConfig common_lb_config = 27; // Optional custom transport socket implementation to use for upstream connections. - // To setup TLS, set a transport socket with name `tls` and + // To setup TLS, set a transport socket with name `envoy.transport_sockets.tls` and // :ref:`UpstreamTlsContexts ` in the `typed_config`. // If no transport socket configuration is specified, new connections // will be set up with plaintext. @@ -980,7 +1007,7 @@ message Cluster { // servers of this cluster. repeated Filter filters = 40; - // [#not-implemented-hide:] New mechanism for LB policy configuration. Used only if the + // New mechanism for LB policy configuration. Used only if the // :ref:`lb_policy` field has the value // :ref:`LOAD_BALANCING_POLICY_CONFIG`. LoadBalancingPolicy load_balancing_policy = 41; @@ -1045,7 +1072,7 @@ message Cluster { bool connection_pool_per_downstream_connection = 51; } -// [#not-implemented-hide:] Extensible load balancing policy configuration. +// Extensible load balancing policy configuration. // // Every LB policy defined via this mechanism will be identified via a unique name using reverse // DNS notation. If the policy needs configuration parameters, it must define a message for its @@ -1071,14 +1098,11 @@ message LoadBalancingPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.LoadBalancingPolicy.Policy"; - reserved 2; - - reserved "config"; + reserved 2, 1, 3; - // Required. The name of the LB policy. - string name = 1; + reserved "config", "name", "typed_config"; - google.protobuf.Any typed_config = 3; + core.v3.TypedExtensionConfig typed_extension_config = 4; } // Each client will iterate over the list in order and stop at the first policy that it diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto index cf98e537261..8f2347eb551 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto @@ -73,7 +73,7 @@ message UpstreamHttpProtocolOptions { // Configures the alternate protocols cache which tracks alternate protocols that can be used to // make an HTTP connection to an origin server. See https://tools.ietf.org/html/rfc7838 for -// HTTP Alternate Services and https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-04 +// HTTP Alternative Services and https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-04 // for the "HTTPS" DNS resource record. message AlternateProtocolsCacheOptions { // The name of the cache. Multiple named caches allow independent alternate protocols cache @@ -93,7 +93,7 @@ message AlternateProtocolsCacheOptions { google.protobuf.UInt32Value max_entries = 2 [(validate.rules).uint32 = {gt: 0}]; } -// [#next-free-field: 6] +// [#next-free-field: 7] message HttpProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HttpProtocolOptions"; @@ -157,6 +157,12 @@ message HttpProtocolOptions { // If this setting is not specified, the value defaults to ALLOW. // Note: upstream responses are not affected by this setting. HeadersWithUnderscoresAction headers_with_underscores_action = 5; + + // Optional maximum requests for both upstream and downstream connections. + // If not specified, there is no limit. + // Setting this parameter to 1 will effectively disable keep alive. + // For HTTP/2 and HTTP/3, due to concurrent stream processing, the limit is approximate. + google.protobuf.UInt32Value max_requests_per_connection = 6; } // [#next-free-field: 8] @@ -478,3 +484,11 @@ message Http3ProtocolOptions { // `. google.protobuf.BoolValue override_stream_error_on_invalid_http_message = 2; } + +// A message to control transformations to the :scheme header +message SchemeHeaderTransformation { + oneof transformation { + // Overwrite any Scheme header with the contents of this string. + string scheme_to_overwrite = 1 [(validate.rules).string = {in: "http" in: "https"}]; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto index 0e10ac3b2fc..0a9aac105e7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto @@ -4,10 +4,12 @@ package envoy.config.endpoint.v3; import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; import "envoy/config/core/v3/health_check.proto"; import "google/protobuf/wrappers.proto"; +import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -108,21 +110,51 @@ message LbEndpoint { google.protobuf.UInt32Value load_balancing_weight = 4 [(validate.rules).uint32 = {gte: 1}]; } +// [#not-implemented-hide:] +// A configuration for a LEDS collection. +message LedsClusterLocalityConfig { + // Configuration for the source of LEDS updates for a Locality. + core.v3.ConfigSource leds_config = 1; + + // The xDS transport protocol glob collection resource name. + // The service is only supported in delta xDS (incremental) mode. + string leds_collection_name = 2; +} + // A group of endpoints belonging to a Locality. // One can have multiple LocalityLbEndpoints for a locality, but this is // generally only done if the different groups need to have different load // balancing weights or different priorities. -// [#next-free-field: 7] +// [#next-free-field: 9] message LocalityLbEndpoints { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.endpoint.LocalityLbEndpoints"; + // [#not-implemented-hide:] + // A list of endpoints of a specific locality. + message LbEndpointList { + repeated LbEndpoint lb_endpoints = 1; + } + // Identifies location of where the upstream hosts run. core.v3.Locality locality = 1; // The group of endpoints belonging to the locality specified. + // [#comment:TODO(adisuissa): Once LEDS is implemented this field needs to be + // deprecated and replaced by *load_balancer_endpoints*.] repeated LbEndpoint lb_endpoints = 2; + // [#not-implemented-hide:] + oneof lb_config { + // The group of endpoints belonging to the locality. + // [#comment:TODO(adisuissa): Once LEDS is implemented the *lb_endpoints* field + // needs to be deprecated.] + LbEndpointList load_balancer_endpoints = 7; + + // LEDS Configuration for the current locality. + LedsClusterLocalityConfig leds_cluster_locality_config = 8; + } + // Optional: Per priority/region/zone/sub_zone weight; at least 1. The load // balancing weight for a locality is divided by the sum of the weights of all // localities at the same priority level to produce the effective percentage diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto index 1dc94edc74b..77db7caaff5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto @@ -23,6 +23,7 @@ message ApiListener { // The type in this field determines the type of API listener. At present, the following // types are supported: // envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager (HTTP) + // envoy.extensions.filters.network.http_connection_manager.v3.EnvoyMobileHttpConnectionManager (HTTP) // [#next-major-version: In the v3 API, replace this Any field with a oneof containing the // specific config message for each type of API listener. We could not do this in v2 because // it would have caused circular dependencies for go protos: lds.proto depends on this file, diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto index b5bda9562ce..a5cd4bfe976 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto @@ -35,7 +35,7 @@ message ListenerCollection { repeated xds.core.v3.CollectionEntry entries = 1; } -// [#next-free-field: 29] +// [#next-free-field: 30] message Listener { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Listener"; @@ -255,17 +255,30 @@ message Listener { // enable the balance config in Y1 and Y2 to balance the connections among the workers. ConnectionBalanceConfig connection_balance_config = 20; + // Deprecated. Use `enable_reuse_port` instead. + bool reuse_port = 21 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + // When this flag is set to true, listeners set the *SO_REUSEPORT* socket option and // create one socket for each worker thread. This makes inbound connections // distribute among worker threads roughly evenly in cases where there are a high number - // of connections. When this flag is set to false, all worker threads share one socket. + // of connections. When this flag is set to false, all worker threads share one socket. This field + // defaults to true. + // + // .. attention:: + // + // Although this field defaults to true, it has different behavior on different platforms. See + // the following text for more information. // - // Before Linux v4.19-rc1, new TCP connections may be rejected during hot restart - // (see `3rd paragraph in 'soreuseport' commit message - // `_). - // This issue was fixed by `tcp: Avoid TCP syncookie rejected by SO_REUSEPORT socket - // `_. - bool reuse_port = 21; + // * On Linux, reuse_port is respected for both TCP and UDP listeners. It also works correctly + // with hot restart. + // * On macOS, reuse_port for TCP does not do what it does on Linux. Instead of load balancing, + // the last socket wins and receives all connections/packets. For TCP, reuse_port is force + // disabled and the user is warned. For UDP, it is enabled, but only one worker will receive + // packets. For QUIC/H3, SW routing will send packets to other workers. For "raw" UDP, only + // a single worker will currently receive packets. + // * On Windows, reuse_port for TCP has undefined behavior. It is force disabled and the user + // is warned similar to macOS. It is left enabled for UDP with undefined behavior currently. + google.protobuf.BoolValue enable_reuse_port = 29; // Configuration for :ref:`access logs ` // emitted by this listener. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto index e6d73b791c2..e737b14b174 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto @@ -64,9 +64,12 @@ message Filter { // 3. Server name (e.g. SNI for TLS protocol), // 4. Transport protocol. // 5. Application protocols (e.g. ALPN for TLS protocol). -// 6. Source type (e.g. any, local or external network). -// 7. Source IP address. -// 8. Source port. +// 6. Directly connected source IP address (this will only be different from the source IP address +// when using a listener filter that overrides the source address, such as the :ref:`Proxy Protocol +// listener filter `). +// 7. Source type (e.g. any, local or external network). +// 8. Source IP address. +// 9. Source port. // // For criteria that allow ranges or wildcards, the most specific value in any // of the configured filter chains that matches the incoming connection is going @@ -90,7 +93,7 @@ message Filter { // listed at the end, because that's how we want to list them in the docs. // // [#comment:TODO(PiotrSikora): Add support for configurable precedence of the rules] -// [#next-free-field: 13] +// [#next-free-field: 14] message FilterChainMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.FilterChainMatch"; @@ -124,6 +127,11 @@ message FilterChainMatch { // [#not-implemented-hide:] google.protobuf.UInt32Value suffix_len = 5; + // The criteria is satisfied if the directly connected source IP address of the downstream + // connection is contained in at least one of the specified subnets. If the parameter is not + // specified or the list is empty, the directly connected source IP address is ignored. + repeated core.v3.CidrRange direct_source_prefix_ranges = 13; + // Specifies the connection source IP match type. Can be any, local or external network. ConnectionSourceType source_type = 12 [(validate.rules).enum = {defined_only: true}]; @@ -238,7 +246,7 @@ message FilterChain { core.v3.Metadata metadata = 5; // Optional custom transport socket implementation to use for downstream connections. - // To setup TLS, set a transport socket with name `tls` and + // To setup TLS, set a transport socket with name `envoy.transport_sockets.tls` and // :ref:`DownstreamTlsContext ` in the `typed_config`. // If no transport socket configuration is specified, new connections // will be set up with plaintext. @@ -282,7 +290,7 @@ message FilterChain { // rules: // - destination_port_range: // start: 3306 -// end: 3306 +// end: 3307 // - destination_port_range: // start: 15000 // end: 15001 diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto index 4445af63211..85fa761dbdd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto @@ -141,6 +141,26 @@ message OverloadAction { google.protobuf.Any typed_config = 3; } +// Configuration for which accounts the WatermarkBuffer Factories should +// track. +message BufferFactoryConfig { + // The minimum power of two at which Envoy starts tracking an account. + // + // Envoy has 8 power of two buckets starting with the provided exponent below. + // Concretely the 1st bucket contains accounts for streams that use + // [2^minimum_account_to_track_power_of_two, + // 2^(minimum_account_to_track_power_of_two + 1)) bytes. + // With the 8th bucket tracking accounts + // >= 128 * 2^minimum_account_to_track_power_of_two. + // + // The maximum value is 56, since we're using uint64_t for bytes counting, + // and that's the last value that would use the 8 buckets. In practice, + // we don't expect the proxy to be holding 2^56 bytes. + // + // If omitted, Envoy should not do any tracking. + uint32 minimum_account_to_track_power_of_two = 1 [(validate.rules).uint32 = {lte: 56 gte: 10}]; +} + message OverloadManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.overload.v2alpha.OverloadManager"; @@ -153,4 +173,7 @@ message OverloadManager { // The set of overload actions. repeated OverloadAction actions = 3; + + // Configuration for buffer factory. + BufferFactoryConfig buffer_factory_config = 4; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto index 3b7f79d605d..d66f9be2b49 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto @@ -7,6 +7,7 @@ import "envoy/config/route/v3/route_components.proto"; import "envoy/type/matcher/v3/metadata.proto"; import "envoy/type/matcher/v3/path.proto"; import "envoy/type/matcher/v3/string.proto"; +import "envoy/type/v3/range.proto"; import "google/api/expr/v1alpha1/checked.proto"; import "google/api/expr/v1alpha1/syntax.proto"; @@ -60,7 +61,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // permissions: // - and_rules: // rules: -// - header: { name: ":method", exact_match: "GET" } +// - header: +// name: ":method" +// string_match: +// exact: "GET" // - url_path: // path: { prefix: "/products" } // - or_rules: @@ -142,7 +146,7 @@ message Policy { } // Permission defines an action (or actions) that a principal can take. -// [#next-free-field: 11] +// [#next-free-field: 12] message Permission { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Permission"; @@ -182,6 +186,9 @@ message Permission { // A port number that describes the destination port connecting to. uint32 destination_port = 6 [(validate.rules).uint32 = {lte: 65535}]; + // A port number range that describes a range of destination ports connecting to. + type.v3.Int32Range destination_port_range = 11; + // Metadata that describes additional information about the action. type.matcher.v3.MetadataMatcher metadata = 7; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto index 80956fdeb4e..e2bf52165be 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto @@ -4,6 +4,7 @@ package envoy.config.route.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/config_source.proto"; +import "envoy/config/core/v3/extension.proto"; import "envoy/config/route/v3/route_components.proto"; import "google/protobuf/wrappers.proto"; @@ -21,7 +22,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // * Routing :ref:`architecture overview ` // * HTTP :ref:`router filter ` -// [#next-free-field: 12] +// [#next-free-field: 13] message RouteConfiguration { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.RouteConfiguration"; @@ -119,6 +120,18 @@ message RouteConfiguration { // is not subject to data plane buffering controls. // google.protobuf.UInt32Value max_direct_response_body_size_bytes = 11; + + // [#not-implemented-hide:] + // A list of plugins and their configurations which may be used by a + // :ref:`envoy_v3_api_field_config.route.v3.RouteAction.cluster_specifier_plugin` + // within the route. All *extension.name* fields in this list must be unique. + repeated ClusterSpecifierPlugin cluster_specifier_plugins = 12; +} + +// Configuration for a cluster specifier plugin. +message ClusterSpecifierPlugin { + // The name of the plugin and its opaque configuration. + core.v3.TypedExtensionConfig extension = 1; } message Vhds { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto index ee82e8f7322..dfb8b8ed1a1 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto @@ -311,7 +311,7 @@ message Route { message WeightedCluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster"; - // [#next-free-field: 11] + // [#next-free-field: 12] message ClusterWeight { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster.ClusterWeight"; @@ -378,6 +378,13 @@ message WeightedCluster { // :ref:`FilterConfig` // message to specify additional options.] map typed_per_filter_config = 10; + + oneof host_rewrite_specifier { + // Indicates that during forwarding, the host header will be swapped with + // this value. + string host_rewrite_literal = 11 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; + } } // Specifies one or more upstream clusters associated with the route. @@ -466,7 +473,7 @@ message RouteMatch { } // Indicates that prefix/path matching should be case sensitive. The default - // is true. + // is true. Ignored for safe_regex matching. google.protobuf.BoolValue case_sensitive = 4; // Indicates that the route should additionally match on a runtime key. Every time the route @@ -563,7 +570,7 @@ message CorsPolicy { core.v3.RuntimeFractionalPercent shadow_enabled = 10; } -// [#next-free-field: 37] +// [#next-free-field: 38] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction"; @@ -839,6 +846,14 @@ message RouteAction { // :ref:`traffic splitting ` // for additional documentation. WeightedCluster weighted_clusters = 3; + + // [#not-implemented-hide:] + // Name of the cluster specifier plugin to use to determine the cluster for + // requests on this route. The plugin name must be defined in the associated + // :ref:`envoy_v3_api_field_config.route.v3.RouteConfiguration.cluster_specifier_plugins` + // in the + // :ref:`envoy_v3_api_field_config.core.v3.TypedExtensionConfig.name` field. + string cluster_specifier_plugin = 37; } // The HTTP status code to use when configured cluster is not found. @@ -1850,7 +1865,7 @@ message RateLimit { // value. // // [#next-major-version: HeaderMatcher should be refactored to use StringMatcher.] -// [#next-free-field: 13] +// [#next-free-field: 14] message HeaderMatcher { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.HeaderMatcher"; @@ -1865,12 +1880,16 @@ message HeaderMatcher { // Specifies how the header match will be performed to route the request. oneof header_match_specifier { // If specified, header match will be performed based on the value of the header. - string exact_match = 4; + // This field is deprecated. Please use :ref:`string_match `. + string exact_match = 4 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If specified, this regex string is a regular expression rule which implies the entire request // header value must match the regex. The rule will not match if only a subsequence of the // request header value matches the regex. - type.matcher.v3.RegexMatcher safe_regex_match = 11; + // This field is deprecated. Please use :ref:`string_match `. + type.matcher.v3.RegexMatcher safe_regex_match = 11 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If specified, header match will be performed based on range. // The rule will match if the request header value is within this range. @@ -1891,28 +1910,46 @@ message HeaderMatcher { // If specified, header match will be performed based on the prefix of the header value. // Note: empty prefix is not allowed, please use present_match instead. + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // // * The prefix *abcd* matches the value *abcdxyz*, but not for *abcxyz*. - string prefix_match = 9 [(validate.rules).string = {min_len: 1}]; + string prefix_match = 9 [ + deprecated = true, + (validate.rules).string = {min_len: 1}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // If specified, header match will be performed based on the suffix of the header value. // Note: empty suffix is not allowed, please use present_match instead. + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // // * The suffix *abcd* matches the value *xyzabcd*, but not for *xyzbcd*. - string suffix_match = 10 [(validate.rules).string = {min_len: 1}]; + string suffix_match = 10 [ + deprecated = true, + (validate.rules).string = {min_len: 1}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // If specified, header match will be performed based on whether the header value contains // the given value or not. // Note: empty contains match is not allowed, please use present_match instead. + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // // * The value *abcd* matches the value *xyzabcdpqr*, but not for *xyzbcdpqr*. - string contains_match = 12 [(validate.rules).string = {min_len: 1}]; + string contains_match = 12 [ + deprecated = true, + (validate.rules).string = {min_len: 1}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // If specified, header match will be performed based on the string match of the header value. + type.matcher.v3.StringMatcher string_match = 13; } // If specified, the match result will be inverted before checking. Defaults to false. diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto index b5b1dbd463f..62da059e264 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto @@ -18,7 +18,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common fault injection types] // Delay specification is used to inject latency into the -// HTTP/gRPC/Mongo/Redis operation or delay proxying of TCP connections. +// HTTP/Mongo operation. // [#next-free-field: 6] message FaultDelay { option (udpa.annotations.versioning).previous_message_type = @@ -46,10 +46,9 @@ message FaultDelay { // Add a fixed delay before forwarding the operation upstream. See // https://developers.google.com/protocol-buffers/docs/proto3#json for - // the JSON/YAML Duration mapping. For HTTP/Mongo/Redis, the specified - // delay will be injected before a new request/operation. For TCP - // connections, the proxying of the connection upstream will be delayed - // for the specified period. This is required if type is FIXED. + // the JSON/YAML Duration mapping. For HTTP/Mongo, the specified + // delay will be injected before a new request/operation. + // This is required if type is FIXED. google.protobuf.Duration fixed_delay = 3 [(validate.rules).duration = {gt {}}]; // Fault delays are controlled via an HTTP header (if applicable). diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto index 856249c2a25..3fb4bfa09e2 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto @@ -19,7 +19,6 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/security.proto"; import "udpa/annotations/status.proto"; @@ -35,7 +34,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // HTTP connection manager :ref:`configuration overview `. // [#extension: envoy.filters.network.http_connection_manager] -// [#next-free-field: 48] +// [#next-free-field: 49] message HttpConnectionManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager"; @@ -371,6 +370,11 @@ message HttpConnectionManager { ServerHeaderTransformation server_header_transformation = 34 [(validate.rules).enum = {defined_only: true}]; + // Allows for explicit transformation of the :scheme header on the request path. + // If not set, Envoy's default :ref:`scheme ` + // handling applies. + config.core.v3.SchemeHeaderTransformation scheme_header_transformation = 48; + // The maximum request headers size for incoming connections. // If unconfigured, the default max request headers allowed is 60 KiB. // Requests that exceed this limit will receive a 431 response. @@ -496,23 +500,7 @@ message HttpConnectionManager { // determining the origin client's IP address. The default is zero if this option // is not specified. See the documentation for // :ref:`config_http_conn_man_headers_x-forwarded-for` for more information. - // - // .. note:: - // This field is deprecated and instead :ref:`original_ip_detection_extensions - // ` - // should be used to configure the :ref:`xff extension ` - // to configure IP detection using the :ref:`config_http_conn_man_headers_x-forwarded-for` header. To replace - // this field use a config like the following: - // - // .. code-block:: yaml - // - // original_ip_detection_extensions: - // typed_config: - // "@type": type.googleapis.com/envoy.extensions.http.original_ip_detection.xff.v3.XffConfig - // xff_num_trusted_hops: 1 - // - uint32 xff_num_trusted_hops = 19 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + uint32 xff_num_trusted_hops = 19; // The configuration for the original IP detection extensions. // @@ -524,6 +512,12 @@ message HttpConnectionManager { // the request. If the request isn't rejected nor any extension succeeds, the HCM will // fallback to using the remote address. // + // .. WARNING:: + // Extensions cannot be used in conjunction with :ref:`use_remote_address + // ` + // nor :ref:`xff_num_trusted_hops + // `. + // // [#extension-category: envoy.http.original_ip_detection] repeated config.core.v3.TypedExtensionConfig original_ip_detection_extensions = 46; @@ -1000,3 +994,12 @@ message RequestIDExtension { // Request ID extension specific configuration. google.protobuf.Any typed_config = 1; } + +// [#protodoc-title: Envoy Mobile HTTP connection manager] +// HTTP connection manager for use in Envoy mobile. +// [#extension: envoy.filters.network.envoy_mobile_http_connection_manager] +message EnvoyMobileHttpConnectionManager { + // The configuration for the underlying HttpConnectionManager which will be + // instantiated for Envoy mobile. + HttpConnectionManager config = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto index aa05a31f23d..82dcb37cd7c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto @@ -9,6 +9,7 @@ import "envoy/type/matcher/v3/string.proto"; import "google/protobuf/any.proto"; import "google/protobuf/wrappers.proto"; +import "udpa/annotations/migrate.proto"; import "udpa/annotations/sensitive.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -232,7 +233,27 @@ message TlsSessionTicketKeys { [(validate.rules).repeated = {min_items: 1}, (udpa.annotations.sensitive) = true]; } -// [#next-free-field: 13] +// Indicates a certificate to be obtained from a named CertificateProvider plugin instance. +// The plugin instances are defined in the client's bootstrap file. +// The plugin allows certificates to be fetched/refreshed over the network asynchronously with +// respect to the TLS handshake. +// [#not-implemented-hide:] +message CertificateProviderPluginInstance { + // Provider instance name. If not present, defaults to "default". + // + // Instance names should generally be defined not in terms of the underlying provider + // implementation (e.g., "file_watcher") but rather in terms of the function of the + // certificates (e.g., "foo_deployment_identity"). + string instance_name = 1; + + // Opaque name used to specify certificate instances or types. For example, "ROOTCA" to specify + // a root-certificate (validation context) or "example.com" to specify a certificate for a + // particular domain. Not all provider instances will actually use this field, so the value + // defaults to the empty string. + string certificate_name = 2; +} + +// [#next-free-field: 14] message CertificateValidationContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CertificateValidationContext"; @@ -279,7 +300,20 @@ message CertificateValidationContext { // directory for any file moves to support rotation. This currently only // applies to dynamic secrets, when the *CertificateValidationContext* is // delivered via SDS. - config.core.v3.DataSource trusted_ca = 1; + // + // Only one of *trusted_ca* and *ca_certificate_provider_instance* may be specified. + // + // [#next-major-version: This field and watched_directory below should ideally be moved into a + // separate sub-message, since there's no point in specifying the latter field without this one.] + config.core.v3.DataSource trusted_ca = 1 + [(udpa.annotations.field_migrate).oneof_promotion = "ca_cert_source"]; + + // Certificate provider instance for fetching TLS certificates. + // + // Only one of *trusted_ca* and *ca_certificate_provider_instance* may be specified. + // [#not-implemented-hide:] + CertificateProviderPluginInstance ca_certificate_provider_instance = 13 + [(udpa.annotations.field_migrate).oneof_promotion = "ca_cert_source"]; // If specified, updates of a file-based *trusted_ca* source will be triggered // by this watch. This allows explicit control over the path watched, by diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto index 02287de5875..f680207955a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -9,7 +9,7 @@ import "envoy/extensions/transport_sockets/tls/v3/secret.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "udpa/annotations/migrate.proto"; +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -125,12 +125,18 @@ message DownstreamTlsContext { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 14] +// [#next-free-field: 15] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; // Config for Certificate provider to get certificates. This provider should allow certificates to be // fetched/refreshed over the network asynchronously with respect to the TLS handshake. + // + // DEPRECATED: This message is not currently used, but if we ever do need it, we will want to + // move it out of CommonTlsContext and into common.proto, similar to the existing + // CertificateProviderPluginInstance message. + // + // [#not-implemented-hide:] message CertificateProvider { // opaque name used to specify certificate instances or types. For example, "ROOTCA" to specify // a root-certificate (validation context) or "TLS" to specify a new tls-certificate. @@ -151,6 +157,11 @@ message CommonTlsContext { // Similar to CertificateProvider above, but allows the provider instances to be configured on // the client side instead of being sent from the control plane. + // + // DEPRECATED: This message was moved outside of CommonTlsContext + // and now lives in common.proto. + // + // [#not-implemented-hide:] message CertificateProviderInstance { // Provider instance name. This name must be defined in the client's configuration (e.g., a // bootstrap file) to correspond to a provider instance (i.e., the same data in the typed_config @@ -179,26 +190,20 @@ message CommonTlsContext { // Config for fetching validation context via SDS API. Note SDS API allows certificates to be // fetched/refreshed over the network asynchronously with respect to the TLS handshake. - // Only one of validation_context_sds_secret_config, validation_context_certificate_provider, - // or validation_context_certificate_provider_instance may be used. - SdsSecretConfig validation_context_sds_secret_config = 2 [ - (validate.rules).message = {required: true}, - (udpa.annotations.field_migrate).oneof_promotion = "dynamic_validation_context" - ]; + SdsSecretConfig validation_context_sds_secret_config = 2 + [(validate.rules).message = {required: true}]; - // Certificate provider for fetching validation context. - // Only one of validation_context_sds_secret_config, validation_context_certificate_provider, - // or validation_context_certificate_provider_instance may be used. + // Certificate provider for fetching CA certs. This will populate the + // *default_validation_context.trusted_ca* field. // [#not-implemented-hide:] CertificateProvider validation_context_certificate_provider = 3 - [(udpa.annotations.field_migrate).oneof_promotion = "dynamic_validation_context"]; + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // Certificate provider instance for fetching validation context. - // Only one of validation_context_sds_secret_config, validation_context_certificate_provider, - // or validation_context_certificate_provider_instance may be used. + // Certificate provider instance for fetching CA certs. This will populate the + // *default_validation_context.trusted_ca* field. // [#not-implemented-hide:] CertificateProviderInstance validation_context_certificate_provider_instance = 4 - [(udpa.annotations.field_migrate).oneof_promotion = "dynamic_validation_context"]; + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } reserved 5; @@ -212,6 +217,12 @@ message CommonTlsContext { // Only a single TLS certificate is supported in client contexts. In server contexts, the first // RSA certificate is used for clients that only support RSA and the first ECDSA certificate is // used for clients that support ECDSA. + // + // Only one of *tls_certificates*, *tls_certificate_sds_secret_configs*, + // and *tls_certificate_provider_instance* may be used. + // [#next-major-version: These mutually exclusive fields should ideally be in a oneof, but it's + // not legal to put a repeated field in a oneof. In the next major version, we should rework + // this to avoid this problem.] repeated TlsCertificate tls_certificates = 2; // Configs for fetching TLS certificates via SDS API. Note SDS API allows certificates to be @@ -220,18 +231,30 @@ message CommonTlsContext { // The same number and types of certificates as :ref:`tls_certificates ` // are valid in the the certificates fetched through this setting. // - // If :ref:`tls_certificates ` - // is non-empty, this field is ignored. + // Only one of *tls_certificates*, *tls_certificate_sds_secret_configs*, + // and *tls_certificate_provider_instance* may be used. + // [#next-major-version: These mutually exclusive fields should ideally be in a oneof, but it's + // not legal to put a repeated field in a oneof. In the next major version, we should rework + // this to avoid this problem.] repeated SdsSecretConfig tls_certificate_sds_secret_configs = 6 [(validate.rules).repeated = {max_items: 2}]; + // Certificate provider instance for fetching TLS certs. + // + // Only one of *tls_certificates*, *tls_certificate_sds_secret_configs*, + // and *tls_certificate_provider_instance* may be used. + // [#not-implemented-hide:] + CertificateProviderPluginInstance tls_certificate_provider_instance = 14; + // Certificate provider for fetching TLS certificates. // [#not-implemented-hide:] - CertificateProvider tls_certificate_certificate_provider = 9; + CertificateProvider tls_certificate_certificate_provider = 9 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Certificate provider instance for fetching TLS certificates. // [#not-implemented-hide:] - CertificateProviderInstance tls_certificate_certificate_provider_instance = 11; + CertificateProviderInstance tls_certificate_certificate_provider_instance = 11 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; oneof validation_context_type { // How to validate peer certificates. @@ -252,11 +275,13 @@ message CommonTlsContext { // Certificate provider for fetching validation context. // [#not-implemented-hide:] - CertificateProvider validation_context_certificate_provider = 10; + CertificateProvider validation_context_certificate_provider = 10 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Certificate provider instance for fetching validation context. // [#not-implemented-hide:] - CertificateProviderInstance validation_context_certificate_provider_instance = 12; + CertificateProviderInstance validation_context_certificate_provider_instance = 12 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } // Supplies the list of ALPN protocols that the listener should expose. In diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto index 78e1572bf8c..c64edde142f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto @@ -62,8 +62,8 @@ message StringMatcher { string contains = 7 [(validate.rules).string = {min_len: 1}]; } - // If true, indicates the exact/prefix/suffix matching should be case insensitive. This has no - // effect for the safe_regex match. + // If true, indicates the exact/prefix/suffix/contains matching should be case insensitive. This + // has no effect for the safe_regex match. // For example, the matcher *data* will match both input string *Data* and *data* if set to true. bool ignore_case = 6; } From cd346832babd623df3b2c7c2e705f3da2285bc19 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Tue, 7 Sep 2021 21:32:33 -0700 Subject: [PATCH 67/82] rls: migrate deprecated server/path to extraKeys (#8469) The [`server` and `path` fields](https://github.com/grpc/grpc-java/blob/v1.40.1/rls/src/main/proto/grpc/lookup/v1/rls.proto#L25-L32) in `RouteLookupRequest` are deprecated. Instead, we will send the server/path information in side of [`key_map`](https://github.com/grpc/grpc-java/blob/v1.40.1/rls/src/main/proto/grpc/lookup/v1/rls.proto#L45). The keys for the server, service and method in the `key_map` will be the _values_ of `host`, `service`, `method` fields respectively in [`extraKeys`](https://github.com/grpc/grpc-java/blob/v1.40.1/rls/src/main/proto/grpc/lookup/v1/rls_config.proto#L69) in RlsConfig. We will also include all entries in the [`constantKey`](https://github.com/grpc/grpc-java/blob/v1.40.1/rls/src/main/proto/grpc/lookup/v1/rls_config.proto#L80) in RlsConfig into `RouteLookupRequest`. Other changes: - Add AutoValue library for ExtraKeys class, just like data classes used in grpc-xds. Will migrate other data classes to AutoValue as well. - Not to keep `targetType` field in the route lookup request data class, because we always use "grpc" as targetType. --- repositories.bzl | 2 + rls/BUILD.bazel | 21 +++++ rls/build.gradle | 13 +++ .../java/io/grpc/rls/RlsProtoConverters.java | 29 +++--- .../main/java/io/grpc/rls/RlsProtoData.java | 88 +++++++++---------- .../java/io/grpc/rls/RlsRequestFactory.java | 64 ++++++++------ .../io/grpc/rls/CachingRlsLbClientTest.java | 40 ++++----- .../java/io/grpc/rls/RlsLoadBalancerTest.java | 21 +++-- .../io/grpc/rls/RlsProtoConvertersTest.java | 38 ++++---- .../io/grpc/rls/RlsRequestFactoryTest.java | 54 +++++++----- 10 files changed, 217 insertions(+), 153 deletions(-) diff --git a/repositories.bzl b/repositories.bzl index 6cf75aa7bb6..0d6e9ab2f74 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -16,6 +16,8 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.auth:google-auth-library-oauth2-http:0.22.0", "com.google.code.findbugs:jsr305:3.0.2", "com.google.code.gson:gson:jar:2.8.6", + "com.google.auto.value:auto-value:1.7.4", + "com.google.auto.value:auto-value-annotations:1.7.4", "com.google.errorprone:error_prone_annotations:2.9.0", "com.google.guava:failureaccess:1.0.1", "com.google.guava:guava:30.1-android", diff --git a/rls/BUILD.bazel b/rls/BUILD.bazel index 0da03fc924b..4daa7029560 100644 --- a/rls/BUILD.bazel +++ b/rls/BUILD.bazel @@ -7,12 +7,14 @@ java_library( ]), visibility = ["//visibility:public"], deps = [ + ":autovalue", ":rls_java_grpc", "//api", "//core", "//core:internal", "//core:util", "//stub", + "@com_google_auto_value_auto_value_annotations//jar", "@com_google_code_findbugs_jsr305//jar", "@com_google_guava_guava//jar", "@io_grpc_grpc_proto//:rls_java_proto", @@ -20,6 +22,25 @@ java_library( ], ) +java_plugin( + name = "autovalue_plugin", + processor_class = "com.google.auto.value.processor.AutoValueProcessor", + deps = [ + "@com_google_auto_value_auto_value//jar", + ], +) + +java_library( + name = "autovalue", + exported_plugins = [ + ":autovalue_plugin", + ], + neverlink = 1, + exports = [ + "@com_google_auto_value_auto_value//jar", + ], +) + java_grpc_library( name = "rls_java_grpc", srcs = ["@io_grpc_grpc_proto//:rls_proto"], diff --git a/rls/build.gradle b/rls/build.gradle index a2ebf2a62ef..45f17fb71c3 100644 --- a/rls/build.gradle +++ b/rls/build.gradle @@ -14,7 +14,9 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), + libraries.autovalue_annotation, libraries.guava + annotationProcessor libraries.autovalue compileOnly libraries.javax_annotation testImplementation libraries.truth, project(':grpc-grpclb'), @@ -24,6 +26,17 @@ dependencies { signature "org.codehaus.mojo.signature:java17:1.0@signature" } +[compileJava].each() { + it.options.compilerArgs += [ + // only has AutoValue annotation processor + "-Xlint:-processing", + ] + appendToProperty( + it.options.errorprone.excludedPaths, + ".*/build/generated/sources/annotationProcessor/java/.*", + "|") +} + javadoc { // Do not publish javadoc since currently there is no public API. failOnError false // no public or protected classes found to document diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java index 32df13c4262..ce89def4467 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java @@ -20,9 +20,11 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.Converter; +import com.google.common.collect.ImmutableMap; import io.grpc.internal.JsonUtil; import io.grpc.lookup.v1.RouteLookupRequest; import io.grpc.lookup.v1.RouteLookupResponse; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -46,25 +48,16 @@ final class RlsProtoConverters { static final class RouteLookupRequestConverter extends Converter { - @SuppressWarnings("deprecation") @Override protected RlsProtoData.RouteLookupRequest doForward(RouteLookupRequest routeLookupRequest) { - return - new RlsProtoData.RouteLookupRequest( - /* server= */ routeLookupRequest.getServer(), - /* path= */ routeLookupRequest.getPath(), - /* targetType= */ routeLookupRequest.getTargetType(), - routeLookupRequest.getKeyMapMap()); + return new RlsProtoData.RouteLookupRequest(routeLookupRequest.getKeyMapMap()); } - @SuppressWarnings("deprecation") @Override protected RouteLookupRequest doBackward(RlsProtoData.RouteLookupRequest routeLookupRequest) { return RouteLookupRequest.newBuilder() - .setServer(routeLookupRequest.getServer()) - .setPath(routeLookupRequest.getPath()) - .setTargetType(routeLookupRequest.getTargetType()) + .setTargetType("grpc") .putAllKeyMap(routeLookupRequest.getKeyMap()) .build(); } @@ -183,7 +176,19 @@ static GrpcKeyBuilder convert(Map keyBuilder) { matcher.isOptional(), "NameMatcher for GrpcKeyBuilders shouldn't be required"); nameMatchers.add(matcher); } - return new GrpcKeyBuilder(names, nameMatchers); + ExtraKeys extraKeys = ExtraKeys.DEFAULT; + Map rawExtraKeys = + (Map) JsonUtil.getObject(keyBuilder, "extraKeys"); + if (rawExtraKeys != null) { + extraKeys = ExtraKeys.create( + rawExtraKeys.get("host"), rawExtraKeys.get("service"), rawExtraKeys.get("method")); + } + Map constantKeys = + (Map) JsonUtil.getObject(keyBuilder, "constantKeys"); + if (constantKeys == null) { + constantKeys = ImmutableMap.of(); + } + return new GrpcKeyBuilder(names, nameMatchers, extraKeys, constantKeys); } } diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoData.java b/rls/src/main/java/io/grpc/rls/RlsProtoData.java index fbcb6feb21c..3556ca609be 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoData.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoData.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.auto.value.AutoValue; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; import com.google.common.collect.ImmutableList; @@ -40,46 +41,12 @@ final class RlsProtoData { @Immutable static final class RouteLookupRequest { - private final String server; - - private final String path; - - private final String targetType; - private final ImmutableMap keyMap; - RouteLookupRequest( - String server, String path, String targetType, Map keyMap) { - this.server = checkNotNull(server, "server"); - this.path = checkNotNull(path, "path"); - this.targetType = checkNotNull(targetType, "targetName"); + RouteLookupRequest(Map keyMap) { this.keyMap = ImmutableMap.copyOf(checkNotNull(keyMap, "keyMap")); } - /** - * Returns a full host name of the target server, {@literal e.g.} firestore.googleapis.com. Only - * set for gRPC requests; HTTP requests must use key_map explicitly. - */ - String getServer() { - return server; - } - - /** - * Returns a full path of the request, {@literal i.e.} "/service/method". Only set for gRPC - * requests; HTTP requests must use key_map explicitly. - */ - String getPath() { - return path; - } - - /** - * Returns the target type allows the client to specify what kind of target format it would like - * from RLS to allow it to find the regional server, {@literal e.g.} "grpc". - */ - String getTargetType() { - return targetType; - } - /** Returns a map of key values extracted via key builders for the gRPC or HTTP request. */ ImmutableMap getKeyMap() { return keyMap; @@ -94,23 +61,17 @@ public boolean equals(Object o) { return false; } RouteLookupRequest that = (RouteLookupRequest) o; - return Objects.equal(server, that.server) - && Objects.equal(path, that.path) - && Objects.equal(targetType, that.targetType) - && Objects.equal(keyMap, that.keyMap); + return Objects.equal(keyMap, that.keyMap); } @Override public int hashCode() { - return Objects.hashCode(server, path, targetType, keyMap); + return Objects.hashCode(keyMap); } @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("server", server) - .add("path", path) - .add("targetName", targetType) .add("keyMap", keyMap) .toString(); } @@ -300,6 +261,7 @@ ImmutableList getValidTargets() { * error. Note that requests can be routed only to a subdomain of the original target, * {@literal e.g.} "us_east_1.cloudbigtable.googleapis.com". */ + @Nullable String getDefaultTarget() { return defaultTarget; } @@ -431,12 +393,18 @@ static final class GrpcKeyBuilder { private final ImmutableList names; private final ImmutableList headers; + private final ExtraKeys extraKeys; + private final ImmutableMap constantKeys; - public GrpcKeyBuilder(List names, List headers) { + public GrpcKeyBuilder( + List names, List headers, ExtraKeys extraKeys, + Map constantKeys) { checkState(names != null && !names.isEmpty(), "names cannot be empty"); this.names = ImmutableList.copyOf(names); checkUniqueKey(checkNotNull(headers, "headers")); this.headers = ImmutableList.copyOf(headers); + this.extraKeys = checkNotNull(extraKeys, "extraKeys"); + this.constantKeys = ImmutableMap.copyOf(checkNotNull(constantKeys, "constantKeys")); } private static void checkUniqueKey(List headers) { @@ -464,6 +432,14 @@ ImmutableList getHeaders() { return headers; } + ExtraKeys getExtraKeys() { + return extraKeys; + } + + ImmutableMap getConstantKeys() { + return constantKeys; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -473,12 +449,14 @@ public boolean equals(Object o) { return false; } GrpcKeyBuilder that = (GrpcKeyBuilder) o; - return Objects.equal(names, that.names) && Objects.equal(headers, that.headers); + return Objects.equal(names, that.names) && Objects.equal(headers, that.headers) + && Objects.equal(extraKeys, that.extraKeys) + && Objects.equal(constantKeys, that.constantKeys); } @Override public int hashCode() { - return Objects.hashCode(names, headers); + return Objects.hashCode(names, headers, extraKeys, constantKeys); } @Override @@ -486,6 +464,8 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("names", names) .add("headers", headers) + .add("extraKeys", extraKeys) + .add("constantKeys", constantKeys) .toString(); } @@ -548,4 +528,20 @@ public String toString() { } } } + + @AutoValue + abstract static class ExtraKeys { + static final ExtraKeys DEFAULT = create(null, null, null); + + @Nullable abstract String host(); + + @Nullable abstract String service(); + + @Nullable abstract String method(); + + static ExtraKeys create( + @Nullable String host, @Nullable String service, @Nullable String method) { + return new AutoValue_RlsProtoData_ExtraKeys(host, service, method); + } + } } diff --git a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java index b9bcb037cf5..e181d64833d 100644 --- a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java +++ b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java @@ -19,17 +19,18 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Table; +import com.google.common.collect.ImmutableMap; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; import io.grpc.rls.RlsProtoData.RouteLookupConfig; import io.grpc.rls.RlsProtoData.RouteLookupRequest; import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.annotation.CheckReturnValue; @@ -40,10 +41,7 @@ final class RlsRequestFactory { private final String target; - /** - * schema: Path(/serviceName/methodName or /serviceName/*), rls request headerName, header fields. - */ - private final Table keyBuilderTable; + private final Map keyBuilderTable; RlsRequestFactory(RouteLookupConfig rlsConfig, String target) { checkNotNull(rlsConfig, "rlsConfig"); @@ -51,18 +49,15 @@ final class RlsRequestFactory { this.keyBuilderTable = createKeyBuilderTable(rlsConfig); } - private static Table createKeyBuilderTable( + private static Map createKeyBuilderTable( RouteLookupConfig config) { - Table table = HashBasedTable.create(); + Map table = new HashMap<>(); for (GrpcKeyBuilder grpcKeyBuilder : config.getGrpcKeyBuilders()) { - for (NameMatcher nameMatcher : grpcKeyBuilder.getHeaders()) { - for (Name name : grpcKeyBuilder.getNames()) { - String method = - name.getMethod() == null || name.getMethod().isEmpty() - ? "*" : name.getMethod(); - String path = "/" + name.getService() + "/" + method; - table.put(path, nameMatcher.getKey(), nameMatcher); - } + for (Name name : grpcKeyBuilder.getNames()) { + boolean hasMethod = name.getMethod() == null || name.getMethod().isEmpty(); + String method = hasMethod ? "*" : name.getMethod(); + String path = "/" + name.getService() + "/" + method; + table.put(path, grpcKeyBuilder); } } return table; @@ -74,20 +69,35 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { checkNotNull(service, "service"); checkNotNull(method, "method"); String path = "/" + service + "/" + method; - Map keyBuilder = keyBuilderTable.row(path); - // if no matching keyBuilder found, fall back to wildcard match (ServiceName/*) - if (keyBuilder.isEmpty()) { - keyBuilder = keyBuilderTable.row("/" + service + "/*"); + GrpcKeyBuilder grpcKeyBuilder = keyBuilderTable.get(path); + if (grpcKeyBuilder == null) { + // if no matching keyBuilder found, fall back to wildcard match (ServiceName/*) + grpcKeyBuilder = keyBuilderTable.get("/" + service + "/*"); + } + if (grpcKeyBuilder == null) { + return new RouteLookupRequest(ImmutableMap.of()); + } + Map rlsRequestHeaders = + createRequestHeaders(metadata, grpcKeyBuilder.getHeaders()); + ExtraKeys extraKeys = grpcKeyBuilder.getExtraKeys(); + Map constantKeys = grpcKeyBuilder.getConstantKeys(); + if (extraKeys.host() != null) { + rlsRequestHeaders.put(extraKeys.host(), target); + } + if (extraKeys.service() != null) { + rlsRequestHeaders.put(extraKeys.service(), service); + } + if (extraKeys.method() != null) { + rlsRequestHeaders.put(extraKeys.method(), method); } - Map rlsRequestHeaders = createRequestHeaders(metadata, keyBuilder); - return new RouteLookupRequest(target, path, "grpc", rlsRequestHeaders); + rlsRequestHeaders.putAll(constantKeys); + return new RouteLookupRequest(rlsRequestHeaders); } private Map createRequestHeaders( - Metadata metadata, Map keyBuilder) { + Metadata metadata, List keyBuilder) { Map rlsRequestHeaders = new HashMap<>(); - for (Map.Entry entry : keyBuilder.entrySet()) { - NameMatcher nameMatcher = entry.getValue(); + for (NameMatcher nameMatcher : keyBuilder) { String value = null; for (String requestHeaderName : nameMatcher.names()) { value = metadata.get(Metadata.Key.of(requestHeaderName, Metadata.ASCII_STRING_MARSHALLER)); @@ -96,11 +106,11 @@ private Map createRequestHeaders( } } if (value != null) { - rlsRequestHeaders.put(entry.getKey(), value); + rlsRequestHeaders.put(nameMatcher.getKey(), value); } else if (!nameMatcher.isOptional()) { throw new StatusRuntimeException( Status.INVALID_ARGUMENT.withDescription( - String.format("Missing mandatory metadata(%s) not found", entry.getKey()))); + String.format("Missing mandatory metadata(%s) not found", nameMatcher.getKey()))); } } return rlsRequestHeaders; diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index c2c221800ab..aa64ec890b6 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -65,6 +65,7 @@ import io.grpc.rls.LruCache.EvictionListener; import io.grpc.rls.LruCache.EvictionType; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -191,9 +192,8 @@ public void run() { public void get_noError_lifeCycle() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(evictionListener); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -242,9 +242,8 @@ public void get_noError_lifeCycle() throws Exception { public void rls_overDirectPath() throws Exception { CachingRlsLbClient.enableOobChannelDirectPath = true; setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -276,8 +275,8 @@ public void rls_overDirectPath() throws Exception { @Test public void get_throttledAndRecover() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest("server", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -319,9 +318,8 @@ public void get_throttledAndRecover() throws Exception { public void get_updatesLbState() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(helper); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -349,7 +347,8 @@ public void get_updatesLbState() throws Exception { Metadata headers = new Metadata(); PickResult pickResult = pickerCaptor.getValue().pickSubchannel( new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod().toBuilder().setFullMethodName("foo/bar").build(), + TestMethodDescriptors.voidMethod().toBuilder().setFullMethodName("service1/create") + .build(), headers, CallOptions.DEFAULT)); assertThat(pickResult.getStatus().isOk()).isTrue(); @@ -360,8 +359,7 @@ public void get_updatesLbState() throws Exception { fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); // try to get invalid RouteLookupRequest invalidRouteLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/doesn/exists", "grpc", ImmutableMap.of()); + new RouteLookupRequest(ImmutableMap.of()); CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); assertThat(errorResp.isPending()).isTrue(); fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); @@ -369,7 +367,7 @@ public void get_updatesLbState() throws Exception { errorResp = getInSyncContext(invalidRouteLookupRequest); assertThat(errorResp.hasError()).isTrue(); - // Channel is still READY because the subchannel for method /foo/bar is still READY. + // Channel is still READY because the subchannel for method /service1/create is still READY. // Method /doesn/exists will use fallback child balancer and fail immediately. inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); @@ -387,10 +385,10 @@ public void get_updatesLbState() throws Exception { @Test public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest("server", "/foo/bar", "grpc", ImmutableMap.of()); - RouteLookupRequest routeLookupRequest2 = - new RouteLookupRequest("server", "/foo/baz", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RouteLookupRequest routeLookupRequest2 = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "baz")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, new RouteLookupResponse(ImmutableList.of("target"), "header"), @@ -426,7 +424,9 @@ private static RouteLookupConfig getRouteLookupConfig() { ImmutableList.of(new Name("service1", "create")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)))), + new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)), + ExtraKeys.create("server", "service-key", "method-key"), + ImmutableMap.of())), /* lookupService= */ "service1", /* lookupServiceTimeoutInMillis= */ TimeUnit.SECONDS.toMillis(2), /* maxAgeInMillis= */ TimeUnit.SECONDS.toMillis(300), diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index 3b7e84bd543..fba295f98f0 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -144,13 +144,15 @@ public void setUp() throws Exception { .build(); fakeRlsServerImpl.setLookupTable( ImmutableMap.of( - new RouteLookupRequest( - "fake-bigtable.googleapis.com", "/com.google/Search", "grpc", - ImmutableMap.of()), + new RouteLookupRequest(ImmutableMap.of( + "server", "fake-bigtable.googleapis.com", + "service-key", "com.google", + "method-key", "Search")), new RouteLookupResponse(ImmutableList.of("wilderness"), "where are you?"), - new RouteLookupRequest( - "fake-bigtable.googleapis.com", "/com.google/Rescue", "grpc", - ImmutableMap.of()), + new RouteLookupRequest(ImmutableMap.of( + "server", "fake-bigtable.googleapis.com", + "service-key", "com.google", + "method-key", "Rescue")), new RouteLookupResponse(ImmutableList.of("civilization"), "you are safe"))); rlsLb = (RlsLoadBalancer) provider.newLoadBalancer(helper); @@ -409,7 +411,12 @@ private String getRlsConfigJsonStr() { + " \"names\": [\"PermitId\"],\n" + " \"optional\": true\n" + " }\n" - + " ]\n" + + " ],\n" + + " \"extraKeys\": {\n" + + " \"host\": \"server\",\n" + + " \"service\": \"service-key\",\n" + + " \"method\": \"method-key\"\n" + + " }\n" + " }\n" + " ],\n" + " \"lookupService\": \"localhost:8972\",\n" diff --git a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java index a50cdeb9f68..bfb331e6cc8 100644 --- a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java @@ -27,6 +27,7 @@ import io.grpc.rls.RlsProtoConverters.RouteLookupConfigConverter; import io.grpc.rls.RlsProtoConverters.RouteLookupRequestConverter; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -41,40 +42,29 @@ @RunWith(JUnit4.class) public class RlsProtoConvertersTest { - @SuppressWarnings("deprecation") @Test public void convert_toRequestProto() { Converter converter = new RouteLookupRequestConverter(); RouteLookupRequest proto = RouteLookupRequest.newBuilder() - .setServer("server") - .setPath("path") - .setTargetType("target") .putKeyMap("key1", "val1") .build(); RlsProtoData.RouteLookupRequest object = converter.convert(proto); - assertThat(object.getServer()).isEqualTo("server"); - assertThat(object.getPath()).isEqualTo("path"); - assertThat(object.getTargetType()).isEqualTo("target"); assertThat(object.getKeyMap()).containsExactly("key1", "val1"); } - @SuppressWarnings("deprecation") @Test public void convert_toRequestObject() { Converter converter = new RouteLookupRequestConverter().reverse(); RlsProtoData.RouteLookupRequest requestObject = - new RlsProtoData.RouteLookupRequest( - "server", "path", "target", ImmutableMap.of("key1", "val1")); + new RlsProtoData.RouteLookupRequest(ImmutableMap.of("key1", "val1")); RouteLookupRequest proto = converter.convert(requestObject); - assertThat(proto.getServer()).isEqualTo("server"); - assertThat(proto.getPath()).isEqualTo("path"); - assertThat(proto.getTargetType()).isEqualTo("target"); + assertThat(proto.getTargetType()).isEqualTo("grpc"); assertThat(proto.getKeyMapMap()).containsExactly("key1", "val1"); } @@ -164,7 +154,15 @@ public void convert_jsonRlsConfig() throws IOException { + " \"names\": [\"User\", \"Parent\"],\n" + " \"optional\": true\n" + " }\n" - + " ]\n" + + " ],\n" + + " \"extraKeys\": {\n" + + " \"host\": \"host-key\",\n" + + " \"service\": \"service-key\",\n" + + " \"method\": \"method-key\"\n" + + " }, \n" + + " \"constantKeys\": {\n" + + " \"constKey1\": \"value1\"\n" + + " }\n" + " }\n" + " ],\n" + " \"lookupService\": \"service1\",\n" @@ -183,16 +181,22 @@ public void convert_jsonRlsConfig() throws IOException { ImmutableList.of(new Name("service1", "create")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("id", ImmutableList.of("X-Google-Id"), true))), + new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)), + ExtraKeys.DEFAULT, + ImmutableMap.of()), new GrpcKeyBuilder( ImmutableList.of(new Name("service1")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("password", ImmutableList.of("Password"), true))), + new NameMatcher("password", ImmutableList.of("Password"), true)), + ExtraKeys.DEFAULT, + ImmutableMap.of()), new GrpcKeyBuilder( ImmutableList.of(new Name("service3")), ImmutableList.of( - new NameMatcher("user", ImmutableList.of("User", "Parent"), true)))), + new NameMatcher("user", ImmutableList.of("User", "Parent"), true)), + ExtraKeys.create("host-key", "service-key", "method-key"), + ImmutableMap.of("constKey1", "value1"))), /* lookupService= */ "service1", /* lookupServiceTimeoutInMillis= */ TimeUnit.SECONDS.toMillis(2), /* maxAgeInMillis= */ TimeUnit.SECONDS.toMillis(300), diff --git a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java index 8661e023346..b0d197ff525 100644 --- a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java @@ -20,9 +20,11 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.grpc.Metadata; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -43,21 +45,29 @@ public class RlsRequestFactoryTest { ImmutableList.of(new Name("com.google.service1", "Create")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("id", ImmutableList.of("X-Google-Id"), true))), + new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)), + ExtraKeys.create("server-1", null, null), + ImmutableMap.of("const-key-1", "const-value-1")), new GrpcKeyBuilder( ImmutableList.of(new Name("com.google.service1")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("password", ImmutableList.of("Password"), true))), + new NameMatcher("password", ImmutableList.of("Password"), true)), + ExtraKeys.create(null, "service-2", null), + ImmutableMap.of("const-key-2", "const-value-2")), new GrpcKeyBuilder( ImmutableList.of(new Name("com.google.service2")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), false), - new NameMatcher("password", ImmutableList.of("Password"), true))), + new NameMatcher("password", ImmutableList.of("Password"), true)), + ExtraKeys.create(null, "service-3", "method-3"), + ImmutableMap.of()), new GrpcKeyBuilder( ImmutableList.of(new Name("com.google.service3")), ImmutableList.of( - new NameMatcher("user", ImmutableList.of("User", "Parent"), true)))), + new NameMatcher("user", ImmutableList.of("User", "Parent"), true)), + ExtraKeys.create(null, null, null), + ImmutableMap.of("const-key-4", "const-value-4"))), /* lookupService= */ "bigtable-rls.googleapis.com", /* lookupServiceTimeoutInMillis= */ TimeUnit.SECONDS.toMillis(2), /* maxAgeInMillis= */ TimeUnit.SECONDS.toMillis(300), @@ -77,10 +87,11 @@ public void create_pathMatches() { metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); RouteLookupRequest request = factory.create("com.google.service1", "Create", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service1/Create"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test", "id", "123"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", + "id", "123", + "server-1", "bigtable.googleapis.com", + "const-key-1", "const-value-1"); } @Test @@ -106,10 +117,11 @@ public void create_pathFallbackMatches() { RouteLookupRequest request = factory.create("com.google.service1" , "Update", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service1/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test", "password", "hunter2"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", + "password", "hunter2", + "service-2", "com.google.service1", + "const-key-2", "const-value-2"); } @Test @@ -121,10 +133,10 @@ public void create_pathFallbackMatches_optionalHeaderMissing() { RouteLookupRequest request = factory.create("com.google.service1", "Update", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service1/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", + "service-2", "com.google.service1", + "const-key-2", "const-value-2"); } @Test @@ -135,10 +147,6 @@ public void create_unknownPath() { metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); RouteLookupRequest request = factory.create("abc.def.service999", "Update", metadata); - - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/abc.def.service999/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); assertThat(request.getKeyMap()).isEmpty(); } @@ -151,9 +159,7 @@ public void create_noMethodInRlsConfig() { RouteLookupRequest request = factory.create("com.google.service3", "Update", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service3/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", "const-key-4", "const-value-4"); } } From d4b96c6b1a0db2c17ff315b3d807663cdd931fb7 Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 8 Sep 2021 23:15:02 +0000 Subject: [PATCH 68/82] xds: remove hashCode() and equals() for SslContextProviderSupplier (#8497) --- .../sds/SslContextProviderSupplier.java | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 3300c22b2bf..17fc442e7da 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -25,7 +25,6 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; -import java.util.Objects; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -119,26 +118,6 @@ public synchronized void close() { shutdown = true; } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - SslContextProviderSupplier that = (SslContextProviderSupplier) o; - return shutdown == that.shutdown - && Objects.equals(tlsContext, that.tlsContext) - && Objects.equals(tlsContextManager, that.tlsContextManager) - && Objects.equals(sslContextProvider, that.sslContextProvider); - } - - @Override - public int hashCode() { - return Objects.hash(tlsContext, tlsContextManager, sslContextProvider, shutdown); - } - @Override public String toString() { return MoreObjects.toStringHelper(this) From cba012ef121e2c5976daf5ff5c9b196e618ce19a Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 8 Sep 2021 23:43:17 +0000 Subject: [PATCH 69/82] xds: use the new cert-provider instances if present (#8494) (#8498) --- .../java/io/grpc/xds/ClientXdsClient.java | 95 +++++++++++-------- .../CertProviderClientSslContextProvider.java | 28 ++---- .../CertProviderServerSslContextProvider.java | 28 ++---- .../CertProviderSslContextProvider.java | 48 ++++++++++ .../internal/sds/CommonTlsContextUtil.java | 14 ++- .../io/grpc/xds/ClientXdsClientDataTest.java | 32 +++++-- .../io/grpc/xds/ClientXdsClientTestBase.java | 41 +++++++- .../io/grpc/xds/ClientXdsClientV2Test.java | 6 ++ .../io/grpc/xds/ClientXdsClientV3Test.java | 16 ++++ ...tProviderClientSslContextProviderTest.java | 84 ++++++++++++++++ ...tProviderServerSslContextProviderTest.java | 94 ++++++++++++++++++ .../sds/CommonTlsContextTestsUtil.java | 83 ++++++++++++++++ 12 files changed, 474 insertions(+), 95 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 83845515978..21cf78b1269 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -456,10 +456,6 @@ static void validateCommonTlsContext( if (commonTlsContext.hasTlsParams()) { throw new ResourceInvalidException("common-tls-context with tls_params is not supported"); } - if (commonTlsContext.hasValidationContext()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context is not supported"); - } if (commonTlsContext.hasValidationContextSdsSecretConfig()) { throw new ResourceInvalidException( "common-tls-context with validation_context_sds_secret_config is not supported"); @@ -473,54 +469,50 @@ static void validateCommonTlsContext( "common-tls-context with validation_context_certificate_provider_instance is not" + " supported"); } - String certInstanceName = null; - if (!commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + String certInstanceName = getIdentityCertInstanceName(commonTlsContext); + if (certInstanceName == null) { if (server) { throw new ResourceInvalidException( - "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); + "tls_certificate_provider_instance is required in downstream-tls-context"); } if (commonTlsContext.getTlsCertificatesCount() > 0) { throw new ResourceInvalidException( - "common-tls-context with tls_certificates is not supported"); + "tls_certificate_provider_instance is unset"); } if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { throw new ResourceInvalidException( - "common-tls-context with tls_certificate_sds_secret_configs is not supported"); + "tls_certificate_provider_instance is unset"); } if (commonTlsContext.hasTlsCertificateCertificateProvider()) { throw new ResourceInvalidException( - "common-tls-context with tls_certificate_certificate_provider is not supported"); - } - } else { - certInstanceName = commonTlsContext.getTlsCertificateCertificateProviderInstance() - .getInstanceName(); - } - if (certInstanceName != null) { - if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { - throw new ResourceInvalidException( - "CertificateProvider instance name '" + certInstanceName - + "' not defined in the bootstrap file."); + "tls_certificate_provider_instance is unset"); } + } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { + throw new ResourceInvalidException( + "CertificateProvider instance name '" + certInstanceName + + "' not defined in the bootstrap file."); } - String rootCaInstanceName = null; - if (!commonTlsContext.hasCombinedValidationContext()) { + String rootCaInstanceName = getRootCertInstanceName(commonTlsContext); + if (rootCaInstanceName == null) { if (!server) { throw new ResourceInvalidException( - "combined_validation_context is required in upstream-tls-context"); + "ca_certificate_provider_instance is required in upstream-tls-context"); } } else { - CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext - = commonTlsContext.getCombinedValidationContext(); - if (!combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance()) { + if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { throw new ResourceInvalidException( - "validation_context_certificate_provider_instance is required in" - + " combined_validation_context"); - } - rootCaInstanceName = combinedCertificateValidationContext - .getValidationContextCertificateProviderInstance().getInstanceName(); - if (combinedCertificateValidationContext.hasDefaultValidationContext()) { - CertificateValidationContext certificateValidationContext - = combinedCertificateValidationContext.getDefaultValidationContext(); + "ca_certificate_provider_instance name '" + rootCaInstanceName + + "' not defined in the bootstrap file."); + } + CertificateValidationContext certificateValidationContext = null; + if (commonTlsContext.hasValidationContext()) { + certificateValidationContext = commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext() && commonTlsContext + .getCombinedValidationContext().hasDefaultValidationContext()) { + certificateValidationContext = commonTlsContext.getCombinedValidationContext() + .getDefaultValidationContext(); + } + if (certificateValidationContext != null) { if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) { throw new ResourceInvalidException( "match_subject_alt_names only allowed in upstream_tls_context"); @@ -547,13 +539,38 @@ static void validateCommonTlsContext( } } } - if (rootCaInstanceName != null) { - if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { - throw new ResourceInvalidException( - "ValidationContextProvider instance name '" + rootCaInstanceName - + "' not defined in the bootstrap file."); + } + + private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName(); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName(); + } + return null; + } + + private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + if (commonTlsContext.getValidationContext().hasCaCertificateProviderInstance()) { + return commonTlsContext.getValidationContext().getCaCertificateProviderInstance() + .getInstanceName(); + } + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext + = commonTlsContext.getCombinedValidationContext(); + if (combinedCertificateValidationContext.hasDefaultValidationContext() + && combinedCertificateValidationContext.getDefaultValidationContext() + .hasCaCertificateProviderInstance()) { + return combinedCertificateValidationContext.getDefaultValidationContext() + .getCaCertificateProviderInstance().getInstanceName(); + } else if (combinedCertificateValidationContext + .hasValidationContextCertificateProviderInstance()) { + return combinedCertificateValidationContext + .getValidationContextCertificateProviderInstance().getInstanceName(); } } + return null; } private static void checkForUniqueness(Set uniqueSet, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java index 2ee21e7db6a..ce9ef3de680 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java @@ -22,7 +22,6 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; @@ -94,27 +93,12 @@ public CertProviderClientSslContextProvider getProvider( @Nullable Map certProviders) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); - CommonTlsContext.CertificateProviderInstance rootCertInstance = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = - combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - CommonTlsContext.CertificateProviderInstance certInstance = null; - if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); - } + CertificateValidationContext staticCertValidationContext = getStaticValidationContext( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( + commonTlsContext); return new CertProviderClientSslContextProvider( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java index 1f33e1de789..a7f0849d00b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java @@ -22,7 +22,6 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; @@ -97,27 +96,12 @@ public CertProviderServerSslContextProvider getProvider( @Nullable Map certProviders) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); - CommonTlsContext.CertificateProviderInstance rootCertInstance = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = - combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - CommonTlsContext.CertificateProviderInstance certInstance = null; - if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); - } + CertificateValidationContext staticCertValidationContext = getStaticValidationContext( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( + commonTlsContext); return new CertProviderServerSslContextProvider( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java index 1af9e1670d3..1ec58764196 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java @@ -18,9 +18,11 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; +import io.grpc.xds.internal.sds.CommonTlsContextUtil; import io.grpc.xds.internal.sds.DynamicSslContextProvider; import java.security.PrivateKey; import java.security.cert.X509Certificate; @@ -88,6 +90,52 @@ private static CertificateProviderInfo getCertProviderConfig( return certProviders != null ? certProviders.get(pluginInstanceName) : null; } + @Nullable + protected static CertificateProviderInstance getCertProviderInstance( + CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return CommonTlsContextUtil.convert(commonTlsContext.getTlsCertificateProviderInstance()); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance(); + } + return null; + } + + @Nullable + protected static CertificateValidationContext getStaticValidationContext( + CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + return commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasDefaultValidationContext()) { + return combinedValidationContext.getDefaultValidationContext(); + } + } + return null; + } + + @Nullable + protected static CommonTlsContext.CertificateProviderInstance getRootCertProviderInstance( + CommonTlsContext commonTlsContext) { + CertificateValidationContext certValidationContext = getStaticValidationContext( + commonTlsContext); + if (certValidationContext != null && certValidationContext.hasCaCertificateProviderInstance()) { + return CommonTlsContextUtil.convert(certValidationContext.getCaCertificateProviderInstance()); + } + if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { + return combinedValidationContext.getValidationContextCertificateProviderInstance(); + } + } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { + return commonTlsContext.getValidationContextCertificateProviderInstance(); + } + return null; + } + @Override public final void updateCertificate(PrivateKey key, List certChain) { savedKey = key; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java index 234989ad115..0c28c79ee22 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java @@ -16,11 +16,12 @@ package io.grpc.xds.internal.sds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; /** Class for utility functions for {@link CommonTlsContext}. */ -final class CommonTlsContextUtil { +public final class CommonTlsContextUtil { private CommonTlsContextUtil() {} @@ -38,4 +39,15 @@ private static boolean hasCertProviderValidationContext(CommonTlsContext commonT } return commonTlsContext.hasValidationContextCertificateProviderInstance(); } + + /** + * Converts {@link CertificateProviderPluginInstance} to + * {@link CommonTlsContext.CertificateProviderInstance}. + */ + public static CommonTlsContext.CertificateProviderInstance convert( + CertificateProviderPluginInstance pluginInstance) { + return CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(pluginInstance.getInstanceName()) + .setCertificateName(pluginInstance.getCertificateName()).build(); + } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index fb5c349f123..80cd2a8046e 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -77,6 +77,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; @@ -1551,7 +1552,7 @@ public void validateCommonTlsContext_validationContext() throws ResourceInvalidE .setValidationContext(CertificateValidationContext.getDefaultInstance()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with validation_context is not supported"); + thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1603,14 +1604,26 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredFo .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); + "tls_certificate_provider_instance is required in downstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, true); } + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_tlsNewCertificateProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1").build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); + } + @Test @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateProviderInstance() - throws ResourceInvalidException { + throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setTlsCertificateCertificateProviderInstance( CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) @@ -1662,7 +1675,7 @@ public void validateCommonTlsContext_validationContextProviderInstance_absentInB .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "ValidationContextProvider instance name 'bad-name' not defined in the bootstrap file."); + "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); ClientXdsClient .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); } @@ -1674,7 +1687,7 @@ public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInval .addTlsCertificates(TlsCertificate.getDefaultInstance()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with tls_certificates is not supported"); + thrown.expectMessage("tls_certificate_provider_instance is unset"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1686,7 +1699,7 @@ public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "common-tls-context with tls_certificate_sds_secret_configs is not supported"); + "tls_certificate_provider_instance is unset"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1700,7 +1713,7 @@ public void validateCommonTlsContext_tlsCertificateCertificateProvider() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "common-tls-context with tls_certificate_certificate_provider is not supported"); + "tls_certificate_provider_instance is unset"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1710,7 +1723,7 @@ public void validateCommonTlsContext_combinedValidationContext_isRequiredForClie CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("combined_validation_context is required in upstream-tls-context"); + thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @@ -1723,8 +1736,7 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "validation_context_certificate_provider_instance is required in " - + "combined_validation_context"); + "ca_certificate_provider_instance is required in upstream-tls-context"); ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 9e4d92fb344..55bd6ba3e9d 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -39,6 +39,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import io.envoyproxy.envoy.config.route.v3.FilterConfig; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.BindableService; import io.grpc.Context; @@ -1353,6 +1354,42 @@ public void cdsResponseWithUpstreamTlsContext() { verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); } + /** + * CDS response containing new UpstreamTlsContext for a cluster. + */ + @Test + @SuppressWarnings("deprecation") + public void cdsResponseWithNewUpstreamTlsContext() { + Assume.assumeTrue(useProtocolV3()); + DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + + // Management server sends back CDS response with UpstreamTlsContext. + Any clusterEds = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", + null, true, + mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null)); + List clusters = ImmutableList.of( + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), + clusterEds, + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, false, + null, "envoy.transport_sockets.tls", null))); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + CertificateProviderPluginInstance certificateProviderInstance = + cdsUpdate.upstreamTlsContext().getCommonTlsContext().getValidationContext() + .getCaCertificateProviderInstance(); + assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); + assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + /** * CDS response containing bad UpstreamTlsContext for a cluster. */ @@ -1373,7 +1410,7 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { "CDS response Cluster 'cluster.googleapis.com' validation error: " + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " + "io.grpc.xds.ClientXdsClient$ResourceInvalidException: " - + "combined_validation_context is required in upstream-tls-context")); + + "ca_certificate_provider_instance is required in upstream-tls-context")); verifyNoInteractions(cdsResourceWatcher); } @@ -2400,6 +2437,8 @@ protected abstract Message buildRingHashLbConfig(String hashFunction, long minRi protected abstract Message buildUpstreamTlsContext(String instanceName, String certName); + protected abstract Message buildNewUpstreamTlsContext(String instanceName, String certName); + protected abstract Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java index 409613aecf7..39f5d1a1a2e 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java @@ -515,6 +515,12 @@ protected Message buildUpstreamTlsContext(String instanceName, String certName) .build(); } + @Override + protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { + return buildUpstreamTlsContext(instanceName, certName); + } + + @Override protected Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests) { diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index eddba1040d4..dfd407ef016 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -77,6 +77,8 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; @@ -555,6 +557,20 @@ protected Message buildUpstreamTlsContext(String instanceName, String certName) .build(); } + @Override + protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { + CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); + if (instanceName != null && certName != null) { + commonTlsContextBuilder.setValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName(instanceName) + .setCertificateName(certName).build())); + } + return UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContextBuilder) + .build(); + } + @Override protected Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests) { diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java index 00b29014648..1eed5488aa0 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java @@ -87,6 +87,27 @@ private CertProviderClientSslContextProvider getSslContextProvider( bootstrapInfo.getCertProviders()); } + /** Helper method to build CertProviderClientSslContextProvider. */ + private CertProviderClientSslContextProvider getNewSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + @Test public void testProviderForClient_mtls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -150,6 +171,69 @@ public void testProviderForClient_mtls() throws Exception { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_mtls_newXds() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getNewSslContextProvider( + "gcp_id", + "gcp_id", + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForClient_queueExecutor() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java index ef801ccc2c1..783ce2b11f7 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java @@ -31,12 +31,14 @@ import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; +import java.util.Arrays; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -81,6 +83,30 @@ private CertProviderServerSslContextProvider getSslContextProvider( bootstrapInfo.getCertProviders()); } + /** Helper method to build CertProviderServerSslContextProvider. */ + private CertProviderServerSslContextProvider getNewSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean requireClientCert) { + EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext, + requireClientCert); + return certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + + @Test public void testProviderForServer_mtls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -145,6 +171,74 @@ public void testProviderForServer_mtls() throws Exception { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForServer_mtls_newXds() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder().addAllMatchSubjectAltNames(Arrays + .asList(StringMatcher.newBuilder().setExact("foo.com").build(), + StringMatcher.newBuilder().setExact("bar.com").build())).build(); + CertProviderServerSslContextProvider provider = + getNewSslContextProvider( + "gcp_id", + "gcp_id", + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + staticCertValidationContext, + /* requireClientCert= */ true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForServer_queueExecutor() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index 81fbda9bde4..840cced424f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -22,6 +22,7 @@ import com.google.common.io.CharStreams; import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; @@ -234,6 +235,30 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( return builder.build(); } + private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( + String certInstanceName, + String certName, + String rootInstanceName, + String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (certInstanceName != null) { + builder = + builder.setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); + } + builder = + addNewCertificateValidationContext( + builder, rootInstanceName, rootCertName, staticCertValidationContext); + if (alpnProtocols != null) { + builder.addAllAlpnProtocols(alpnProtocols); + } + return builder.build(); + } + @SuppressWarnings("deprecation") private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, @@ -259,6 +284,26 @@ private static CommonTlsContext.Builder addCertificateValidationContext( return builder; } + private static CommonTlsContext.Builder addNewCertificateValidationContext( + CommonTlsContext.Builder builder, + String rootInstanceName, + String rootCertName, + CertificateValidationContext staticCertValidationContext) { + if (rootInstanceName != null) { + CertificateProviderPluginInstance providerInstance = + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(rootInstanceName) + .setCertificateName(rootCertName) + .build(); + CertificateValidationContext.Builder validationContextBuilder = + staticCertValidationContext != null ? staticCertValidationContext.toBuilder() + : CertificateValidationContext.newBuilder(); + return builder.setValidationContext( + validationContextBuilder.setCaCertificateProviderInstance(providerInstance)); + } + return builder; + } + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextForCertProviderInstance( @@ -278,6 +323,25 @@ private static CommonTlsContext.Builder addCertificateValidationContext( staticCertValidationContext)); } + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.UpstreamTlsContext + buildNewUpstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + return buildUpstreamTlsContext( + buildNewCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext)); + } + /** Helper method to build DownstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContextForCertProviderInstance( @@ -298,6 +362,25 @@ private static CommonTlsContext.Builder addCertificateValidationContext( staticCertValidationContext), requireClientCert); } + /** Helper method to build DownstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.DownstreamTlsContext + buildNewDownstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean requireClientCert) { + return buildInternalDownstreamTlsContext( + buildNewCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext), requireClientCert); + } /** Perform some simple checks on sslContext. */ public static void doChecksOnSslContext(boolean server, SslContext sslContext, From 36abf2501bf70a4f5b609a4a73e83716aed71687 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Thu, 9 Sep 2021 10:01:07 -0700 Subject: [PATCH 70/82] xds: referenciate server routing config (#8491) (#8500) --- .../java/io/grpc/xds/XdsServerWrapper.java | 68 +++---- ...rChainMatchingProtocolNegotiatorsTest.java | 51 ++--- .../io/grpc/xds/XdsServerWrapperTest.java | 184 +++++++++++------- 3 files changed, 150 insertions(+), 153 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index fa8ecf8a822..faa6e9d34b2 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -396,7 +396,7 @@ public void run() { } routeDiscoveryStates.keySet().retainAll(allRds); if (pendingRds.isEmpty()) { - updateSelector(true); + updateSelector(); } } }); @@ -450,14 +450,7 @@ private void shutdown() { releaseSuppliersInFlight(); } - /** - * Use firstTimeNoPendingRds to indicate that the previous SslContextProviderSuppliers in - * filterChainSelectorRef should be released. Call updateSelector(true) when all routing are - * just complete and the newest filter chain is ready to be applied to the - * filterChainSelectorRef. Call updateSelector(false) for subsequent routing update - * corresponding to the same filter chain list. - */ - private void updateSelector(boolean firstTimeNoPendingRds) { + private void updateSelector() { Map filterChainRouting = new HashMap<>(); for (FilterChain filterChain: filterChains) { filterChainRouting.put(filterChain, generateRoutingConfig(filterChain)); @@ -466,10 +459,7 @@ private void updateSelector(boolean firstTimeNoPendingRds) { Collections.unmodifiableMap(filterChainRouting), defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier(), defaultFilterChain == null ? null : generateRoutingConfig(defaultFilterChain)); - List toRelease = Collections.emptyList(); - if (firstTimeNoPendingRds) { - toRelease = getSuppliersInUse(); - } + List toRelease = getSuppliersInUse(); filterChainSelectorRef.set(selector); for (SslContextProviderSupplier e: toRelease) { e.close(); @@ -480,14 +470,12 @@ private void updateSelector(boolean firstTimeNoPendingRds) { private ServerRoutingConfig generateRoutingConfig(FilterChain filterChain) { HttpConnectionManager hcm = filterChain.getHttpConnectionManager(); if (hcm.virtualHosts() != null) { - return ServerRoutingConfig.create(hcm.httpFilterConfigs(), hcm.virtualHosts()); + return ServerRoutingConfig.create(hcm.httpFilterConfigs(), + new AtomicReference<>(hcm.virtualHosts())); } else { RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); - if (rds != null && rds.savedVirtualHosts != null) { - return ServerRoutingConfig.create(hcm.httpFilterConfigs(), rds.savedVirtualHosts); - } else { - return ServerRoutingConfig.FAILING_ROUTING_CONFIG; - } + checkNotNull(rds, "rds"); + return ServerRoutingConfig.create(hcm.httpFilterConfigs(), rds.savedVirtualHosts); } } @@ -555,8 +543,8 @@ private void releaseSuppliersInFlight() { private final class RouteDiscoveryState implements RdsResourceWatcher { private final String resourceName; - @Nullable - private List savedVirtualHosts; + private AtomicReference> savedVirtualHosts = + new AtomicReference<>(); private boolean isPending = true; private RouteDiscoveryState(String resourceName) { @@ -571,7 +559,7 @@ public void run() { if (!routeDiscoveryStates.containsKey(resourceName)) { return; } - savedVirtualHosts = update.virtualHosts; + savedVirtualHosts.set(ImmutableList.copyOf(update.virtualHosts)); maybeUpdateSelector(); } }); @@ -586,7 +574,7 @@ public void run() { return; } logger.log(Level.WARNING, "Rds {0} unavailable", resourceName); - savedVirtualHosts = null; + savedVirtualHosts.set(null); maybeUpdateSelector(); } }); @@ -608,13 +596,13 @@ public void run() { } // Update the selector to use the most recently updated configs only after all rds have been - // discovered, i.e. pendingRds is empty. Do the updateSelector even after rds are already - // fully discovered and new change comes. + // discovered for the first time. Later changes on rds will be applied through virtual host + // list atomic ref. private void maybeUpdateSelector() { isPending = false; - boolean isLastPending = pendingRds.remove(resourceName); - if (pendingRds.isEmpty()) { - updateSelector(isLastPending); + boolean isLastPending = pendingRds.remove(resourceName) && pendingRds.isEmpty(); + if (isLastPending) { + updateSelector(); } } } @@ -644,15 +632,19 @@ public Listener interceptCall(ServerCall call, public Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { ServerRoutingConfig routingConfig = call.getAttributes().get(ATTR_SERVER_ROUTING_CONFIG); - if (routingConfig == null - || routingConfig.equals(ServerRoutingConfig.FAILING_ROUTING_CONFIG)) { - String errorMsg = "Missing xDS routing config. " + (routingConfig == null ? "" : - "RDS config unavailable."); + if (routingConfig == null) { + String errorMsg = "Missing xDS routing config."; + call.close(Status.UNAVAILABLE.withDescription(errorMsg), new Metadata()); + return new Listener() {}; + } + List virtualHosts = routingConfig.virtualHosts().get(); + if (virtualHosts == null) { + String errorMsg = "Missing xDS routing config VirtualHosts due to RDS config unavailable."; call.close(Status.UNAVAILABLE.withDescription(errorMsg), new Metadata()); return new Listener() {}; } VirtualHost virtualHost = RoutingUtils.findVirtualHostForHostName( - routingConfig.virtualHosts(), call.getAuthority()); + virtualHosts, call.getAuthority()); if (virtualHost == null) { call.close( Status.UNAVAILABLE.withDescription("Could not find xDS virtual host matching RPC"), @@ -727,24 +719,20 @@ public Listener interceptCall(ServerCall call, */ @AutoValue abstract static class ServerRoutingConfig { - private static final ServerRoutingConfig FAILING_ROUTING_CONFIG = - new AutoValue_XdsServerWrapper_ServerRoutingConfig( - ImmutableList.of(), ImmutableList.of()); - // Top level http filter configs. abstract ImmutableList httpFilterConfigs(); - abstract ImmutableList virtualHosts(); + abstract AtomicReference> virtualHosts(); /** * Server routing configuration. * */ public static ServerRoutingConfig create(List httpFilterConfigs, - List virtualHosts) { + AtomicReference> virtualHosts) { checkNotNull(httpFilterConfigs, "httpFilterConfigs"); checkNotNull(virtualHosts, "virtualHosts"); return new AutoValue_XdsServerWrapper_ServerRoutingConfig( - ImmutableList.copyOf(httpFilterConfigs), ImmutableList.copyOf(virtualHosts)); + ImmutableList.copyOf(httpFilterConfigs), virtualHosts); } } } diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index d79785c9f32..167f3f03c6b 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; import io.grpc.internal.TestUtils.NoopChannelLogger; @@ -62,6 +63,8 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -88,6 +91,8 @@ public class FilterChainMatchingProtocolNegotiatorsTest { private static final String LOCAL_IP = "10.1.2.3"; // dest private static final String REMOTE_IP = "10.4.2.3"; // source private static final int PORT = 7000; + private final ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + new ArrayList(), new AtomicReference>()); @Test public void nofilterChainMatch_defaultSslContext() throws Exception { @@ -98,8 +103,6 @@ public void nofilterChainMatch_defaultSslContext() throws Exception { SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( new HashMap(), defaultSsl, noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = @@ -154,8 +157,6 @@ public void singleFilterChainWithoutAlpn() throws Exception { "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector(ImmutableMap.of(filterChain, noopConfig), null, null); FilterChainMatchingHandler filterChainMatchingHandler = @@ -195,8 +196,6 @@ public void singleFilterChainWithAlpn() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChain, randomConfig("no-match")), defaultFilterChain.getSslContextProviderSupplier(), noopConfig); @@ -241,12 +240,11 @@ public void destPortFails_returnDefaultFilterChain() throws Exception { tlsContextForDefaultFilterChain, tlsContextManager); ServerRoutingConfig routingConfig = ServerRoutingConfig.create( - new ArrayList(), Arrays.asList(createVirtualHost("virtual"))); - ServerRoutingConfig defaultRoutingConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); + new ArrayList(), new AtomicReference<>( + ImmutableList.of(createVirtualHost("virtual")))); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainWithDestPort, routingConfig), - defaultFilterChain.getSslContextProviderSupplier(), defaultRoutingConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); @@ -259,7 +257,7 @@ public void destPortFails_returnDefaultFilterChain() throws Exception { pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); - assertThat(routingSettable.get()).isEqualTo(defaultRoutingConfig); + assertThat(routingSettable.get()).isEqualTo(noopConfig); assertThat(sslSet.get().getTlsContext()) .isSameInstanceAs(tlsContextForDefaultFilterChain); } @@ -287,8 +285,6 @@ public void destPrefixRangeMatch() throws Exception { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainWithMatch, noopConfig), defaultFilterChain.getSslContextProviderSupplier(), randomConfig("no-match")); @@ -333,8 +329,6 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), defaultFilterChain.getSslContextProviderSupplier(), noopConfig); @@ -380,8 +374,6 @@ public void dest0LengthPrefixRange() "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChain0Length, noopConfig), defaultFilterChain.getSslContextProviderSupplier(), null); @@ -439,8 +431,6 @@ public void destPrefixRange_moreSpecificWins() tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), filterChainMoreSpecific, noopConfig), @@ -500,8 +490,6 @@ public void destPrefixRange_emptyListLessSpecific() tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), filterChainMoreSpecific, noopConfig), @@ -559,8 +547,6 @@ public void destPrefixRangeIpv6_moreSpecificWins() tlsContextMoreSpecific, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), filterChainMoreSpecific, noopConfig), @@ -624,8 +610,6 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, filterChainLessSpecific, randomConfig("no-match")), @@ -669,8 +653,6 @@ public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), defaultFilterChain.getSslContextProviderSupplier(), noopConfig); @@ -716,8 +698,6 @@ public void sourceTypeLocal() throws Exception { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainWithMatch, noopConfig), defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); @@ -777,8 +757,6 @@ public void sourcePrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, filterChainLessSpecific, randomConfig("no-match")), @@ -845,8 +823,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChain1, noopConfig, filterChain2, noopConfig), defaultFilterChain.getSslContextProviderSupplier(), noopConfig); @@ -908,8 +884,6 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChainEmptySourcePorts, randomConfig("no-match"), filterChainSourcePortMatch, noopConfig), @@ -1059,8 +1033,6 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); Map map = new HashMap<>(); map.put(filterChain1, randomConfig("1")); map.put(filterChain2, randomConfig("2")); @@ -1142,8 +1114,6 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, mock(TlsContextManager.class)); - ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new ArrayList()); FilterChainSelector selector = new FilterChainSelector( ImmutableMap.of(filterChain1, randomConfig("1"), filterChain2, randomConfig("2")), defaultFilterChain.getSslContextProviderSupplier(), noopConfig); @@ -1177,7 +1147,8 @@ private static VirtualHost createVirtualHost(String name) { private static ServerRoutingConfig randomConfig(String domain) { return ServerRoutingConfig.create( - new ArrayList(), Arrays.asList(createVirtualHost(domain))); + new ArrayList(), new AtomicReference<>( + ImmutableList.of(createVirtualHost(domain)))); } private EnvoyServerProtoData.DownstreamTlsContext createTls() { diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 4c91d5758f9..876b0913742 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -376,12 +376,11 @@ public void run() { tlsContextManager); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); start.get(5000, TimeUnit.MILLISECONDS); - FilterChainSelector selector = selectorRef.get(); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); - assertThat(selector.getRoutingConfigs()).isEqualTo(ImmutableMap.of( - filterChain, ServerRoutingConfig.create(httpConnectionManager.httpFilterConfigs(), - httpConnectionManager.virtualHosts()) - )); + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); + ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(filterChain); + assertThat(realConfig.virtualHosts().get()).isEqualTo(httpConnectionManager.virtualHosts()); + assertThat(realConfig.httpFilterConfigs()).isEqualTo(httpConnectionManager.httpFilterConfigs()); verify(listener).onServing(); verify(mockServer).start(); } @@ -427,15 +426,22 @@ public void run() { Collections.singletonList(createVirtualHost("virtual-host-2"))); start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer).start(); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - f0, ServerRoutingConfig.create( - hcmVirtual.httpFilterConfigs(), hcmVirtual.virtualHosts()), - f2, ServerRoutingConfig.create(f2.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1"))) - )); - assertThat(selectorRef.get().getDefaultRoutingConfig()).isEqualTo( - ServerRoutingConfig.create(f3.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-2")))); + ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(f0); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f0.getHttpConnectionManager().httpFilterConfigs()); + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(2); + realConfig = selectorRef.get().getRoutingConfigs().get(f2); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f2.getHttpConnectionManager().httpFilterConfigs()); + realConfig = selectorRef.get().getDefaultRoutingConfig(); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-2"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f3.getHttpConnectionManager().httpFilterConfigs()); assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isEqualTo( f3.getSslContextProviderSupplier()); } @@ -469,38 +475,53 @@ public void run() { Collections.singletonList(createVirtualHost("virtual-host-0"))); start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer, times(1)).start(); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - f0, ServerRoutingConfig.create( - f0.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-0"))), - f1, ServerRoutingConfig.create(f1.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-0"))) - )); - assertThat(selectorRef.get().getDefaultRoutingConfig()).isEqualTo( - ServerRoutingConfig.create(f2.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-0")))); + ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(f0); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f0.getHttpConnectionManager().httpFilterConfigs()); + realConfig = selectorRef.get().getRoutingConfigs().get(f1); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f1.getHttpConnectionManager().httpFilterConfigs()); + + realConfig = selectorRef.get().getDefaultRoutingConfig(); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f2.getHttpConnectionManager().httpFilterConfigs()); assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isSameInstanceAs( f2.getSslContextProviderSupplier()); EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); + EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); xdsClient.rdsCount = new CountDownLatch(1); - xdsClient.deliverLdsUpdate(Arrays.asList(f1, f3), f4); + xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - f1, ServerRoutingConfig.create( - f1.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-0"))), - f3, ServerRoutingConfig.create(f3.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-0"))) - )); - assertThat(selectorRef.get().getDefaultRoutingConfig()).isEqualTo( - ServerRoutingConfig.create(f4.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1")))); + + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(2); + realConfig = selectorRef.get().getRoutingConfigs().get(f5); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f5.getHttpConnectionManager().httpFilterConfigs()); + realConfig = selectorRef.get().getRoutingConfigs().get(f3); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f3.getHttpConnectionManager().httpFilterConfigs()); + + realConfig = selectorRef.get().getDefaultRoutingConfig(); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f4.getHttpConnectionManager().httpFilterConfigs()); assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isSameInstanceAs( f4.getSslContextProviderSupplier()); verify(mockServer, times(1)).start(); @@ -535,27 +556,38 @@ public void run() { xdsClient.rdsCount.await(); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); start.get(5000, TimeUnit.MILLISECONDS); - assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo(ServerRoutingConfig.create( - ImmutableList.of(), ImmutableList.of()) - ); + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(2); + ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(f1); + assertThat(realConfig.virtualHosts().get()).isNull(); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f1.getHttpConnectionManager().httpFilterConfigs()); + realConfig = selectorRef.get().getRoutingConfigs().get(f0); + assertThat(realConfig.virtualHosts().get()).isEqualTo(hcmVirtual.virtualHosts()); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f0.getHttpConnectionManager().httpFilterConfigs()); + xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo( - ServerRoutingConfig.create(f1.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1")))); + realConfig = selectorRef.get().getRoutingConfigs().get(f1); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f1.getHttpConnectionManager().httpFilterConfigs()); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); - assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo( - ServerRoutingConfig.create(f1.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1")))); + realConfig = selectorRef.get().getRoutingConfigs().get(f1); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f1.getHttpConnectionManager().httpFilterConfigs()); xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); - assertThat(selectorRef.get().getRoutingConfigs().get(f1)).isEqualTo(ServerRoutingConfig.create( - ImmutableList.of(), ImmutableList.of()) - ); + realConfig = selectorRef.get().getRoutingConfigs().get(f1); + assertThat(realConfig.virtualHosts().get()).isNull(); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + f1.getHttpConnectionManager().httpFilterConfigs()); } - @Test public void error() throws Exception { final SettableFuture start = SettableFuture.create(); @@ -602,11 +634,12 @@ public void run() { verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(1)).onServing(); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - filterChain1, ServerRoutingConfig.create( - filterChain1.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1"))) - )); + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); + ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(filterChain1); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + filterChain1.getHttpConnectionManager().httpFilterConfigs()); // xds update after start xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-2"))); @@ -615,11 +648,12 @@ public void run() { verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(2)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - filterChain1, ServerRoutingConfig.create( - filterChain1.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-2"))) - )); + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); + realConfig = selectorRef.get().getRoutingConfigs().get(filterChain1); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-2"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + filterChain1.getHttpConnectionManager().httpFilterConfigs()); assertThat(sslSupplier1.isShutdown()).isFalse(); // not serving after serving @@ -652,11 +686,12 @@ public void run() { verify(mockServer, times(3)).start(); verify(listener, times(1)).onServing(); verify(listener, times(3)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - filterChain2, ServerRoutingConfig.create( - filterChain2.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1"))) - )); + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); + realConfig = selectorRef.get().getRoutingConfigs().get(filterChain2); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + filterChain2.getHttpConnectionManager().httpFilterConfigs()); assertThat(executor.numPendingTasks()).isEqualTo(1); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); verify(mockServer, times(4)).shutdown(); @@ -676,11 +711,13 @@ public void run() { verify(listener, times(1)).onServing(); when(mockServer.isShutdown()).thenReturn(false); verify(listener, times(4)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getRoutingConfigs()).isEqualTo(ImmutableMap.of( - filterChain3, ServerRoutingConfig.create( - filterChain3.getHttpConnectionManager().httpFilterConfigs(), - Collections.singletonList(createVirtualHost("virtual-host-1"))) - )); + + assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); + realConfig = selectorRef.get().getRoutingConfigs().get(filterChain3); + assertThat(realConfig.virtualHosts().get()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.httpFilterConfigs()).isEqualTo( + filterChain3.getHttpConnectionManager().httpFilterConfigs()); xdsServerWrapper.shutdown(); verify(mockServer, times(5)).shutdown(); assertThat(sslSupplier3.isShutdown()).isTrue(); @@ -828,7 +865,8 @@ public void run() { verify(mockBuilder).intercept(interceptorCaptor.capture()); ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); ServerRoutingConfig failingConfig = ServerRoutingConfig.create( - ImmutableList.of(), ImmutableList.of()); + ImmutableList.of(), new AtomicReference>() + ); ServerCall serverCall = mock(ServerCall.class); when(serverCall.getAttributes()).thenReturn( Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, failingConfig).build()); @@ -841,7 +879,7 @@ public void run() { Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); assertThat(status.getDescription()).isEqualTo( - "Missing xDS routing config. RDS config unavailable."); + "Missing xDS routing config VirtualHosts due to RDS config unavailable."); } @Test @@ -900,7 +938,7 @@ public ServerCall.Listener interceptCall(ServerCall(ImmutableList.of(virtualHost)) ); ServerCall serverCall = mock(ServerCall.class); ServerCallHandler mockNext = mock(ServerCallHandler.class); @@ -984,8 +1022,8 @@ private static ServerRoutingConfig createRoutingConfig(String path, String domai FilterConfig f0 = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn(filterType); return ServerRoutingConfig.create( - Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0)), - Collections.singletonList(virtualHost) + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0)), + new AtomicReference<>(ImmutableList.of(virtualHost)) ); } From c72bb4d29f01a64a522368442d68d2966deac785 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Thu, 9 Sep 2021 13:19:40 -0700 Subject: [PATCH 71/82] fix header matcher for null value (#8504) --- .../java/io/grpc/xds/XdsNameResolver.java | 32 +------------------ .../java/io/grpc/xds/internal/Matchers.java | 9 ++---- .../io/grpc/xds/internal/MatcherTest.java | 13 ++++++++ 3 files changed, 17 insertions(+), 37 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 787336e4b54..b5e3ae813e1 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -594,7 +594,7 @@ static boolean matchRoute(RouteMatch routeMatch, String fullMethodName, return false; } for (HeaderMatcher headerMatcher : routeMatch.headerMatchers()) { - if (!matchHeader(headerMatcher, getHeaderValue(headers, headerMatcher.name()))) { + if (!headerMatcher.matches(getHeaderValue(headers, headerMatcher.name()))) { return false; } } @@ -615,36 +615,6 @@ private static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) return pathMatcher.regEx().matches(fullMethodName); } - // TODO(zivy): consider reuse Matchers.HeaderMatcher.matches() - private static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) { - if (headerMatcher.present() != null) { - return (value == null) == headerMatcher.present().equals(headerMatcher.inverted()); - } - if (value == null) { - return false; - } - boolean baseMatch; - if (headerMatcher.exactValue() != null) { - baseMatch = headerMatcher.exactValue().equals(value); - } else if (headerMatcher.safeRegEx() != null) { - baseMatch = headerMatcher.safeRegEx().matches(value); - } else if (headerMatcher.range() != null) { - long numValue; - try { - numValue = Long.parseLong(value); - baseMatch = numValue >= headerMatcher.range().start() - && numValue <= headerMatcher.range().end(); - } catch (NumberFormatException ignored) { - baseMatch = false; - } - } else if (headerMatcher.prefix() != null) { - baseMatch = value.startsWith(headerMatcher.prefix()); - } else { - baseMatch = value.endsWith(headerMatcher.suffix()); - } - return baseMatch != headerMatcher.inverted(); - } - @Nullable private static String getHeaderValue(Metadata headers, String headerName) { if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { diff --git a/xds/src/main/java/io/grpc/xds/internal/Matchers.java b/xds/src/main/java/io/grpc/xds/internal/Matchers.java index 28ec8418297..3bf7b7723e2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/Matchers.java +++ b/xds/src/main/java/io/grpc/xds/internal/Matchers.java @@ -117,13 +117,8 @@ private static HeaderMatcher create(String name, @Nullable String exactValue, /** Returns the matching result. */ public boolean matches(@Nullable String value) { - if (present() != null) { - return (value == null) == present().equals(inverted()); - } - // FIXME(zivy@): invert result for null value. - // https://github.com/envoyproxy/envoy/blob/0fae6970ddaf93f024908ba304bbd2b34e997a51/source/common/http/header_utility.cc#L130 if (value == null) { - return false; + return present() != null && present() == inverted(); } boolean baseMatch; if (exactValue() != null) { @@ -141,6 +136,8 @@ public boolean matches(@Nullable String value) { } } else if (prefix() != null) { baseMatch = value.startsWith(prefix()); + } else if (present() != null) { + baseMatch = present(); } else { baseMatch = value.endsWith(suffix()); } diff --git a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java index 4fb4acc41f6..93a9b7087d6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java @@ -127,45 +127,58 @@ public void headerMatcher() { HeaderMatcher matcher = HeaderMatcher.forExactValue("version", "v1", false); assertThat(matcher.matches("v1")).isTrue(); assertThat(matcher.matches("v2")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forExactValue("version", "v1", true); assertThat(matcher.matches("v1")).isFalse(); assertThat(matcher.matches( "v2")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPresent("version", true, false); assertThat(matcher.matches("any")).isTrue(); assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPresent("version", true, true); assertThat(matcher.matches("version")).isFalse(); + assertThat(matcher.matches(null)).isTrue(); matcher = HeaderMatcher.forPresent("version", false, true); assertThat(matcher.matches("tag")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPresent("version", false, false); assertThat(matcher.matches("tag")).isFalse(); + assertThat(matcher.matches(null)).isTrue(); matcher = HeaderMatcher.forPrefix("version", "v2", false); assertThat(matcher.matches("v22")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPrefix("version", "v2", true); assertThat(matcher.matches("v22")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSuffix("version", "v1", false); assertThat(matcher.matches("xv1")).isTrue(); assertThat(matcher.matches("v1x")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSuffix("version", "v2", true); assertThat(matcher.matches("xv1")).isTrue(); assertThat(matcher.matches("1v2")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v2.*"), false); assertThat(matcher.matches("v2..")).isTrue(); assertThat(matcher.matches("v1")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v1\\..*"), true); assertThat(matcher.matches("v1.43")).isFalse(); assertThat(matcher.matches("v2")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forRange("version", Range.create(8080L, 8090L), false); assertThat(matcher.matches("8080")).isTrue(); assertThat(matcher.matches("1")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forRange("version", Range.create(8080L, 8090L), true); assertThat(matcher.matches("1")).isTrue(); assertThat(matcher.matches("8080")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); } } From 17c2460ffc782d28b6677803f6188a18810b276b Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Fri, 10 Sep 2021 12:10:22 -0700 Subject: [PATCH 72/82] xds: add terminal http filter verification, remove lame route filter, add hcm as terminal network filter verification (#8342) (#8507) --- .../java/io/grpc/xds/ClientXdsClient.java | 79 +++++++----- xds/src/main/java/io/grpc/xds/LameFilter.java | 121 ------------------ .../java/io/grpc/xds/XdsNameResolver.java | 42 +----- .../io/grpc/xds/ClientXdsClientDataTest.java | 94 ++++++++++++-- .../io/grpc/xds/ClientXdsClientTestBase.java | 27 ++-- .../io/grpc/xds/ClientXdsClientV2Test.java | 5 + .../io/grpc/xds/ClientXdsClientV3Test.java | 21 +++ .../java/io/grpc/xds/XdsNameResolverTest.java | 24 ---- 8 files changed, 174 insertions(+), 239 deletions(-) delete mode 100644 xds/src/main/java/io/grpc/xds/LameFilter.java diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 21cf78b1269..693848897c6 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -334,41 +334,32 @@ static FilterChain parseFilterChain( TlsContextManager tlsContextManager, FilterRegistry filterRegistry, Set uniqueSet, Set certProviderInstances, boolean parseHttpFilters) throws ResourceInvalidException { - io.grpc.xds.HttpConnectionManager httpConnectionManager = null; - HashSet uniqueNames = new HashSet<>(); - for (io.envoyproxy.envoy.config.listener.v3.Filter filter : proto.getFiltersList()) { - if (!uniqueNames.add(filter.getName())) { - throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " with duplicated filter: " + filter.getName()); - } - if (!filter.hasTypedConfig()) { - throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() - + " without typed_config"); - } - Any any = filter.getTypedConfig(); - // HttpConnectionManager is the only supported network filter at the moment. - if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { - throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() - + " with unsupported typed_config type " + any.getTypeUrl()); - } - if (httpConnectionManager == null) { - HttpConnectionManager hcmProto; - try { - hcmProto = any.unpack(HttpConnectionManager.class); - } catch (InvalidProtocolBufferException e) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " - + filter.getName() + " failed to unpack message", e); - } - httpConnectionManager = parseHttpConnectionManager( - hcmProto, rdsResources, filterRegistry, parseHttpFilters, false /* isForClient */); - } - } - if (httpConnectionManager == null) { + if (proto.getFiltersCount() != 1) { throw new ResourceInvalidException("FilterChain " + proto.getName() - + " missing required HttpConnectionManager filter"); + + " should contain exact one HttpConnectionManager filter"); } + io.envoyproxy.envoy.config.listener.v3.Filter filter = proto.getFiltersList().get(0); + if (!filter.hasTypedConfig()) { + throw new ResourceInvalidException( + "FilterChain " + proto.getName() + " contains filter " + filter.getName() + + " without typed_config"); + } + Any any = filter.getTypedConfig(); + // HttpConnectionManager is the only supported network filter at the moment. + if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { + throw new ResourceInvalidException( + "FilterChain " + proto.getName() + " contains filter " + filter.getName() + + " with unsupported typed_config type " + any.getTypeUrl()); + } + HttpConnectionManager hcmProto; + try { + hcmProto = any.unpack(HttpConnectionManager.class); + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " + + filter.getName() + " failed to unpack message", e); + } + io.grpc.xds.HttpConnectionManager httpConnectionManager = parseHttpConnectionManager( + hcmProto, rdsResources, filterRegistry, parseHttpFilters, false /* isForClient */); EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = null; if (proto.hasTransportSocket()) { @@ -762,10 +753,14 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( // Parse http filters. List filterConfigs = null; if (parseHttpFilter) { + if (proto.getHttpFiltersList().isEmpty()) { + throw new ResourceInvalidException("Missing HttpFilter in HttpConnectionManager."); + } filterConfigs = new ArrayList<>(); Set names = new HashSet<>(); - for (io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter - httpFilter : proto.getHttpFiltersList()) { + for (int i = 0; i < proto.getHttpFiltersCount(); i++) { + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter + httpFilter = proto.getHttpFiltersList().get(i); String filterName = httpFilter.getName(); if (!names.add(filterName)) { throw new ResourceInvalidException( @@ -773,6 +768,11 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( } StructOrError filterConfig = parseHttpFilter(httpFilter, filterRegistry, isForClient); + if ((i == proto.getHttpFiltersCount() - 1) + && (filterConfig == null || !isTerminalFilter(filterConfig.struct))) { + throw new ResourceInvalidException("The last HttpFilter must be a terminal filter: " + + filterName); + } if (filterConfig == null) { continue; } @@ -781,6 +781,10 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( "HttpConnectionManager contains invalid HttpFilter: " + filterConfig.getErrorDetail()); } + if ((i < proto.getHttpFiltersCount() - 1) && isTerminalFilter(filterConfig.getStruct())) { + throw new ResourceInvalidException("A terminal HttpFilter must be the last filter: " + + filterName); + } filterConfigs.add(new NamedFilterConfig(filterName, filterConfig.struct)); } } @@ -821,6 +825,11 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( "HttpConnectionManager neither has inlined route_config nor RDS"); } + // hard-coded: currently router config is the only terminal filter. + private static boolean isTerminalFilter(FilterConfig filterConfig) { + return RouterFilter.ROUTER_CONFIG.equals(filterConfig); + } + @VisibleForTesting @Nullable // Returns null if the filter is optional but not supported. static StructOrError parseHttpFilter( diff --git a/xds/src/main/java/io/grpc/xds/LameFilter.java b/xds/src/main/java/io/grpc/xds/LameFilter.java deleted file mode 100644 index 4dd1d3c96ed..00000000000 --- a/xds/src/main/java/io/grpc/xds/LameFilter.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import com.google.common.util.concurrent.MoreExecutors; -import com.google.protobuf.Message; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.Context; -import io.grpc.LoadBalancer.PickSubchannelArgs; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import io.grpc.xds.Filter.ClientInterceptorBuilder; -import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; - -/** - * A filter that fails all RPCs. To be added to the end of filter chain if RouterFilter is absent. - */ -enum LameFilter implements Filter, ClientInterceptorBuilder { - INSTANCE; - - static final FilterConfig LAME_CONFIG = new FilterConfig() { - @Override - public String typeUrl() { - throw new UnsupportedOperationException("shouldn't be called"); - } - - @Override - public String toString() { - return "LAME_CONFIG"; - } - }; - - @Override - public String[] typeUrls() { - return new String[0]; - } - - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - throw new UnsupportedOperationException(); - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - throw new UnsupportedOperationException(); - } - - @Nullable - @Override - public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler) { - class LameInterceptor implements ClientInterceptor { - - @Override - public ClientCall interceptCall( - MethodDescriptor method, final CallOptions callOptions, Channel next) { - final Context context = Context.current(); - return new ClientCall() { - @Override - public void start(final Listener listener, Metadata headers) { - Executor callExecutor = callOptions.getExecutor(); - if (callExecutor == null) { // This should never happen in practice because - // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with - // a callExecutor. - // TODO(https://github.com/grpc/grpc-java/issues/7868) - callExecutor = MoreExecutors.directExecutor(); - } - callExecutor.execute( - new Runnable() { - @Override - public void run() { - Context previous = context.attach(); - try { - listener.onClose( - Status.UNAVAILABLE.withDescription("No router filter"), new Metadata()); - } finally { - context.detach(previous); - } - } - }); - } - - @Override - public void request(int numMessages) {} - - @Override - public void cancel(@Nullable String message, @Nullable Throwable cause) {} - - @Override - public void halfClose() {} - - @Override - public void sendMessage(ReqT message) {} - }; - } - } - - return new LameInterceptor(); - } -} diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index b5e3ae813e1..e2905dee26e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -23,7 +23,6 @@ import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import com.google.gson.Gson; import com.google.protobuf.util.Durations; @@ -95,8 +94,6 @@ final class XdsNameResolver extends NameResolver { CallOptions.Key.create("io.grpc.xds.CLUSTER_SELECTION_KEY"); static final CallOptions.Key RPC_HASH_KEY = CallOptions.Key.create("io.grpc.xds.RPC_HASH_KEY"); - private static final NamedFilterConfig LAME_FILTER = - new NamedFilterConfig(null, LameFilter.LAME_CONFIG); @VisibleForTesting static boolean enableTimeout = Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT")) @@ -374,10 +371,6 @@ public Result selectConfig(PickSubchannelArgs args) { do { routingCfg = routingConfig; selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig); - if (routingCfg.filterChain != null - && Iterables.getLast(routingCfg.filterChain).equals(LAME_FILTER)) { - break; - } for (Route route : routingCfg.routes) { if (matchRoute(route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(), headers, random)) { @@ -442,12 +435,7 @@ public Result selectConfig(PickSubchannelArgs args) { if (routingCfg.filterChain != null) { for (NamedFilterConfig namedFilter : routingCfg.filterChain) { FilterConfig filterConfig = namedFilter.filterConfig; - Filter filter; - if (namedFilter.equals(LAME_FILTER)) { - filter = LameFilter.INSTANCE; - } else { - filter = filterRegistry.get(filterConfig.typeUrl()); - } + Filter filter = filterRegistry.get(filterConfig.typeUrl()); if (filter instanceof ClientInterceptorBuilder) { ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) .buildClientInterceptor( @@ -458,12 +446,6 @@ public Result selectConfig(PickSubchannelArgs args) { } } } - if (Iterables.getLast(routingCfg.filterChain).equals(LAME_FILTER)) { - return Result.newBuilder() - .setConfig(config) - .setInterceptor(combineInterceptors(filterInterceptors)) - .build(); - } } final String finalCluster = cluster; final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers); @@ -724,27 +706,7 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura return; } - // A router filter is required for request routing. For backward compatibility, routing - // is always enabled for gRPC clients without HttpFilter support. List routes = virtualHost.routes(); - List filterChain = null; - if (filterConfigs != null) { - boolean hasRouter = false; - filterChain = new ArrayList<>(filterConfigs.size()); - for (NamedFilterConfig namedFilter : filterConfigs) { - filterChain.add(namedFilter); - if (namedFilter.filterConfig.equals(RouterFilter.ROUTER_CONFIG)) { - hasRouter = true; - break; - } - } - if (!hasRouter) { - // Fail all RPCs if a router filter is not present. Reference counts for all currently - // selectable clusters should be reclaimed. - filterChain.add(LAME_FILTER); - routes = Collections.emptyList(); - } - } // Populate all clusters to which requests can be routed to through the virtual host. Set clusters = new HashSet<>(); @@ -785,7 +747,7 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura // selectable. routingConfig = new RoutingConfig( - httpMaxStreamDurationNano, routes, filterChain, + httpMaxStreamDurationNano, routes, filterConfigs, virtualHost.filterConfigOverrides()); shouldUpdateResult = false; for (String cluster : deletedClusters) { diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 80cd2a8046e..807f512b0f4 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -129,7 +129,7 @@ public class ClientXdsClientDataTest { @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); - private final FilterRegistry filterRegistry = FilterRegistry.newRegistry(); + private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private boolean originalEnableRetry; @Before @@ -1132,6 +1132,9 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) .addHttpFilters( HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("terminal").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); @@ -1140,6 +1143,70 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv true /* does not matter */); } + @Test + public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidException { + filterRegistry.register(FaultFilter.INSTANCE); + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true) + .setTypedConfig(Any.pack(HTTPFault.newBuilder().build()))) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true /* does not matter */); + } + + @Test + public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidException { + filterRegistry.register(RouterFilter.INSTANCE); + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .addHttpFilters( + HttpFilter.newBuilder().setName("terminal").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("A terminal HttpFilter must be the last filter: terminal"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true); + } + + @Test + public void parseHttpConnectionManager_unknownFilters() throws ResourceInvalidException { + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true)) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true /* does not matter */); + } + + @Test + public void parseHttpConnectionManager_emptyFilters() throws ResourceInvalidException { + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("Missing HttpFilter in HttpConnectionManager."); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true /* does not matter */); + } + @Test public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder() @@ -1280,7 +1347,8 @@ public void parseServerSideListener_useOriginalDst() throws ResourceInvalidExcep @Test public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceInvalidException { Filter filter1 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-1").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-1").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch1 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(80, 8080)) @@ -1296,7 +1364,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI .addFilters(filter1) .build(); Filter filter2 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-2").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-2").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch2 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(443, 8080)) @@ -1328,7 +1397,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() throws ResourceInvalidException { Filter filter1 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-1").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-1").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch1 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(80, 8080)) @@ -1343,7 +1413,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() .addFilters(filter1) .build(); Filter filter2 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-2").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-2").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch2 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(443, 8080)) @@ -1374,7 +1445,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() @Test public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInvalidException { Filter filter1 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-1").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-1").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch1 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(80, 8080)) @@ -1391,7 +1463,8 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .addFilters(filter1) .build(); Filter filter2 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-2").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-2").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch2 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(443, 8080)) @@ -1428,7 +1501,7 @@ public void parseFilterChain_noHcm() throws ResourceInvalidException { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "FilterChain filter-chain-foo missing required HttpConnectionManager filter"); + "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); ClientXdsClient.parseFilterChain( filterChain, new HashSet(), null, filterRegistry, null, null, true /* does not matter */); @@ -1447,7 +1520,7 @@ public void parseFilterChain_duplicateFilter() throws ResourceInvalidException { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "FilterChain filter-chain-foo with duplicated filter: envoy.http_connection_manager"); + "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); ClientXdsClient.parseFilterChain( filterChain, new HashSet(), null, filterRegistry, null, null, true /* does not matter */); @@ -1504,6 +1577,7 @@ public void parseFilterChain_noName_generatedUuid() throws ResourceInvalidExcept HttpFilter.newBuilder() .setName("http-filter-foo") .setIsOptional(true) + .setTypedConfig(Any.pack(Router.newBuilder().build())) .build())) .build(); FilterChain filterChain2 = @@ -1512,6 +1586,7 @@ public void parseFilterChain_noName_generatedUuid() throws ResourceInvalidExcept .addFilters(buildHttpConnectionManagerFilter( HttpFilter.newBuilder() .setName("http-filter-bar") + .setTypedConfig(Any.pack(Router.newBuilder().build())) .setIsOptional(true) .build())) .build(); @@ -1525,7 +1600,6 @@ public void parseFilterChain_noName_generatedUuid() throws ResourceInvalidExcept assertThat(parsedFilterChain1.getName()).isNotEqualTo(parsedFilterChain2.getName()); } - @Test public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 55bd6ba3e9d..a2a29ffe989 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -39,6 +39,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import io.envoyproxy.envoy.config.route.v3.FilterConfig; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.BindableService; @@ -658,7 +659,8 @@ public void ldsResourceUpdate_withFaultInjection() { mf.buildHttpFaultTypedConfig( 1L, 2, "cluster1", ImmutableList.of(), 3, null, null, null), - false)))); + false), + mf.buildHttpFilter("terminal", Any.pack(Router.newBuilder().build()), true)))); call.sendResponse(LDS, listener, VERSION_1, "0000"); // Client sends an ACK LDS request. @@ -993,7 +995,7 @@ public void rdsResourcesDeletedByLdsTcpListener() { verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - RDS_RESOURCE, null, Collections.emptyList()); + RDS_RESOURCE, null, Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "google-sds-config-default", "ROOTCA", false); Message filterChain = mf.buildFilterChain( @@ -1028,7 +1030,7 @@ public void rdsResourcesDeletedByLdsTcpListener() { null, mf.buildRouteConfiguration( "route-bar.googleapis.com", mf.buildOpaqueVirtualHosts(VHOST_SIZE)), - Collections.emptyList()); + Collections.singletonList(mf.buildTerminalFilter())); filterChain = mf.buildFilterChain( Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.tls", hcmFilter); @@ -2203,7 +2205,8 @@ public void serverSideListenerFound() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "google-sds-config-default", "ROOTCA", false); Message filterChain = mf.buildFilterChain( @@ -2226,7 +2229,8 @@ public void serverSideListenerFound() { assertThat(parsedFilterChain.getFilterChainMatch().getApplicationProtocols()).isEmpty(); assertThat(parsedFilterChain.getHttpConnectionManager().rdsName()) .isEqualTo("route-foo.googleapis.com"); - assertThat(parsedFilterChain.getHttpConnectionManager().httpFilterConfigs()).isEmpty(); + assertThat(parsedFilterChain.getHttpConnectionManager().httpFilterConfigs().get(0).filterConfig) + .isEqualTo(RouterFilter.ROUTER_CONFIG); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @@ -2237,7 +2241,8 @@ public void serverSideListenerNotFound() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "google-sds-config-default", "ROOTCA", false); Message filterChain = mf.buildFilterChain( @@ -2263,7 +2268,8 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( null, null,false); Message filterChain = mf.buildFilterChain( @@ -2286,7 +2292,8 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "cert1", "cert2",false); Message filterChain = mf.buildFilterChain( @@ -2385,7 +2392,7 @@ protected abstract static class MessageFactory { /** Throws {@link InvalidProtocolBufferException} on {@link Any#unpack(Class)}. */ protected static final Any FAILING_ANY = Any.newBuilder().setTypeUrl("fake").build(); - protected final Message buildListenerWithApiListener(String name, Message routeConfiguration) { + protected Message buildListenerWithApiListener(String name, Message routeConfiguration) { return buildListenerWithApiListener( name, routeConfiguration, Collections.emptyList()); } @@ -2470,5 +2477,7 @@ protected abstract Message buildListenerWithFilterChain( protected abstract Message buildHttpConnectionManagerFilter( @Nullable String rdsName, @Nullable Message routeConfig, List httpFilters); + + protected abstract Message buildTerminalFilter(); } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java index 39f5d1a1a2e..1a69b6fc650 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java @@ -641,6 +641,11 @@ protected Message buildHttpConnectionManagerFilter( @Nullable String rdsName, @Nullable Message routeConfig, List httpFilters) { throw new UnsupportedOperationException(); } + + @Override + protected Message buildTerminalFilter() { + throw new UnsupportedOperationException(); + } } /** diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index dfd407ef016..6df36e1c31e 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -74,6 +74,7 @@ import io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort.HeaderAbort; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; @@ -276,6 +277,15 @@ protected Message buildListenerWithApiListener( .build(); } + @Override + protected Message buildListenerWithApiListener(String name, Message routeConfiguration) { + return buildListenerWithApiListener(name, routeConfiguration, Arrays.asList( + HttpFilter.newBuilder() + .setName("terminal") + .setTypedConfig(Any.pack(Router.newBuilder().build())).build() + )); + } + @Override protected Message buildListenerWithApiListenerForRds(String name, String rdsResourceName) { return Listener.newBuilder() @@ -291,6 +301,10 @@ protected Message buildListenerWithApiListenerForRds(String name, String rdsReso .setConfigSource( ConfigSource.newBuilder() .setAds(AggregatedConfigSource.getDefaultInstance()))) + .addHttpFilters( + HttpFilter.newBuilder() + .setName("terminal") + .setTypedConfig(Any.pack(Router.newBuilder().build()))) .build()))) .build(); } @@ -742,6 +756,13 @@ protected Message buildHttpConnectionManagerFilter( Any.pack(hcmBuilder.build(), "type.googleapis.com")) .build(); } + + @Override + protected Message buildTerminalFilter() { + return HttpFilter.newBuilder() + .setName("terminal") + .setTypedConfig(Any.pack(Router.newBuilder().build())).build(); + } } /** diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 22d7302f207..7a8fec5f74a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -1384,20 +1384,6 @@ public long nanoTime() { + " Deadline exceeded after 0.000004000s. ")); } - @Test - public void resolved_withNoRouterFilter() { - resolver.start(mockListener); - FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); - xdsClient.deliverLdsUpdateWithNoRouterFilter(); - verify(mockListener).onResult(resolutionResultCaptor.capture()); - ResolutionResult result = resolutionResultCaptor.getValue(); - InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); - ClientCall.Listener observer = startNewCall( - TestMethodDescriptors.voidMethod(), configSelector, Collections.emptyMap(), - CallOptions.DEFAULT); - verifyRpcFailed(observer, Status.UNAVAILABLE.withDescription("No router filter")); - } - @Test public void resolved_faultAbortAndDelayInLdsUpdateInLdsUpdate() { resolver.start(mockListener); @@ -1826,16 +1812,6 @@ void deliverLdsUpdateWithFaultInjection( 0L, Collections.singletonList(virtualHost), filterChain))); } - void deliverLdsUpdateWithNoRouterFilter() { - VirtualHost virtualHost = VirtualHost.create( - "virtual-host", - Collections.singletonList(AUTHORITY), - Collections.emptyList(), - Collections.emptyMap()); - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - 0L, Collections.singletonList(virtualHost), ImmutableList.of()))); - } - void deliverLdsUpdateForRdsNameWithFaultInjection( final String rdsName, @Nullable FaultConfig httpFilterFaultConfig) { if (httpFilterFaultConfig == null) { From dbf92027b6559480d2ce796ef805d1db373359f7 Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Mon, 13 Sep 2021 08:31:00 -0700 Subject: [PATCH 73/82] xds: populate envoy RetryPolicy with no retryOn to resolver (#8511) Envoy RetryPolicy with empty retryOn should not be ignored as no retry config when selecting Route config. Therefore, if xDS update for a route contains a RetryPolicy that has no RetryOn value that we support, but the virtual host config does, xds client should choose the Envoy RetryPolicy from the route (even with no RetryOn), rather than choosing the one from virtual host, and try to convert it into grpc RetryPolicy, and end up with no retry. --- xds/src/main/java/io/grpc/xds/ClientXdsClient.java | 11 ++++------- xds/src/main/java/io/grpc/xds/XdsNameResolver.java | 5 +++-- .../java/io/grpc/xds/ClientXdsClientDataTest.java | 3 ++- .../test/java/io/grpc/xds/XdsNameResolverTest.java | 14 +++++++++++++- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 693848897c6..f79b0c4e9f6 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -1281,13 +1281,10 @@ private static StructOrError parseRetryPolicy( retryableStatusCodesBuilder.add(code); } List retryableStatusCodes = retryableStatusCodesBuilder.build(); - if (!retryableStatusCodes.isEmpty()) { - return StructOrError.fromStruct( - RetryPolicy.create( - maxAttempts, retryableStatusCodes, initialBackoff, maxBackoff, - /* perAttemptRecvTimeout= */ null)); - } - return null; + return StructOrError.fromStruct( + RetryPolicy.create( + maxAttempts, retryableStatusCodes, initialBackoff, maxBackoff, + /* perAttemptRecvTimeout= */ null)); } @VisibleForTesting diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index e2905dee26e..4cd52c8b3f9 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -182,13 +182,14 @@ public void shutdown() { @VisibleForTesting static Map generateServiceConfigWithMethodConfig( @Nullable Long timeoutNano, @Nullable RetryPolicy retryPolicy) { - if (timeoutNano == null && retryPolicy == null) { + if (timeoutNano == null + && (retryPolicy == null || retryPolicy.retryableStatusCodes().isEmpty())) { return Collections.emptyMap(); } ImmutableMap.Builder methodConfig = ImmutableMap.builder(); methodConfig.put( "name", Collections.singletonList(Collections.emptyMap())); - if (retryPolicy != null) { + if (retryPolicy != null && !retryPolicy.retryableStatusCodes().isEmpty()) { ImmutableMap.Builder rawRetryPolicy = ImmutableMap.builder(); rawRetryPolicy.put("maxAttempts", (double) retryPolicy.maxAttempts()); rawRetryPolicy.put("initialBackoff", Durations.toString(retryPolicy.initialBackoff())); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 807f512b0f4..876615d0b39 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -552,7 +552,8 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder.build()) .build(); struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); - assertThat(struct.getStruct().retryPolicy()).isNull(); + assertThat(struct.getStruct().retryPolicy()).isNotNull(); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()).isEmpty(); // base_interval unset builder diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 7a8fec5f74a..babaa2b3034 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -988,6 +988,8 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { RetryPolicy retryPolicy = RetryPolicy.create( 4, ImmutableList.of(Code.UNAVAILABLE, Code.CANCELLED), Durations.fromMillis(100), Durations.fromMillis(200), null); + RetryPolicy retryPolicyWithEmptyStatusCodes = RetryPolicy.create( + 4, ImmutableList.of(), Durations.fromMillis(100), Durations.fromMillis(200), null); // timeout only String expectedServiceConfigJson = "{\n" @@ -1001,6 +1003,11 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(timeoutNano, null)) .isEqualTo(expectedServiceConfig); + // timeout and retry with empty retriable status codes + assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig( + timeoutNano, retryPolicyWithEmptyStatusCodes)) + .isEqualTo(expectedServiceConfig); + // retry only expectedServiceConfigJson = "{\n" + " \"methodConfig\": [{\n" @@ -1021,6 +1028,7 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(null, retryPolicy)) .isEqualTo(expectedServiceConfig); + // timeout and retry expectedServiceConfigJson = "{\n" + " \"methodConfig\": [{\n" @@ -1043,12 +1051,16 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { .isEqualTo(expectedServiceConfig); // no timeout and no retry - // timeout and retry expectedServiceConfigJson = "{}"; expectedServiceConfig = (Map) JsonParser.parse(expectedServiceConfigJson); assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(null, null)) .isEqualTo(expectedServiceConfig); + + // retry with emtry retriable status codes only + assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig( + null, retryPolicyWithEmptyStatusCodes)) + .isEqualTo(expectedServiceConfig); } @Test From 0741c4e328cc5f36e8a0aff0be7e75867d536c0f Mon Sep 17 00:00:00 2001 From: ZHANG Dapeng Date: Sat, 11 Sep 2021 21:57:47 -0700 Subject: [PATCH 74/82] xds: apply valid resources while NACKing update (#8506) Implementing [gRFC A46](https://github.com/grpc/proposal/pull/260) --- .../java/io/grpc/xds/ClientXdsClient.java | 135 +++++------ .../io/grpc/xds/ClientXdsClientTestBase.java | 209 +++++++++++++++--- 2 files changed, 248 insertions(+), 96 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index f79b0c4e9f6..d490c9861b9 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -190,6 +190,7 @@ final class ClientXdsClient extends AbstractXdsClient { protected void handleLdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); Set retainedRdsResources = new HashSet<>(); @@ -222,6 +223,7 @@ protected void handleLdsResponse(String versionInfo, List resources, String } catch (ResourceInvalidException e) { errors.add( "LDS response Listener '" + listenerName + "' validation error: " + e.getMessage()); + invalidResources.add(listenerName); continue; } @@ -231,19 +233,9 @@ protected void handleLdsResponse(String versionInfo, List resources, String getLogger().log(XdsLogLevel.INFO, "Received LDS Response version {0} nonce {1}. Parsed resources: {2}", versionInfo, nonce, unpackedResources); - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.LDS, unpackedResources, versionInfo, nonce, errors); - return; - } - - handleResourcesAccepted(ResourceType.LDS, parsedResources, versionInfo, nonce); - for (String resource : rdsResourceSubscribers.keySet()) { - if (!retainedRdsResources.contains(resource)) { - ResourceSubscriber subscriber = rdsResourceSubscribers.get(resource); - subscriber.onAbsent(); - } - } + handleResourceUpdate( + ResourceType.LDS, parsedResources, invalidResources, retainedRdsResources, versionInfo, + nonce, errors); } private LdsUpdate processClientSideListener( @@ -1310,6 +1302,7 @@ static StructOrError parseClusterWeight( protected void handleRdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); for (int i = 0; i < resources.size(); i++) { @@ -1337,6 +1330,7 @@ protected void handleRdsResponse(String versionInfo, List resources, String errors.add( "RDS response RouteConfiguration '" + routeConfigName + "' validation error: " + e .getMessage()); + invalidResources.add(routeConfigName); continue; } @@ -1345,12 +1339,9 @@ protected void handleRdsResponse(String versionInfo, List resources, String getLogger().log(XdsLogLevel.INFO, "Received RDS Response version {0} nonce {1}. Parsed resources: {2}", versionInfo, nonce, unpackedResources); - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.RDS, unpackedResources, versionInfo, nonce, errors); - } else { - handleResourcesAccepted(ResourceType.RDS, parsedResources, versionInfo, nonce); - } + handleResourceUpdate( + ResourceType.RDS, parsedResources, invalidResources, Collections.emptySet(), + versionInfo, nonce, errors); } private static RdsUpdate processRouteConfiguration( @@ -1374,6 +1365,7 @@ private static RdsUpdate processRouteConfiguration( protected void handleCdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); Set retainedEdsResources = new HashSet<>(); @@ -1410,6 +1402,7 @@ protected void handleCdsResponse(String versionInfo, List resources, String } catch (ResourceInvalidException e) { errors.add( "CDS response Cluster '" + clusterName + "' validation error: " + e.getMessage()); + invalidResources.add(clusterName); continue; } parsedResources.put(clusterName, new ParsedResource(cdsUpdate, resource)); @@ -1417,21 +1410,9 @@ protected void handleCdsResponse(String versionInfo, List resources, String getLogger().log(XdsLogLevel.INFO, "Received CDS Response version {0} nonce {1}. Parsed resources: {2}", versionInfo, nonce, unpackedResources); - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.CDS, unpackedResources, versionInfo, nonce, errors); - return; - } - - handleResourcesAccepted(ResourceType.CDS, parsedResources, versionInfo, nonce); - // CDS responses represents the state of the world, EDS resources not referenced in CDS - // resources should be deleted. - for (String resource : edsResourceSubscribers.keySet()) { - ResourceSubscriber subscriber = edsResourceSubscribers.get(resource); - if (!retainedEdsResources.contains(resource)) { - subscriber.onAbsent(); - } - } + handleResourceUpdate( + ResourceType.CDS, parsedResources, invalidResources, retainedEdsResources, versionInfo, + nonce, errors); } @VisibleForTesting @@ -1612,6 +1593,7 @@ private static StructOrError parseNonAggregateCluster( protected void handleEdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); for (int i = 0; i < resources.size(); i++) { @@ -1646,16 +1628,17 @@ protected void handleEdsResponse(String versionInfo, List resources, String } catch (ResourceInvalidException e) { errors.add("EDS response ClusterLoadAssignment '" + clusterName + "' validation error: " + e.getMessage()); + invalidResources.add(clusterName); continue; } parsedResources.put(clusterName, new ParsedResource(edsUpdate, resource)); } - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.EDS, unpackedResources, versionInfo, nonce, errors); - } else { - handleResourcesAccepted(ResourceType.EDS, parsedResources, versionInfo, nonce); - } + getLogger().log( + XdsLogLevel.INFO, "Received EDS Response version {0} nonce {1}. Parsed resources: {2}", + versionInfo, nonce, unpackedResources); + handleResourceUpdate( + ResourceType.EDS, parsedResources, invalidResources, Collections.emptySet(), + versionInfo, nonce, errors); } private static EdsUpdate processClusterLoadAssignment(ClusterLoadAssignment assignment) @@ -2045,43 +2028,67 @@ private void cleanUpResourceTimers() { } } - private void handleResourcesAccepted( - ResourceType type, Map parsedResources, String version, - String nonce) { - ackResponse(type, version, nonce); - + private void handleResourceUpdate( + ResourceType type, Map parsedResources, Set invalidResources, + Set retainedResources, String version, String nonce, List errors) { + String errorDetail = null; + if (errors.isEmpty()) { + checkArgument(invalidResources.isEmpty(), "found invalid resources but missing errors"); + ackResponse(type, version, nonce); + } else { + errorDetail = Joiner.on('\n').join(errors); + getLogger().log(XdsLogLevel.WARNING, + "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", + type, version, nonce, errorDetail); + nackResponse(type, nonce, errorDetail); + } long updateTime = timeProvider.currentTimeNanos(); for (Map.Entry entry : getSubscribedResourcesMap(type).entrySet()) { String resourceName = entry.getKey(); ResourceSubscriber subscriber = entry.getValue(); + // Attach error details to the subscribed resources that included in the ADS update. + if (invalidResources.contains(resourceName)) { + subscriber.onRejected(version, updateTime, errorDetail); + } // Notify the watchers. if (parsedResources.containsKey(resourceName)) { subscriber.onData(parsedResources.get(resourceName), version, updateTime); } else if (type == ResourceType.LDS || type == ResourceType.CDS) { + if (subscriber.data != null && invalidResources.contains(resourceName)) { + // Update is rejected but keep using the cached data. + if (type == ResourceType.LDS) { + LdsUpdate ldsUpdate = (LdsUpdate) subscriber.data; + io.grpc.xds.HttpConnectionManager hcm = ldsUpdate.httpConnectionManager(); + if (hcm != null) { + String rdsName = hcm.rdsName(); + if (rdsName != null) { + retainedResources.add(rdsName); + } + } + } else { + CdsUpdate cdsUpdate = (CdsUpdate) subscriber.data; + String edsName = cdsUpdate.edsServiceName(); + if (edsName == null) { + edsName = cdsUpdate.clusterName(); + } + retainedResources.add(edsName); + } + continue; + } // For State of the World services, notify watchers when their watched resource is missing // from the ADS update. subscriber.onAbsent(); } } - } - - private void handleResourcesRejected( - ResourceType type, Set unpackedResourceNames, String version, - String nonce, List errors) { - String errorDetail = Joiner.on('\n').join(errors); - getLogger().log(XdsLogLevel.WARNING, - "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", - type, version, nonce, errorDetail); - nackResponse(type, nonce, errorDetail); - - long updateTime = timeProvider.currentTimeNanos(); - for (Map.Entry entry : getSubscribedResourcesMap(type).entrySet()) { - String resourceName = entry.getKey(); - ResourceSubscriber subscriber = entry.getValue(); - - // Attach error details to the subscribed resources that included in the ADS update. - if (unpackedResourceNames.contains(resourceName)) { - subscriber.onRejected(version, updateTime, errorDetail); + // LDS/CDS responses represents the state of the world, RDS/EDS resources not referenced in + // LDS/CDS resources should be deleted. + if (type == ResourceType.LDS || type == ResourceType.CDS) { + Map dependentSubscribers = + type == ResourceType.LDS ? rdsResourceSubscribers : edsResourceSubscribers; + for (String resource : dependentSubscribers.keySet()) { + if (!retainedResources.contains(resource)) { + dependentSubscribers.get(resource).onAbsent(); + } } } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index a2a29ffe989..e66c73163be 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -473,11 +473,11 @@ public void ldsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "LDS response Resource index 0 - can't decode Listener: ", "LDS response Resource index 2 - can't decode Listener: "); - verifyResourceMetadataNacked(LDS, LDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(LDS, LDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(ldsResourceWatcher); + verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); } /** @@ -517,14 +517,14 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.2")), "B", Any.pack(mf.buildListenerWithApiListenerInvalid("B"))); call.sendResponse(LDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B - // {C} -> ACK, version 1 + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); - verifyResourceMetadataNacked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); - verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataDoesNotExist(LDS, "C"); call.verifyRequestNack(LDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); // LDS -> {B, C} version 3 @@ -532,7 +532,7 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { "B", Any.pack(mf.buildListenerWithApiListenerForRds("B", "B.3")), "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.3"))); call.sendResponse(LDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> does not exist // {B, C} -> ACK, version 3 verifyResourceMetadataDoesNotExist(LDS, "A"); verifyResourceMetadataAcked(LDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); @@ -541,6 +541,73 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifySubscribedResourcesMetadataSizes(3, 0, 0, 0); } + @Test + public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscriptioin() { + List subscribedResourceNames = ImmutableList.of("A", "B", "C"); + xdsClient.watchLdsResource("A", ldsResourceWatcher); + xdsClient.watchRdsResource("A.1", rdsResourceWatcher); + xdsClient.watchLdsResource("B", ldsResourceWatcher); + xdsClient.watchRdsResource("B.1", rdsResourceWatcher); + xdsClient.watchLdsResource("C", ldsResourceWatcher); + xdsClient.watchRdsResource("C.1", rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(LDS, "A"); + verifyResourceMetadataRequested(LDS, "B"); + verifyResourceMetadataRequested(LDS, "C"); + verifyResourceMetadataRequested(RDS, "A.1"); + verifyResourceMetadataRequested(RDS, "B.1"); + verifyResourceMetadataRequested(RDS, "C.1"); + verifySubscribedResourcesMetadataSizes(3, 0, 3, 0); + + // LDS -> {A, B, C}, version 1 + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.1")), + "B", Any.pack(mf.buildListenerWithApiListenerForRds("B", "B.1")), + "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.1"))); + call.sendResponse(LDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A, B, C} -> ACK, version 1 + verifyResourceMetadataAcked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + call.verifyRequest(LDS, subscribedResourceNames, VERSION_1, "0000", NODE); + + // RDS -> {A.1, B.1, C.1}, version 1 + List vhostsV1 = mf.buildOpaqueVirtualHosts(1); + ImmutableMap resourcesV11 = ImmutableMap.of( + "A.1", Any.pack(mf.buildRouteConfiguration("A.1", vhostsV1)), + "B.1", Any.pack(mf.buildRouteConfiguration("B.1", vhostsV1)), + "C.1", Any.pack(mf.buildRouteConfiguration("C.1", vhostsV1))); + call.sendResponse(RDS, resourcesV11.values().asList(), VERSION_1, "0000"); + // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceMetadataAcked(RDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(RDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); + + // LDS -> {A, B}, version 2 + // Failed to parse endpoint B + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.2")), + "B", Any.pack(mf.buildListenerWithApiListenerInvalid("B"))); + call.sendResponse(LDS, resourcesV2.values().asList(), VERSION_2, "0001"); + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist + List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); + verifyResourceMetadataNacked( + LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, + errorsV2); + verifyResourceMetadataDoesNotExist(LDS, "C"); + call.verifyRequestNack(LDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); + // {A.1} -> does not exist + // {B.1} -> version 1 + // {C.1} -> does not exist + verifyResourceMetadataDoesNotExist(RDS, "A.1"); + verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataDoesNotExist(RDS, "C.1"); + } + @Test public void ldsResourceFound_containsVirtualHosts() { DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); @@ -807,11 +874,11 @@ public void rdsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "RDS response Resource index 0 - can't decode RouteConfiguration: ", "RDS response Resource index 2 - can't decode RouteConfiguration: "); - verifyResourceMetadataNacked(RDS, RDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // The response is NACKed with the same error message. call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(rdsResourceWatcher); + verify(rdsResourceWatcher).onChanged(any(RdsUpdate.class)); } /** @@ -852,12 +919,12 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { "A", Any.pack(mf.buildRouteConfiguration("A", mf.buildOpaqueVirtualHosts(2))), "B", Any.pack(mf.buildRouteConfigurationInvalid("B"))); call.sendResponse(RDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 List errorsV2 = ImmutableList.of("RDS response RouteConfiguration 'B' validation error: "); - verifyResourceMetadataNacked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(RDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); verifyResourceMetadataAcked(RDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -869,10 +936,9 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { "B", Any.pack(mf.buildRouteConfiguration("B", vhostsV3)), "C", Any.pack(mf.buildRouteConfiguration("C", vhostsV3))); call.sendResponse(RDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 - verifyResourceMetadataNacked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(RDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); call.verifyRequest(RDS, subscribedResourceNames, VERSION_3, "0002", NODE); @@ -1146,11 +1212,12 @@ public void cdsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "CDS response Resource index 0 - can't decode Cluster: ", "CDS response Resource index 2 - can't decode Cluster: "); - verifyResourceMetadataNacked(CDS, CDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked( + CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(cdsResourceWatcher); + verify(cdsResourceWatcher).onChanged(any(CdsUpdate.class)); } /** @@ -1198,14 +1265,14 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { )), "B", Any.pack(mf.buildClusterInvalid("B"))); call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B - // {C} -> ACK, version 1 + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); - verifyResourceMetadataNacked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); - verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataDoesNotExist(CDS, "C"); call.verifyRequestNack(CDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); // CDS -> {B, C} version 3 @@ -1217,7 +1284,7 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { "envoy.transport_sockets.tls", null ))); call.sendResponse(CDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> does not exit // {B, C} -> ACK, version 3 verifyResourceMetadataDoesNotExist(CDS, "A"); verifyResourceMetadataAcked(CDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); @@ -1225,6 +1292,82 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { call.verifyRequest(CDS, subscribedResourceNames, VERSION_3, "0002", NODE); } + @Test + public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscription() { + List subscribedResourceNames = ImmutableList.of("A", "B", "C"); + xdsClient.watchCdsResource("A", cdsResourceWatcher); + xdsClient.watchEdsResource("A.1", edsResourceWatcher); + xdsClient.watchCdsResource("B", cdsResourceWatcher); + xdsClient.watchEdsResource("B.1", edsResourceWatcher); + xdsClient.watchCdsResource("C", cdsResourceWatcher); + xdsClient.watchEdsResource("C.1", edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(CDS, "A"); + verifyResourceMetadataRequested(CDS, "B"); + verifyResourceMetadataRequested(CDS, "C"); + verifyResourceMetadataRequested(EDS, "A.1"); + verifyResourceMetadataRequested(EDS, "B.1"); + verifyResourceMetadataRequested(EDS, "C.1"); + verifySubscribedResourcesMetadataSizes(0, 3, 0, 3); + + // CDS -> {A, B, C}, version 1 + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + )), + "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + )), + "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + ))); + call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A, B, C} -> ACK, version 1 + verifyResourceMetadataAcked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + call.verifyRequest(CDS, subscribedResourceNames, VERSION_1, "0000", NODE); + + // EDS -> {A.1, B.1, C.1}, version 1 + List dropOverloads = ImmutableList.of(); + List endpointsV1 = ImmutableList.of(lbEndpointHealthy); + ImmutableMap resourcesV11 = ImmutableMap.of( + "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads)), + "B.1", Any.pack(mf.buildClusterLoadAssignment("B.1", endpointsV1, dropOverloads)), + "C.1", Any.pack(mf.buildClusterLoadAssignment("C.1", endpointsV1, dropOverloads))); + call.sendResponse(EDS, resourcesV11.values().asList(), VERSION_1, "0000"); + // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceMetadataAcked(EDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(EDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); + + // CDS -> {A, B}, version 2 + // Failed to parse endpoint B + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + )), + "B", Any.pack(mf.buildClusterInvalid("B"))); + call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist + List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); + verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); + verifyResourceMetadataNacked( + CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, + errorsV2); + verifyResourceMetadataDoesNotExist(CDS, "C"); + call.verifyRequestNack(CDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); + // {A.1} -> does not exist + // {B.1} -> version 1 + // {C.1} -> does not exist + verifyResourceMetadataDoesNotExist(EDS, "A.1"); + verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataDoesNotExist(EDS, "C.1"); + } + @Test public void cdsResourceFound() { DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); @@ -1666,11 +1809,14 @@ public void edsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "EDS response Resource index 0 - can't decode ClusterLoadAssignment: ", "EDS response Resource index 2 - can't decode ClusterLoadAssignment: "); - verifyResourceMetadataNacked(EDS, EDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked( + EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // The response is NACKed with the same error message. call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(edsResourceWatcher); + verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); + EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); } /** @@ -1713,12 +1859,12 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { "A", Any.pack(mf.buildClusterLoadAssignment("A", endpointsV2, dropOverloads)), "B", Any.pack(mf.buildClusterLoadAssignmentInvalid("B"))); call.sendResponse(EDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 List errorsV2 = ImmutableList.of("EDS response ClusterLoadAssignment 'B' validation error: "); - verifyResourceMetadataNacked(EDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(EDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); verifyResourceMetadataAcked(EDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -1731,10 +1877,9 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { "B", Any.pack(mf.buildClusterLoadAssignment("B", endpointsV3, dropOverloads)), "C", Any.pack(mf.buildClusterLoadAssignment("C", endpointsV3, dropOverloads))); call.sendResponse(EDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 - verifyResourceMetadataNacked(EDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(EDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); call.verifyRequest(EDS, subscribedResourceNames, VERSION_3, "0002", NODE); From cc7c55e84314b538bac63510f53aeb72368d7b8f Mon Sep 17 00:00:00 2001 From: sanjaypujare Date: Wed, 15 Sep 2021 16:23:04 -0700 Subject: [PATCH 75/82] xds: The xdsClient uses equals() for duplicate detection/suppression so a proper equals() (and also hashCode()) is needed. (#8529) sslContextProvider and shutdown are excluded because these are implementation instances instead of pure data. This will be replaced by a more permanent solution later (such as removing the construction of implementation instances by XdsClient) --- .../sds/SslContextProviderSupplier.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 17fc442e7da..664b4881bc2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -25,6 +25,7 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.Objects; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -118,6 +119,24 @@ public synchronized void close() { shutdown = true; } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SslContextProviderSupplier that = (SslContextProviderSupplier) o; + return Objects.equals(tlsContext, that.tlsContext) + && Objects.equals(tlsContextManager, that.tlsContextManager); + } + + @Override + public int hashCode() { + return Objects.hash(tlsContext, tlsContextManager); + } + @Override public String toString() { return MoreObjects.toStringHelper(this) From 8b6e0e5f051d5beef2cd6e3c3f5a7423fee006f6 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 14 Sep 2021 10:49:48 -0700 Subject: [PATCH 76/82] netty: Allow protocol negotiators to shut down transport, with grace period This will be used for draining old connections when xDS configuration changes. --- .../netty/GracefulServerCloseCommand.java | 53 +++++++++++++++++++ .../InternalGracefulServerCloseCommand.java | 36 +++++++++++++ .../io/grpc/netty/NettyServerHandler.java | 24 +++++++-- .../WriteBufferingAndExceptionHandler.java | 2 + .../io/grpc/netty/NettyServerHandlerTest.java | 50 +++++++++++++++++ 5 files changed, 160 insertions(+), 5 deletions(-) create mode 100644 netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java create mode 100644 netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java diff --git a/netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java b/netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java new file mode 100644 index 00000000000..97904687548 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java @@ -0,0 +1,53 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.base.Preconditions; +import java.util.concurrent.TimeUnit; + +/** + * A command to trigger close and allow streams naturally close. + */ +class GracefulServerCloseCommand extends WriteQueue.AbstractQueuedCommand { + private final String goAwayDebugString; + private final long graceTime; + private final TimeUnit graceTimeUnit; + + public GracefulServerCloseCommand(String goAwayDebugString) { + this(goAwayDebugString, -1, null); + } + + public GracefulServerCloseCommand( + String goAwayDebugString, long graceTime, TimeUnit graceTimeUnit) { + this.goAwayDebugString = Preconditions.checkNotNull(goAwayDebugString, "goAwayDebugString"); + this.graceTime = graceTime; + this.graceTimeUnit = graceTimeUnit; + } + + public String getGoAwayDebugString() { + return goAwayDebugString; + } + + /** Has no meaning if {@code getGraceTimeUnit() == null}. */ + public long getGraceTime() { + return graceTime; + } + + public TimeUnit getGraceTimeUnit() { + return graceTimeUnit; + } +} diff --git a/netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java b/netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java new file mode 100644 index 00000000000..deb72373ac7 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java @@ -0,0 +1,36 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.Internal; +import java.util.concurrent.TimeUnit; + +/** + * Internal accessor for {@link GracefulServerCloseCommand}. + */ +@Internal +public final class InternalGracefulServerCloseCommand { + private InternalGracefulServerCloseCommand() {} + + public static Object create(String goAwayDebugString) { + return new GracefulServerCloseCommand(goAwayDebugString); + } + + public static Object create(String goAwayDebugString, long graceTime, TimeUnit graceTimeUnit) { + return new GracefulServerCloseCommand(goAwayDebugString, graceTime, graceTimeUnit); + } +} diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 6fca656e795..0a34644267f 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -618,6 +618,8 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) sendResponseHeaders(ctx, (SendResponseHeadersCommand) msg, promise); } else if (msg instanceof CancelServerStreamCommand) { cancelStream(ctx, (CancelServerStreamCommand) msg, promise); + } else if (msg instanceof GracefulServerCloseCommand) { + gracefulClose(ctx, (GracefulServerCloseCommand) msg, promise); } else if (msg instanceof ForcefulCloseCommand) { forcefulClose(ctx, (ForcefulCloseCommand) msg, promise); } else { @@ -631,11 +633,8 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - if (gracefulShutdown == null) { - gracefulShutdown = new GracefulShutdown("app_requested", null); - gracefulShutdown.start(ctx); - ctx.flush(); - } + gracefulClose(ctx, new GracefulServerCloseCommand("app_requested"), promise); + ctx.flush(); } /** @@ -716,6 +715,21 @@ private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand c } } + private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg, + ChannelPromise promise) throws Exception { + // Ideally we'd adjust a pre-existing graceful shutdown's grace period to at least what is + // requested here. But that's an edge case and seems bug-prone. + if (gracefulShutdown == null) { + Long graceTimeInNanos = null; + if (msg.getGraceTimeUnit() != null) { + graceTimeInNanos = msg.getGraceTimeUnit().toNanos(msg.getGraceTime()); + } + gracefulShutdown = new GracefulShutdown(msg.getGoAwayDebugString(), graceTimeInNanos); + gracefulShutdown.start(ctx); + } + promise.setSuccess(); + } + private void forcefulClose(final ChannelHandlerContext ctx, final ForcefulCloseCommand msg, ChannelPromise promise) throws Exception { super.close(ctx, promise); diff --git a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java index 9521fc93889..100367625fa 100644 --- a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java +++ b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java @@ -124,6 +124,8 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) promise.setFailure(failCause); ReferenceCountUtil.release(msg); } else { + // Do not special case GracefulServerCloseCommand, as we don't want to cause handshake + // failures. if (msg instanceof GracefulCloseCommand || msg instanceof ForcefulCloseCommand) { // No point in continuing negotiation ctx.close(); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 961f983d9cd..8c44088afa7 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -350,6 +350,56 @@ public void closeShouldGracefullyCloseChannel() throws Exception { assertFalse(channel().isOpen()); } + @Test + public void gracefulCloseShouldGracefullyCloseChannel() throws Exception { + manualSetUp(); + handler() + .write(ctx(), new GracefulServerCloseCommand("test", 1, TimeUnit.MINUTES), newPromise()); + + verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + verifyWrite().writePing( + eq(ctx()), + eq(false), + eq(NettyServerHandler.GRACEFUL_SHUTDOWN_PING), + isA(ChannelPromise.class)); + channelRead(pingFrame(/*ack=*/ true , NettyServerHandler.GRACEFUL_SHUTDOWN_PING)); + + verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + + // Verify that the channel was closed. + assertFalse(channel().isOpen()); + } + + @Test + public void secondGracefulCloseIsSafe() throws Exception { + manualSetUp(); + handler().write(ctx(), new GracefulServerCloseCommand("test"), newPromise()); + + verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + verifyWrite().writePing( + eq(ctx()), + eq(false), + eq(NettyServerHandler.GRACEFUL_SHUTDOWN_PING), + isA(ChannelPromise.class)); + + handler().write(ctx(), new GracefulServerCloseCommand("test2"), newPromise()); + + channel().runPendingTasks(); + // No additional GOAWAYs. + verifyWrite().writeGoAway(any(ChannelHandlerContext.class), any(Integer.class), any(Long.class), + any(ByteBuf.class), any(ChannelPromise.class)); + channel().checkException(); + assertTrue(channel().isOpen()); + + channelRead(pingFrame(/*ack=*/ true , NettyServerHandler.GRACEFUL_SHUTDOWN_PING)); + verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + assertFalse(channel().isOpen()); + } + @Test public void exceptionCaughtShouldCloseConnection() throws Exception { manualSetUp(); From 4d5a19c3873ce68972beaa2307ee94727059d7be Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 14 Sep 2021 16:40:14 -0700 Subject: [PATCH 77/82] xds: Drain old server connections on Listener updates This is necessary to make sure all connections are using the new configuration. --- ...ilterChainMatchingProtocolNegotiators.java | 75 +++++++-- .../grpc/xds/FilterChainSelectorManager.java | 95 +++++++++++ .../io/grpc/xds/InternalXdsAttributes.java | 13 +- .../java/io/grpc/xds/XdsServerBuilder.java | 44 ++++- .../java/io/grpc/xds/XdsServerWrapper.java | 25 +-- ...rChainMatchingProtocolNegotiatorsTest.java | 153 +++++++++++------- .../xds/FilterChainSelectorManagerTest.java | 107 ++++++++++++ .../XdsClientWrapperForServerSdsTestMisc.java | 33 ++-- .../io/grpc/xds/XdsServerBuilderTest.java | 11 ++ .../io/grpc/xds/XdsServerWrapperTest.java | 96 ++++++----- 10 files changed, 508 insertions(+), 144 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java create mode 100644 xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index 0c8780fe744..24cd4e9ae7e 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -17,7 +17,8 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_REF; +import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; @@ -28,6 +29,7 @@ import io.grpc.Attributes; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalGracefulServerCloseCommand; import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; @@ -40,6 +42,8 @@ import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; import io.grpc.xds.internal.Matchers.CidrMatcher; import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; @@ -54,7 +58,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -77,14 +81,16 @@ private FilterChainMatchingProtocolNegotiators() { static final class FilterChainMatchingHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler grpcHandler; - private final FilterChainSelector selector; + private final FilterChainSelectorManager filterChainSelectorManager; private final ProtocolNegotiator delegate; FilterChainMatchingHandler( - GrpcHttp2ConnectionHandler grpcHandler, FilterChainSelector selector, + GrpcHttp2ConnectionHandler grpcHandler, + FilterChainSelectorManager filterChainSelectorManager, ProtocolNegotiator delegate) { this.grpcHandler = checkNotNull(grpcHandler, "grpcHandler"); - this.selector = checkNotNull(selector, "selector"); + this.filterChainSelectorManager = + checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); this.delegate = checkNotNull(delegate, "delegate"); } @@ -94,6 +100,19 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc super.userEventTriggered(ctx, evt); return; } + long drainGraceTime = 0; + TimeUnit drainGraceTimeUnit = null; + Long drainGraceNanosObj = grpcHandler.getEagAttributes().get(ATTR_DRAIN_GRACE_NANOS); + if (drainGraceNanosObj != null) { + drainGraceTime = drainGraceNanosObj; + drainGraceTimeUnit = TimeUnit.NANOSECONDS; + } + FilterChainSelectorManager.Closer closer = new FilterChainSelectorManager.Closer( + new GracefullyShutdownChannelRunnable(ctx.channel(), drainGraceTime, drainGraceTimeUnit)); + FilterChainSelector selector = filterChainSelectorManager.register(closer); + ctx.channel().closeFuture().addListener( + new FilterChainSelectorManagerDeregister(filterChainSelectorManager, closer)); + checkNotNull(selector, "selector"); SelectedConfig config = selector.select( (InetSocketAddress) ctx.channel().localAddress(), (InetSocketAddress) ctx.channel().remoteAddress()); @@ -354,10 +373,10 @@ public AsciiString scheme() { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - AtomicReference filterChainSelectorRef = - grpcHandler.getEagAttributes().get(ATTR_FILTER_CHAIN_SELECTOR_REF); - checkNotNull(filterChainSelectorRef, "filterChainSelectorRef"); - return new FilterChainMatchingHandler(grpcHandler, filterChainSelectorRef.get(), + FilterChainSelectorManager filterChainSelectorManager = + grpcHandler.getEagAttributes().get(ATTR_FILTER_CHAIN_SELECTOR_MANAGER); + checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); + return new FilterChainMatchingHandler(grpcHandler, filterChainSelectorManager, delegate.newNegotiator(offloadExecutorPool)); } @@ -384,4 +403,42 @@ private SelectedConfig(ServerRoutingConfig routingConfig, this.sslContextProviderSupplier = sslContextProviderSupplier; } } + + private static class FilterChainSelectorManagerDeregister implements ChannelFutureListener { + private final FilterChainSelectorManager filterChainSelectorManager; + private final FilterChainSelectorManager.Closer closer; + + public FilterChainSelectorManagerDeregister( + FilterChainSelectorManager filterChainSelectorManager, + FilterChainSelectorManager.Closer closer) { + this.filterChainSelectorManager = + checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); + this.closer = checkNotNull(closer, "closer"); + } + + @Override public void operationComplete(ChannelFuture future) throws Exception { + filterChainSelectorManager.deregister(closer); + } + } + + private static class GracefullyShutdownChannelRunnable implements Runnable { + private final Channel channel; + private final long drainGraceTime; + @Nullable + private final TimeUnit drainGraceTimeUnit; + + public GracefullyShutdownChannelRunnable( + Channel channel, long drainGraceTime, @Nullable TimeUnit drainGraceTimeUnit) { + this.channel = checkNotNull(channel, "channel"); + this.drainGraceTime = drainGraceTime; + this.drainGraceTimeUnit = drainGraceTimeUnit; + } + + @Override public void run() { + Object gracefulCloseCommand = InternalGracefulServerCloseCommand.create( + "xds_drain", drainGraceTime, drainGraceTimeUnit); + channel.writeAndFlush(gracefulCloseCommand) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java new file mode 100644 index 00000000000..4295d75f59b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java @@ -0,0 +1,95 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import java.util.Comparator; +import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.concurrent.GuardedBy; + +/** + * Maintains the current xDS selector and any resources using that selector. When the selector + * changes, old resources are closed to avoid old config usages. + */ +final class FilterChainSelectorManager { + private static final AtomicLong closerId = new AtomicLong(); + + private final Object lock = new Object(); + @GuardedBy("lock") + private FilterChainSelector selector; + // Avoid HashSet since it does not decrease in size, forming a high water mark. + @GuardedBy("lock") + private TreeSet closers = new TreeSet(new CloserComparator()); + + public FilterChainSelector register(Closer closer) { + synchronized (lock) { + Preconditions.checkState(closers.add(closer), "closer already registered"); + return selector; + } + } + + public void deregister(Closer closer) { + synchronized (lock) { + closers.remove(closer); + } + } + + /** Only safe to be called by code that is responsible for updating the selector. */ + public FilterChainSelector getSelectorToUpdateSelector() { + synchronized (lock) { + return selector; + } + } + + public void updateSelector(FilterChainSelector newSelector) { + TreeSet oldClosers; + synchronized (lock) { + oldClosers = closers; + closers = new TreeSet(closers.comparator()); + selector = newSelector; + } + for (Closer closer : oldClosers) { + closer.closer.run(); + } + } + + @VisibleForTesting + int getRegisterCount() { + synchronized (lock) { + return closers.size(); + } + } + + public static final class Closer { + private final long id = closerId.getAndIncrement(); + private final Runnable closer; + + /** {@code closer} may be run multiple times. */ + public Closer(Runnable closer) { + this.closer = Preconditions.checkNotNull(closer, "closer"); + } + } + + private static class CloserComparator implements Comparator { + @Override public int compare(Closer c1, Closer c2) { + return Long.compare(c1.id, c2.id); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index 82eddd355af..410a64df9ca 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -22,10 +22,8 @@ import io.grpc.Internal; import io.grpc.NameResolver; import io.grpc.internal.ObjectPool; -import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import java.util.concurrent.atomic.AtomicReference; /** * Internal attributes used for xDS implementation. Do not use. @@ -81,9 +79,14 @@ public final class InternalXdsAttributes { * Filter chain match for network filters. */ @Grpc.TransportAttr - static final Attributes.Key> - ATTR_FILTER_CHAIN_SELECTOR_REF = Attributes.Key.create( - "io.grpc.xds.InternalXdsAttributes.filterChainSelectorRef"); + static final Attributes.Key + ATTR_FILTER_CHAIN_SELECTOR_MANAGER = Attributes.Key.create( + "io.grpc.xds.InternalXdsAttributes.filterChainSelectorManager"); + + /** Grace time to use when draining. Null for an infinite grace time. */ + @Grpc.TransportAttr + static final Attributes.Key ATTR_DRAIN_GRACE_NANOS = + Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.drainGraceTime"); private InternalXdsAttributes() {} } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index 34879fd8cd0..c95c1e6d48f 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -16,9 +16,11 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_REF; +import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.DoNotCall; @@ -33,11 +35,10 @@ import io.grpc.netty.InternalNettyServerCredentials; import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.NettyServerBuilder; -import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingNegotiatorServerFactory; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Logger; /** @@ -45,6 +46,8 @@ */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7514") public final class XdsServerBuilder extends ForwardingServerBuilder { + private static final long AS_LARGE_AS_INFINITE = TimeUnit.DAYS.toNanos(1000); + private final NettyServerBuilder delegate; private final int port; private XdsServingStatusListener xdsServingStatusListener; @@ -52,6 +55,8 @@ public final class XdsServerBuilder extends ForwardingServerBuilder= 0, "drain grace time must be non-negative: %s", + drainGraceTime); + checkNotNull(drainGraceTimeUnit, "drainGraceTimeUnit"); + if (drainGraceTimeUnit.toNanos(drainGraceTime) >= AS_LARGE_AS_INFINITE) { + drainGraceTimeUnit = null; + } + this.drainGraceTime = drainGraceTime; + this.drainGraceTimeUnit = drainGraceTimeUnit; + return this; + } + @DoNotCall("Unsupported. Use forPort(int, ServerCredentials) instead") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException( @@ -94,12 +119,15 @@ public static XdsServerBuilder forPort(int port, ServerCredentials serverCredent @Override public Server build() { checkState(isServerBuilt.compareAndSet(false, true), "Server already built!"); - AtomicReference filterChainSelectorRef = new AtomicReference<>(); - InternalNettyServerBuilder.eagAttributes(delegate, Attributes.newBuilder() - .set(ATTR_FILTER_CHAIN_SELECTOR_REF, filterChainSelectorRef) - .build()); + FilterChainSelectorManager filterChainSelectorManager = new FilterChainSelectorManager(); + Attributes.Builder builder = Attributes.newBuilder() + .set(ATTR_FILTER_CHAIN_SELECTOR_MANAGER, filterChainSelectorManager); + if (drainGraceTimeUnit != null) { + builder.set(ATTR_DRAIN_GRACE_NANOS, drainGraceTimeUnit.toNanos(drainGraceTime)); + } + InternalNettyServerBuilder.eagAttributes(delegate, builder.build()); return new XdsServerWrapper("0.0.0.0:" + port, delegate, xdsServingStatusListener, - filterChainSelectorRef, xdsClientPoolFactory, filterRegistry); + filterChainSelectorManager, xdsClientPoolFactory, filterRegistry); } @VisibleForTesting diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index faa6e9d34b2..29821f2cba8 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -100,7 +100,7 @@ public void uncaughtException(Thread t, Throwable e) { private final ThreadSafeRandom random = ThreadSafeRandomImpl.instance; private final XdsClientPoolFactory xdsClientPoolFactory; private final XdsServingStatusListener listener; - private final AtomicReference filterChainSelectorRef; + private final FilterChainSelectorManager filterChainSelectorManager; private final AtomicBoolean started = new AtomicBoolean(false); private final AtomicBoolean shutdown = new AtomicBoolean(false); private boolean isServing; @@ -117,11 +117,11 @@ public void uncaughtException(Thread t, Throwable e) { String listenerAddress, ServerBuilder delegateBuilder, XdsServingStatusListener listener, - AtomicReference filterChainSelectorRef, + FilterChainSelectorManager filterChainSelectorManager, XdsClientPoolFactory xdsClientPoolFactory, FilterRegistry filterRegistry) { - this(listenerAddress, delegateBuilder, listener, filterChainSelectorRef, xdsClientPoolFactory, - filterRegistry, SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + this(listenerAddress, delegateBuilder, listener, filterChainSelectorManager, + xdsClientPoolFactory, filterRegistry, SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); sharedTimeService = true; } @@ -130,7 +130,7 @@ public void uncaughtException(Thread t, Throwable e) { String listenerAddress, ServerBuilder delegateBuilder, XdsServingStatusListener listener, - AtomicReference filterChainSelectorRef, + FilterChainSelectorManager filterChainSelectorManager, XdsClientPoolFactory xdsClientPoolFactory, FilterRegistry filterRegistry, ScheduledExecutorService timeService) { @@ -138,7 +138,8 @@ public void uncaughtException(Thread t, Throwable e) { this.delegateBuilder = checkNotNull(delegateBuilder, "delegateBuilder"); this.delegateBuilder.intercept(new ConfigApplyingInterceptor()); this.listener = checkNotNull(listener, "listener"); - this.filterChainSelectorRef = checkNotNull(filterChainSelectorRef, "filterChainSelectorRef"); + this.filterChainSelectorManager + = checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); this.timeService = checkNotNull(timeService, "timeService"); this.filterRegistry = checkNotNull(filterRegistry,"filterRegistry"); @@ -361,8 +362,8 @@ public void run() { } checkNotNull(update.listener(), "update"); if (!pendingRds.isEmpty()) { - // filter chain state has not yet been applied to filterChainSelectorRef and there are - // two sets of sslContextProviderSuppliers, so we release the old ones. + // filter chain state has not yet been applied to filterChainSelectorManager and there + // are two sets of sslContextProviderSuppliers, so we release the old ones. releaseSuppliersInFlight(); pendingRds.clear(); } @@ -443,7 +444,7 @@ private void shutdown() { logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); xdsClient.cancelLdsResourceWatch(resourceName, this); List toRelease = getSuppliersInUse(); - filterChainSelectorRef.set(FilterChainSelector.NO_FILTER_CHAIN); + filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { s.close(); } @@ -460,7 +461,7 @@ private void updateSelector() { defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier(), defaultFilterChain == null ? null : generateRoutingConfig(defaultFilterChain)); List toRelease = getSuppliersInUse(); - filterChainSelectorRef.set(selector); + filterChainSelectorManager.updateSelector(selector); for (SslContextProviderSupplier e: toRelease) { e.close(); } @@ -482,7 +483,7 @@ private ServerRoutingConfig generateRoutingConfig(FilterChain filterChain) { private void handleConfigNotFound(StatusException exception) { cleanUpRouteDiscoveryStates(); List toRelease = getSuppliersInUse(); - filterChainSelectorRef.set(FilterChainSelector.NO_FILTER_CHAIN); + filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); for (SslContextProviderSupplier s: toRelease) { s.close(); } @@ -511,7 +512,7 @@ private void cleanUpRouteDiscoveryStates() { private List getSuppliersInUse() { List toRelease = new ArrayList<>(); - FilterChainSelector selector = filterChainSelectorRef.get(); + FilterChainSelector selector = filterChainSelectorManager.getSelectorToUpdateSelector(); if (selector != null) { for (FilterChain f: selector.getRoutingConfigs().keySet()) { if (f.getSslContextProviderSupplier() != null) { diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index 167f3f03c6b..891dec322c0 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -65,6 +65,7 @@ import java.util.Map; import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -87,6 +88,7 @@ public class FilterChainMatchingProtocolNegotiatorsTest { private ChannelHandlerContext channelHandlerCtx; @Mock private ProtocolNegotiator mockDelegate; + private FilterChainSelectorManager selectorManager = new FilterChainSelectorManager(); private static final HttpConnectionManager HTTP_CONNECTION_MANAGER = createRds("routing-config"); private static final String LOCAL_IP = "10.1.2.3"; // dest private static final String REMOTE_IP = "10.4.2.3"; // source @@ -94,6 +96,16 @@ public class FilterChainMatchingProtocolNegotiatorsTest { private final ServerRoutingConfig noopConfig = ServerRoutingConfig.create( new ArrayList(), new AtomicReference>()); + @After + @SuppressWarnings("FutureReturnValueIgnored") + public void tearDown() { + if (channel.isActive()) { + channel.close(); + channel.runPendingTasks(); + } + assertThat(selectorManager.getRegisterCount()).isEqualTo(0); + } + @Test public void nofilterChainMatch_defaultSslContext() throws Exception { final SettableFuture sslSet = SettableFuture.create(); @@ -103,10 +115,10 @@ public void nofilterChainMatch_defaultSslContext() throws Exception { SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( - new HashMap(), defaultSsl, noopConfig); + selectorManager.updateSelector(new FilterChainSelector( + new HashMap(), defaultSsl, noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); ChannelHandlerContext channelHandlerCtx = pipeline.context(filterChainMatchingHandler); assertThat(channelHandlerCtx).isNotNull(); @@ -125,10 +137,10 @@ public void nofilterChainMatch_defaultSslContext() throws Exception { @Test public void noFilterChainMatch_noDefaultSslContext() { - FilterChainSelector selector = new FilterChainSelector( - new HashMap(), null, null); + selectorManager.updateSelector(new FilterChainSelector( + new HashMap(), null, null)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); channelHandlerCtx = pipeline.context(filterChainMatchingHandler); assertThat(channelHandlerCtx).isNotNull(); @@ -139,6 +151,33 @@ public void noFilterChainMatch_noDefaultSslContext() { assertThat(channel.closeFuture().isDone()).isTrue(); } + @Test + public void filterSelectorChange_drainsConnection() { + ChannelHandler next = new ChannelInboundHandlerAdapter(); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + selectorManager.updateSelector(new FilterChainSelector( + new HashMap(), null, noopConfig)); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + pipeline.fireUserEventTriggered(event); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNull(); + + channel.runPendingTasks(); + channelHandlerCtx = pipeline.context(next); + assertThat(channelHandlerCtx).isNotNull(); + assertThat(channel.readOutbound()).isNull(); + + selectorManager.updateSelector(new FilterChainSelector( + new HashMap(), null, noopConfig)); + assertThat(channel.readOutbound().getClass().getName()) + .isEqualTo("io.grpc.netty.GracefulServerCloseCommand"); + } + @Test public void singleFilterChainWithoutAlpn() throws Exception { EnvoyServerProtoData.FilterChainMatch filterChainMatch = @@ -157,10 +196,10 @@ public void singleFilterChainWithoutAlpn() throws Exception { "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector(ImmutableMap.of(filterChain, noopConfig), - null, null); + selectorManager.updateSelector(new FilterChainSelector(ImmutableMap.of(filterChain, noopConfig), + null, null)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); @@ -196,11 +235,11 @@ public void singleFilterChainWithAlpn() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChain, randomConfig("no-match")), - defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -242,12 +281,12 @@ public void destPortFails_returnDefaultFilterChain() throws Exception { ServerRoutingConfig routingConfig = ServerRoutingConfig.create( new ArrayList(), new AtomicReference<>( ImmutableList.of(createVirtualHost("virtual")))); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainWithDestPort, routingConfig), - defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -285,12 +324,12 @@ public void destPrefixRangeMatch() throws Exception { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainWithMatch, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("no-match")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("no-match"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -329,12 +368,12 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), - defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -374,11 +413,11 @@ public void dest0LengthPrefixRange() "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChain0Length, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), null); + defaultFilterChain.getSslContextProviderSupplier(), null)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -431,13 +470,13 @@ public void destPrefixRange_moreSpecificWins() tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), filterChainMoreSpecific, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -490,12 +529,12 @@ public void destPrefixRange_emptyListLessSpecific() tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), filterChainMoreSpecific, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -547,13 +586,13 @@ public void destPrefixRangeIpv6_moreSpecificWins() tlsContextMoreSpecific, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), filterChainMoreSpecific, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -610,12 +649,12 @@ public void destPrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, filterChainLessSpecific, randomConfig("no-match")), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -653,11 +692,11 @@ public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), - defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -698,11 +737,11 @@ public void sourceTypeLocal() throws Exception { "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainWithMatch, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel(LOCAL_IP, LOCAL_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); @@ -757,13 +796,13 @@ public void sourcePrefixRange_moreSpecificWith2Wins() EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, filterChainLessSpecific, randomConfig("no-match")), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); @@ -823,12 +862,12 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChain1, noopConfig, filterChain2, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); pipeline.fireUserEventTriggered(event); channel.runPendingTasks(); @@ -884,13 +923,13 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChainEmptySourcePorts, randomConfig("no-match"), filterChainSourcePortMatch, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); @@ -1040,11 +1079,11 @@ public void filterChain_5stepMatch() throws Exception { map.put(filterChain4, randomConfig("4")); map.put(filterChain5, noopConfig); map.put(filterChain6, randomConfig("6")); - FilterChainSelector selector = new FilterChainSelector( - map, defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default")); + selectorManager.updateSelector(new FilterChainSelector( + map, defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); @@ -1114,12 +1153,12 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, mock(TlsContextManager.class)); - FilterChainSelector selector = new FilterChainSelector( + selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChain1, randomConfig("1"), filterChain2, randomConfig("2")), - defaultFilterChain.getSslContextProviderSupplier(), noopConfig); + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); final SettableFuture sslSet = SettableFuture.create(); final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); diff --git a/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java b/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java new file mode 100644 index 00000000000..d7b883f1941 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.FilterChainSelectorManager.Closer; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class FilterChainSelectorManagerTest { + private FilterChainSelectorManager manager = new FilterChainSelectorManager(); + private ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + Collections.emptyList(), + new AtomicReference>()); + private FilterChainSelector selector1 = new FilterChainSelector( + Collections.emptyMap(), null, null); + private FilterChainSelector selector2 = new FilterChainSelector( + Collections.emptyMap(), null, noopConfig); + private CounterRunnable runnable1 = new CounterRunnable(); + private CounterRunnable runnable2 = new CounterRunnable(); + + @Test + public void updateSelector_changesSelector() { + assertThat(manager.getSelectorToUpdateSelector()).isNull(); + assertThat(manager.register(new Closer(runnable1))).isNull(); + + manager.updateSelector(selector1); + + assertThat(runnable1.counter).isEqualTo(1); + assertThat(manager.getSelectorToUpdateSelector()).isSameInstanceAs(selector1); + assertThat(manager.register(new Closer(runnable2))).isSameInstanceAs(selector1); + assertThat(runnable2.counter).isEqualTo(0); + } + + @Test + public void updateSelector_callsCloserOnce() { + assertThat(manager.register(new Closer(runnable1))).isNull(); + + manager.updateSelector(selector1); + manager.updateSelector(selector2); + + assertThat(runnable1.counter).isEqualTo(1); + } + + @Test + public void deregister_removesCloser() { + Closer closer1 = new Closer(runnable1); + manager.updateSelector(selector1); + assertThat(manager.register(closer1)).isSameInstanceAs(selector1); + assertThat(manager.getRegisterCount()).isEqualTo(1); + + manager.deregister(closer1); + + assertThat(manager.getRegisterCount()).isEqualTo(0); + manager.updateSelector(selector2); + assertThat(runnable1.counter).isEqualTo(0); + } + + @Test + public void deregister_removesCorrectCloser() { + Closer closer1 = new Closer(runnable1); + Closer closer2 = new Closer(runnable2); + manager.updateSelector(selector1); + assertThat(manager.register(closer1)).isSameInstanceAs(selector1); + assertThat(manager.register(closer2)).isSameInstanceAs(selector1); + assertThat(manager.getRegisterCount()).isEqualTo(2); + + manager.deregister(closer1); + + assertThat(manager.getRegisterCount()).isEqualTo(1); + manager.updateSelector(selector2); + assertThat(runnable1.counter).isEqualTo(0); + assertThat(runnable2.counter).isEqualTo(1); + } + + private static class CounterRunnable implements Runnable { + int counter; + + @Override public void run() { + counter++; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index 532cb282b26..1871cb79770 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -72,7 +72,6 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -101,7 +100,7 @@ public class XdsClientWrapperForServerSdsTestMisc { @Mock private XdsServingStatusListener listener; private FakeXdsClient xdsClient = new FakeXdsClient(); - private AtomicReference selectorRef = new AtomicReference<>(); + private FilterChainSelectorManager selectorManager = new FilterChainSelectorManager(); private XdsServerWrapper xdsServerWrapper; @@ -117,13 +116,14 @@ public void setUp() { when(mockBuilder.build()).thenReturn(mockServer); when(mockServer.isShutdown()).thenReturn(false); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:" + PORT, mockBuilder, listener, - selectorRef, new FakeXdsClientPoolFactory(xdsClient), FilterRegistry.newRegistry()); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), FilterRegistry.newRegistry()); } @Test public void nonInetSocketAddress_expectNull() throws Exception { sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager); - assertThat(getSslContextProviderSupplier(selectorRef.get())).isNull(); + assertThat(getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector())) + .isNull(); } @Test @@ -168,7 +168,7 @@ public void run() { LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); xdsClient.ldsWatcher.onChanged(listenerUpdate); start.get(5, TimeUnit.SECONDS); - FilterChainSelector selector = selectorRef.get(); + FilterChainSelector selector = selectorManager.getSelectorToUpdateSelector(); assertThat(getSslContextProviderSupplier(selector)).isNull(); } @@ -193,7 +193,7 @@ public void run() { } catch (ExecutionException ex) { assertThat(ex.getCause()).isInstanceOf(IOException.class); } - assertThat(selectorRef.get()).isSameInstanceAs(NO_FILTER_CHAIN); + assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test @@ -217,7 +217,7 @@ public void run() { } catch (ExecutionException ex) { assertThat(ex.getCause()).isInstanceOf(IOException.class); } - assertThat(selectorRef.get()).isSameInstanceAs(NO_FILTER_CHAIN); + assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test @@ -241,7 +241,7 @@ public void run() { } catch (ExecutionException ex) { assertThat(ex.getCause()).isInstanceOf(IOException.class); } - assertThat(selectorRef.get()).isSameInstanceAs(NO_FILTER_CHAIN); + assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test @@ -263,13 +263,14 @@ public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws Exception localAddress = new InetSocketAddress(ipLocalAddress, PORT); sendListenerUpdate(localAddress, tlsContext1, null, tlsContextManager); - SslContextProviderSupplier returnedSupplier = getSslContextProviderSupplier(selectorRef.get()); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); XdsServerTestHelper .generateListenerUpdate(xdsClient, Arrays.asList(1234), tlsContext2, tlsContext3, tlsContextManager); - returnedSupplier = getSslContextProviderSupplier(selectorRef.get()); + returnedSupplier = getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext2); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); reset(tlsContextManager); @@ -294,7 +295,7 @@ public SocketAddress remoteAddress() { } }; pipeline = channel.pipeline(); - returnedSupplier = getSslContextProviderSupplier(selectorRef.get()); + returnedSupplier = getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext3); callUpdateSslContext(returnedSupplier); xdsServerWrapper.shutdown(); @@ -314,7 +315,7 @@ public void releaseOldSupplierOnNotFound_verifyClose() throws Exception { sendListenerUpdate(localAddress, tlsContext1, null, tlsContextManager); SslContextProviderSupplier returnedSupplier = - getSslContextProviderSupplier(selectorRef.get()); + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); xdsClient.ldsWatcher.onResourceDoesNotExist("not-found Error"); @@ -331,7 +332,7 @@ public void releaseOldSupplierOnPermDeniedError_verifyClose() throws Exception { sendListenerUpdate(localAddress, tlsContext1, null, tlsContextManager); SslContextProviderSupplier returnedSupplier = - getSslContextProviderSupplier(selectorRef.get()); + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); @@ -348,7 +349,7 @@ public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { sendListenerUpdate(localAddress, tlsContext1, null, tlsContextManager); SslContextProviderSupplier returnedSupplier = - getSslContextProviderSupplier(selectorRef.get()); + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); callUpdateSslContext(returnedSupplier); xdsClient.ldsWatcher.onError(Status.CANCELLED); @@ -412,8 +413,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { ProtocolNegotiator mockDelegate = mock(ProtocolNegotiator.class); GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + FilterChainSelectorManager manager = new FilterChainSelectorManager(); + manager.updateSelector(selector); FilterChainMatchingHandler filterChainMatchingHandler = - new FilterChainMatchingHandler(grpcHandler, selector, mockDelegate); + new FilterChainMatchingHandler(grpcHandler, manager, mockDelegate); pipeline.addLast(filterChainMatchingHandler); ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); pipeline.fireUserEventTriggered(event); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 476dc10a16a..0d15c1f660e 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -281,4 +281,15 @@ public void xdsServer_2ndSetter_expectException() throws IOException { assertThat(expected).hasMessageThat().contains("Server already built!"); } } + + @Test + public void drainGraceTime_negativeThrows() throws IOException { + buildBuilder(null); + try { + builder.drainGraceTime(-1, TimeUnit.SECONDS); + fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().contains("drain grace time"); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 876b0913742..d4421361158 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -99,7 +99,7 @@ public class XdsServerWrapperTest { @Mock private XdsServingStatusListener listener; - private AtomicReference selectorRef = new AtomicReference<>(); + private FilterChainSelectorManager selectorManager = new FilterChainSelectorManager(); private FakeClock executor = new FakeClock(); private FakeXdsClient xdsClient = new FakeXdsClient(); private FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); @@ -109,7 +109,7 @@ public class XdsServerWrapperTest { public void setup() { when(mockBuilder.build()).thenReturn(mockServer); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, - selectorRef, new FakeXdsClientPoolFactory(xdsClient), + selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry, executor.getScheduledExecutorService()); } @@ -141,7 +141,7 @@ private void verifyBootstrapFail(Bootstrapper.BootstrapInfo b) throws Exception XdsClient xdsClient = mock(XdsClient.class); when(xdsClient.getBootstrapInfo()).thenReturn(b); xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, - selectorRef, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -377,8 +377,10 @@ public void run() { xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); start.get(5000, TimeUnit.MILLISECONDS); assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); - ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(filterChain); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain); assertThat(realConfig.virtualHosts().get()).isEqualTo(httpConnectionManager.virtualHosts()); assertThat(realConfig.httpFilterConfigs()).isEqualTo(httpConnectionManager.httpFilterConfigs()); verify(listener).onServing(); @@ -408,7 +410,7 @@ public void run() { xdsClient.rdsCount = new CountDownLatch(3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); assertThat(start.isDone()).isFalse(); - assertThat(selectorRef.get()).isNull(); + assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); verify(mockServer, never()).start(); verify(listener, never()).onServing(); @@ -426,23 +428,26 @@ public void run() { Collections.singletonList(createVirtualHost("virtual-host-2"))); start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer).start(); - ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(f0); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f0.getHttpConnectionManager().httpFilterConfigs()); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(2); - realConfig = selectorRef.get().getRoutingConfigs().get(f2); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(2); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f2); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f2.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorRef.get().getDefaultRoutingConfig(); + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig(); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-2"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f3.getHttpConnectionManager().httpFilterConfigs()); - assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isEqualTo( + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) + .isEqualTo( f3.getSslContextProviderSupplier()); } @@ -468,31 +473,32 @@ public void run() { xdsClient.rdsCount = new CountDownLatch(1); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); assertThat(start.isDone()).isFalse(); - assertThat(selectorRef.get()).isNull(); + assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer, times(1)).start(); - ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(f0); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f0.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorRef.get().getRoutingConfigs().get(f1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f1.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorRef.get().getDefaultRoutingConfig(); + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig(); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f2.getHttpConnectionManager().httpFilterConfigs()); - assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isSameInstanceAs( - f2.getSslContextProviderSupplier()); + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) + .isSameInstanceAs(f2.getSslContextProviderSupplier()); EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); @@ -505,24 +511,26 @@ public void run() { xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(2); - realConfig = selectorRef.get().getRoutingConfigs().get(f5); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(2); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f5); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f5.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorRef.get().getRoutingConfigs().get(f3); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f3); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f3.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorRef.get().getDefaultRoutingConfig(); + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig(); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f4.getHttpConnectionManager().httpFilterConfigs()); - assertThat(selectorRef.get().getDefaultSslContextProviderSupplier()).isSameInstanceAs( + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) + .isSameInstanceAs( f4.getSslContextProviderSupplier()); verify(mockServer, times(1)).start(); xdsServerWrapper.shutdown(); @@ -556,33 +564,35 @@ public void run() { xdsClient.rdsCount.await(); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); start.get(5000, TimeUnit.MILLISECONDS); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(2); - ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(f1); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(2); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); assertThat(realConfig.virtualHosts().get()).isNull(); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f1.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorRef.get().getRoutingConfigs().get(f0); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0); assertThat(realConfig.virtualHosts().get()).isEqualTo(hcmVirtual.virtualHosts()); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f0.getHttpConnectionManager().httpFilterConfigs()); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-1"))); - realConfig = selectorRef.get().getRoutingConfigs().get(f1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f1.getHttpConnectionManager().httpFilterConfigs()); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); - realConfig = selectorRef.get().getRoutingConfigs().get(f1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f1.getHttpConnectionManager().httpFilterConfigs()); xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); - realConfig = selectorRef.get().getRoutingConfigs().get(f1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); assertThat(realConfig.virtualHosts().get()).isNull(); assertThat(realConfig.httpFilterConfigs()).isEqualTo( f1.getHttpConnectionManager().httpFilterConfigs()); @@ -615,7 +625,8 @@ public void run() { SslContextProviderSupplier sslSupplier0 = filterChain0.getSslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain0), null); xdsClient.ldsWatcher.onError(Status.INTERNAL); - assertThat(selectorRef.get()).isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + assertThat(selectorManager.getSelectorToUpdateSelector()) + .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); assertThat(xdsClient.rdsWatchers).isEmpty(); verify(mockBuilder, times(1)).build(); verify(listener, times(2)).onNotServing(any(StatusException.class)); @@ -634,8 +645,10 @@ public void run() { verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(1)).onServing(); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); - ServerRoutingConfig realConfig = selectorRef.get().getRoutingConfigs().get(filterChain1); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain1); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( @@ -648,8 +661,10 @@ public void run() { verify(mockBuilder, times(1)).build(); verify(mockServer, times(2)).start(); verify(listener, times(2)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); - realConfig = selectorRef.get().getRoutingConfigs().get(filterChain1); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain1); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-2"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( @@ -661,7 +676,8 @@ public void run() { assertThat(xdsClient.rdsWatchers).isEmpty(); verify(mockServer, times(3)).shutdown(); when(mockServer.isShutdown()).thenReturn(true); - assertThat(selectorRef.get()).isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + assertThat(selectorManager.getSelectorToUpdateSelector()) + .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); verify(listener, times(3)).onNotServing(any(StatusException.class)); assertThat(sslSupplier1.isShutdown()).isTrue(); // no op @@ -686,8 +702,10 @@ public void run() { verify(mockServer, times(3)).start(); verify(listener, times(1)).onServing(); verify(listener, times(3)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); - realConfig = selectorRef.get().getRoutingConfigs().get(filterChain2); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain2); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( @@ -712,8 +730,10 @@ public void run() { when(mockServer.isShutdown()).thenReturn(false); verify(listener, times(4)).onNotServing(any(StatusException.class)); - assertThat(selectorRef.get().getRoutingConfigs().size()).isEqualTo(1); - realConfig = selectorRef.get().getRoutingConfigs().get(filterChain3); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain3); assertThat(realConfig.virtualHosts().get()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); assertThat(realConfig.httpFilterConfigs()).isEqualTo( From 0b8b33da77050c8ecb333f1e1bd05b3c375f1642 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Mon, 20 Sep 2021 16:14:43 -0700 Subject: [PATCH 78/82] xds, rbac: build per route serverInterceptor for httpConfig (#8524) (#8539) --- ...ilterChainMatchingProtocolNegotiators.java | 24 +- .../java/io/grpc/xds/XdsServerWrapper.java | 209 ++++++---- ...rChainMatchingProtocolNegotiatorsTest.java | 80 ++-- .../xds/FilterChainSelectorManagerTest.java | 15 +- .../io/grpc/xds/XdsServerWrapperTest.java | 383 +++++++++++------- 5 files changed, 414 insertions(+), 297 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java index 24cd4e9ae7e..b828b862454 100644 --- a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -59,6 +59,7 @@ import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -135,28 +136,29 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc static final class FilterChainSelector { public static final FilterChainSelector NO_FILTER_CHAIN = new FilterChainSelector( - Collections.emptyMap(), null, null); - private final Map routingConfigs; + Collections.>emptyMap(), + null, new AtomicReference()); + private final Map> routingConfigs; @Nullable private final SslContextProviderSupplier defaultSslContextProviderSupplier; @Nullable - private final ServerRoutingConfig defaultRoutingConfig; + private final AtomicReference defaultRoutingConfig; - FilterChainSelector(Map routingConfigs, + FilterChainSelector(Map> routingConfigs, @Nullable SslContextProviderSupplier defaultSslContextProviderSupplier, - @Nullable ServerRoutingConfig defaultRoutingConfig) { + @Nullable AtomicReference defaultRoutingConfig) { this.routingConfigs = checkNotNull(routingConfigs, "routingConfigs"); this.defaultSslContextProviderSupplier = defaultSslContextProviderSupplier; - this.defaultRoutingConfig = defaultRoutingConfig; + this.defaultRoutingConfig = checkNotNull(defaultRoutingConfig, "defaultRoutingConfig"); } @VisibleForTesting - Map getRoutingConfigs() { + Map> getRoutingConfigs() { return routingConfigs; } @VisibleForTesting - ServerRoutingConfig getDefaultRoutingConfig() { + AtomicReference getDefaultRoutingConfig() { return defaultRoutingConfig; } @@ -189,7 +191,7 @@ SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) return new SelectedConfig( routingConfigs.get(selected), selected.getSslContextProviderSupplier()); } - if (defaultRoutingConfig != null) { + if (defaultRoutingConfig.get() != null) { return new SelectedConfig(defaultRoutingConfig, defaultSslContextProviderSupplier); } return null; @@ -393,11 +395,11 @@ public void close() { * The FilterChain level configuration. */ private static final class SelectedConfig { - private final ServerRoutingConfig routingConfig; + private final AtomicReference routingConfig; @Nullable private final SslContextProviderSupplier sslContextProviderSupplier; - private SelectedConfig(ServerRoutingConfig routingConfig, + private SelectedConfig(AtomicReference routingConfig, @Nullable SslContextProviderSupplier sslContextProviderSupplier) { this.routingConfig = checkNotNull(routingConfig, "routingConfig"); this.sslContextProviderSupplier = sslContextProviderSupplier; diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index 29821f2cba8..e7301500e0e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -22,6 +22,7 @@ import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.InternalServerInterceptors; @@ -87,8 +88,9 @@ public void uncaughtException(Thread t, Throwable e) { } }); - public static final Attributes.Key ATTR_SERVER_ROUTING_CONFIG = - Attributes.Key.create("io.grpc.xds.ServerWrapper.serverRoutingConfig"); + public static final Attributes.Key> + ATTR_SERVER_ROUTING_CONFIG = + Attributes.Key.create("io.grpc.xds.ServerWrapper.serverRoutingConfig"); @VisibleForTesting static final long RETRY_DELAY_NANOS = TimeUnit.MINUTES.toNanos(1); @@ -346,6 +348,15 @@ private final class DiscoveryState implements LdsResourceWatcher { @Nullable private FilterChain defaultFilterChain; private boolean stopped; + private final Map> savedRdsRoutingConfigRef + = new HashMap<>(); + private final ServerInterceptor noopInterceptor = new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + return next.startCall(call, headers); + } + }; private DiscoveryState(String resourceName) { this.resourceName = checkNotNull(resourceName, "resourceName"); @@ -452,14 +463,16 @@ private void shutdown() { } private void updateSelector() { - Map filterChainRouting = new HashMap<>(); + Map> filterChainRouting = new HashMap<>(); + savedRdsRoutingConfigRef.clear(); for (FilterChain filterChain: filterChains) { filterChainRouting.put(filterChain, generateRoutingConfig(filterChain)); } FilterChainSelector selector = new FilterChainSelector( Collections.unmodifiableMap(filterChainRouting), defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier(), - defaultFilterChain == null ? null : generateRoutingConfig(defaultFilterChain)); + defaultFilterChain == null ? new AtomicReference() : + generateRoutingConfig(defaultFilterChain)); List toRelease = getSuppliersInUse(); filterChainSelectorManager.updateSelector(selector); for (SslContextProviderSupplier e: toRelease) { @@ -468,18 +481,84 @@ private void updateSelector() { startDelegateServer(); } - private ServerRoutingConfig generateRoutingConfig(FilterChain filterChain) { + private AtomicReference generateRoutingConfig(FilterChain filterChain) { HttpConnectionManager hcm = filterChain.getHttpConnectionManager(); if (hcm.virtualHosts() != null) { - return ServerRoutingConfig.create(hcm.httpFilterConfigs(), - new AtomicReference<>(hcm.virtualHosts())); + ImmutableMap interceptors = generatePerRouteInterceptors( + hcm.httpFilterConfigs(), hcm.virtualHosts()); + return new AtomicReference<>(ServerRoutingConfig.create(hcm.virtualHosts(),interceptors)); } else { RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); checkNotNull(rds, "rds"); - return ServerRoutingConfig.create(hcm.httpFilterConfigs(), rds.savedVirtualHosts); + AtomicReference serverRoutingConfigRef = new AtomicReference<>(); + if (rds.savedVirtualHosts != null) { + ImmutableMap interceptors = generatePerRouteInterceptors( + hcm.httpFilterConfigs(), rds.savedVirtualHosts); + ServerRoutingConfig serverRoutingConfig = + ServerRoutingConfig.create(rds.savedVirtualHosts, interceptors); + serverRoutingConfigRef.set(serverRoutingConfig); + } else { + serverRoutingConfigRef.set(ServerRoutingConfig.FAILING_ROUTING_CONFIG); + } + savedRdsRoutingConfigRef.put(filterChain, serverRoutingConfigRef); + return serverRoutingConfigRef; } } + private ImmutableMap generatePerRouteInterceptors( + List namedFilterConfigs, List virtualHosts) { + ImmutableMap.Builder perRouteInterceptors = + new ImmutableMap.Builder<>(); + for (VirtualHost virtualHost : virtualHosts) { + for (Route route : virtualHost.routes()) { + List filterInterceptors = new ArrayList<>(); + Map selectedOverrideConfigs = + new HashMap<>(virtualHost.filterConfigOverrides()); + selectedOverrideConfigs.putAll(route.filterConfigOverrides()); + for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { + FilterConfig filterConfig = namedFilterConfig.filterConfig; + Filter filter = filterRegistry.get(filterConfig.typeUrl()); + if (filter instanceof ServerInterceptorBuilder) { + ServerInterceptor interceptor = + ((ServerInterceptorBuilder) filter).buildServerInterceptor( + filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); + if (interceptor != null) { + filterInterceptors.add(interceptor); + } + } else { + logger.log(Level.WARNING, "HttpFilterConfig(type URL: " + + filterConfig.typeUrl() + ") is not supported on server-side. " + + "Probably a bug at ClientXdsClient verification."); + } + } + ServerInterceptor interceptor = combineInterceptors(filterInterceptors); + perRouteInterceptors.put(route, interceptor); + } + } + return perRouteInterceptors.build(); + } + + private ServerInterceptor combineInterceptors(final List interceptors) { + if (interceptors.isEmpty()) { + return noopInterceptor; + } + if (interceptors.size() == 1) { + return interceptors.get(0); + } + return new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + // intercept forward + for (int i = interceptors.size() - 1; i >= 0; i--) { + next = InternalServerInterceptors.interceptCallHandlerCreate( + interceptors.get(i), next); + } + return next.startCall(call, headers); + } + }; + } + private void handleConfigNotFound(StatusException exception) { cleanUpRouteDiscoveryStates(); List toRelease = getSuppliersInUse(); @@ -508,6 +587,7 @@ private void cleanUpRouteDiscoveryStates() { xdsClient.cancelRdsResourceWatch(rdsName, rdsState); } routeDiscoveryStates.clear(); + savedRdsRoutingConfigRef.clear(); } private List getSuppliersInUse() { @@ -544,8 +624,7 @@ private void releaseSuppliersInFlight() { private final class RouteDiscoveryState implements RdsResourceWatcher { private final String resourceName; - private AtomicReference> savedVirtualHosts = - new AtomicReference<>(); + private ImmutableList savedVirtualHosts; private boolean isPending = true; private RouteDiscoveryState(String resourceName) { @@ -560,7 +639,8 @@ public void run() { if (!routeDiscoveryStates.containsKey(resourceName)) { return; } - savedVirtualHosts.set(ImmutableList.copyOf(update.virtualHosts)); + savedVirtualHosts = ImmutableList.copyOf(update.virtualHosts); + updateRdsRoutingConfig(); maybeUpdateSelector(); } }); @@ -575,7 +655,8 @@ public void run() { return; } logger.log(Level.WARNING, "Rds {0} unavailable", resourceName); - savedVirtualHosts.set(null); + savedVirtualHosts = null; + updateRdsRoutingConfig(); maybeUpdateSelector(); } }); @@ -596,6 +677,25 @@ public void run() { }); } + private void updateRdsRoutingConfig() { + for (FilterChain filterChain : savedRdsRoutingConfigRef.keySet()) { + if (resourceName.equals(filterChain.getHttpConnectionManager().rdsName())) { + ServerRoutingConfig updatedRoutingConfig; + if (savedVirtualHosts == null) { + updatedRoutingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; + } else { + ImmutableMap updatedInterceptors = + generatePerRouteInterceptors( + filterChain.getHttpConnectionManager().httpFilterConfigs(), + savedVirtualHosts); + updatedRoutingConfig = ServerRoutingConfig.create(savedVirtualHosts, + updatedInterceptors); + } + savedRdsRoutingConfigRef.get(filterChain).set(updatedRoutingConfig); + } + } + } + // Update the selector to use the most recently updated configs only after all rds have been // discovered for the first time. Later changes on rds will be applied through virtual host // list atomic ref. @@ -632,18 +732,16 @@ public Listener interceptCall(ServerCall call, @Override public Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { - ServerRoutingConfig routingConfig = call.getAttributes().get(ATTR_SERVER_ROUTING_CONFIG); - if (routingConfig == null) { - String errorMsg = "Missing xDS routing config."; - call.close(Status.UNAVAILABLE.withDescription(errorMsg), new Metadata()); - return new Listener() {}; - } - List virtualHosts = routingConfig.virtualHosts().get(); - if (virtualHosts == null) { - String errorMsg = "Missing xDS routing config VirtualHosts due to RDS config unavailable."; + AtomicReference routingConfigRef = + call.getAttributes().get(ATTR_SERVER_ROUTING_CONFIG); + ServerRoutingConfig routingConfig = routingConfigRef == null ? null : + routingConfigRef.get(); + if (routingConfig == null || routingConfig == ServerRoutingConfig.FAILING_ROUTING_CONFIG) { + String errorMsg = "Missing or broken xDS routing config: RDS config unavailable."; call.close(Status.UNAVAILABLE.withDescription(errorMsg), new Metadata()); return new Listener() {}; } + List virtualHosts = routingConfig.virtualHosts(); VirtualHost virtualHost = RoutingUtils.findVirtualHostForHostName( virtualHosts, call.getAuthority()); if (virtualHost == null) { @@ -653,14 +751,11 @@ public Listener interceptCall(ServerCall call, return new Listener() {}; } Route selectedRoute = null; - Map selectedOverrideConfigs = - new HashMap<>(virtualHost.filterConfigOverrides()); MethodDescriptor method = call.getMethodDescriptor(); for (Route route : virtualHost.routes()) { if (RoutingUtils.matchRoute( route.routeMatch(), "/" + method.getFullMethodName(), headers, random)) { selectedRoute = route; - selectedOverrideConfigs.putAll(route.filterConfigOverrides()); break; } } @@ -670,48 +765,12 @@ public Listener interceptCall(ServerCall call, new Metadata()); return new ServerCall.Listener() {}; } - List filterInterceptors = new ArrayList<>(); - for (NamedFilterConfig namedFilterConfig : routingConfig.httpFilterConfigs()) { - FilterConfig filterConfig = namedFilterConfig.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ServerInterceptorBuilder) { - ServerInterceptor interceptor = - ((ServerInterceptorBuilder) filter).buildServerInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); - if (interceptor != null) { - filterInterceptors.add(interceptor); - } - } else { - call.close( - Status.UNAVAILABLE.withDescription("HttpFilterConfig(type URL: " - + filterConfig.typeUrl() + ") is not supported on server-side."), - new Metadata()); - return new Listener() {}; - } + ServerInterceptor routeInterceptor = noopInterceptor; + Map perRouteInterceptors = routingConfig.interceptors(); + if (perRouteInterceptors != null && perRouteInterceptors.get(selectedRoute) != null) { + routeInterceptor = perRouteInterceptors.get(selectedRoute); } - ServerInterceptor interceptor = combineInterceptors(filterInterceptors); - return interceptor.interceptCall(call, headers, next); - } - - private ServerInterceptor combineInterceptors(final List interceptors) { - if (interceptors.isEmpty()) { - return noopInterceptor; - } - if (interceptors.size() == 1) { - return interceptors.get(0); - } - return new ServerInterceptor() { - @Override - public Listener interceptCall(ServerCall call, - Metadata headers, ServerCallHandler next) { - // intercept forward - for (int i = interceptors.size() - 1; i >= 0; i--) { - next = InternalServerInterceptors.interceptCallHandlerCreate( - interceptors.get(i), next); - } - return next.startCall(call, headers); - } - }; + return routeInterceptor.interceptCall(call, headers, next); } } @@ -720,20 +779,24 @@ public Listener interceptCall(ServerCall call, */ @AutoValue abstract static class ServerRoutingConfig { - // Top level http filter configs. - abstract ImmutableList httpFilterConfigs(); + @VisibleForTesting + static final ServerRoutingConfig FAILING_ROUTING_CONFIG = ServerRoutingConfig.create( + ImmutableList.of(), ImmutableMap.of()); + + abstract ImmutableList virtualHosts(); - abstract AtomicReference> virtualHosts(); + // Prebuilt per route server interceptors from http filter configs. + abstract ImmutableMap interceptors(); /** * Server routing configuration. * */ - public static ServerRoutingConfig create(List httpFilterConfigs, - AtomicReference> virtualHosts) { - checkNotNull(httpFilterConfigs, "httpFilterConfigs"); + public static ServerRoutingConfig create( + ImmutableList virtualHosts, + ImmutableMap interceptors) { checkNotNull(virtualHosts, "virtualHosts"); - return new AutoValue_XdsServerWrapper_ServerRoutingConfig( - ImmutableList.copyOf(httpFilterConfigs), virtualHosts); + checkNotNull(interceptors, "interceptors"); + return new AutoValue_XdsServerWrapper_ServerRoutingConfig(virtualHosts, interceptors); } } } diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java index 891dec322c0..b223516465f 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -26,6 +26,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.util.concurrent.SettableFuture; +import io.grpc.ServerInterceptor; import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.netty.GrpcHttp2ConnectionHandler; import io.grpc.netty.InternalProtocolNegotiationEvent; @@ -93,8 +94,12 @@ public class FilterChainMatchingProtocolNegotiatorsTest { private static final String LOCAL_IP = "10.1.2.3"; // dest private static final String REMOTE_IP = "10.4.2.3"; // source private static final int PORT = 7000; - private final ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - new ArrayList(), new AtomicReference>()); + private final AtomicReference noopConfig = new AtomicReference<>( + ServerRoutingConfig.create(ImmutableList.of(), + ImmutableMap.of())); + final SettableFuture sslSet = SettableFuture.create(); + final SettableFuture> routingSettable = + SettableFuture.create(); @After @SuppressWarnings("FutureReturnValueIgnored") @@ -108,15 +113,14 @@ public void tearDown() { @Test public void nofilterChainMatch_defaultSslContext() throws Exception { - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), tlsContextManager); selectorManager.updateSelector(new FilterChainSelector( - new HashMap(), defaultSsl, noopConfig)); + new HashMap>(), + defaultSsl, noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); @@ -138,7 +142,8 @@ public void nofilterChainMatch_defaultSslContext() throws Exception { @Test public void noFilterChainMatch_noDefaultSslContext() { selectorManager.updateSelector(new FilterChainSelector( - new HashMap(), null, null)); + new HashMap>(), + null, new AtomicReference())); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); @@ -156,7 +161,7 @@ public void filterSelectorChange_drainsConnection() { ChannelHandler next = new ChannelInboundHandlerAdapter(); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); selectorManager.updateSelector(new FilterChainSelector( - new HashMap(), null, noopConfig)); + new HashMap>(), null, noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); @@ -173,7 +178,7 @@ public void filterSelectorChange_drainsConnection() { assertThat(channel.readOutbound()).isNull(); selectorManager.updateSelector(new FilterChainSelector( - new HashMap(), null, noopConfig)); + new HashMap>(), null, noopConfig)); assertThat(channel.readOutbound().getClass().getName()) .isEqualTo("io.grpc.netty.GracefulServerCloseCommand"); } @@ -197,11 +202,9 @@ public void singleFilterChainWithoutAlpn() throws Exception { tlsContextManager); selectorManager.updateSelector(new FilterChainSelector(ImmutableMap.of(filterChain, noopConfig), - null, null)); + null, new AtomicReference())); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -241,8 +244,6 @@ public void singleFilterChainWithAlpn() throws Exception { FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -279,17 +280,16 @@ public void destPortFails_returnDefaultFilterChain() throws Exception { tlsContextForDefaultFilterChain, tlsContextManager); ServerRoutingConfig routingConfig = ServerRoutingConfig.create( - new ArrayList(), new AtomicReference<>( - ImmutableList.of(createVirtualHost("virtual")))); + ImmutableList.of(createVirtualHost("virtual")), + ImmutableMap.of()); selectorManager.updateSelector(new FilterChainSelector( - ImmutableMap.of(filterChainWithDestPort, routingConfig), + ImmutableMap.of(filterChainWithDestPort, + new AtomicReference(routingConfig)), defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -331,8 +331,6 @@ public void destPrefixRangeMatch() throws Exception { FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -375,8 +373,6 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -415,12 +411,11 @@ public void dest0LengthPrefixRange() selectorManager.updateSelector(new FilterChainSelector( ImmutableMap.of(filterChain0Length, noopConfig), - defaultFilterChain.getSslContextProviderSupplier(), null)); + defaultFilterChain.getSslContextProviderSupplier(), + new AtomicReference())); FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -478,8 +473,6 @@ public void destPrefixRange_moreSpecificWins() FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -536,8 +529,6 @@ public void destPrefixRange_emptyListLessSpecific() FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -594,8 +585,6 @@ public void destPrefixRangeIpv6_moreSpecificWins() FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); @@ -656,8 +645,6 @@ filterChainLessSpecific, randomConfig("no-match")), FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -698,8 +685,7 @@ public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -712,8 +698,6 @@ public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { @Test public void sourceTypeLocal() throws Exception { - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = @@ -753,8 +737,6 @@ public void sourceTypeLocal() throws Exception { @Test public void sourcePrefixRange_moreSpecificWith2Wins() throws Exception { - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); @@ -815,7 +797,6 @@ filterChainLessSpecific, randomConfig("no-match")), @Test public void sourcePrefixRange_2Matchers_expectException() throws UnknownHostException { - final SettableFuture sslSet = SettableFuture.create(); ChannelHandler next = new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { @@ -930,8 +911,6 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -1072,7 +1051,7 @@ public void filterChain_5stepMatch() throws Exception { EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - Map map = new HashMap<>(); + Map> map = new HashMap<>(); map.put(filterChain1, randomConfig("1")); map.put(filterChain2, randomConfig("2")); map.put(filterChain3, randomConfig("3")); @@ -1085,8 +1064,6 @@ public void filterChain_5stepMatch() throws Exception { FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -1159,8 +1136,6 @@ public void filterChainMatch_unsupportedMatchers() throws Exception { FilterChainMatchingHandler filterChainMatchingHandler = new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); - final SettableFuture sslSet = SettableFuture.create(); - final SettableFuture routingSettable = SettableFuture.create(); ChannelHandler next = captureAttrHandler(sslSet, routingSettable); when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); @@ -1184,10 +1159,11 @@ private static VirtualHost createVirtualHost(String name) { ImmutableMap.of()); } - private static ServerRoutingConfig randomConfig(String domain) { - return ServerRoutingConfig.create( - new ArrayList(), new AtomicReference<>( - ImmutableList.of(createVirtualHost(domain)))); + private static AtomicReference randomConfig(String domain) { + return new AtomicReference<>( + ServerRoutingConfig.create(ImmutableList.of(createVirtualHost(domain)), + ImmutableMap.of()) + ); } private EnvoyServerProtoData.DownstreamTlsContext createTls() { @@ -1216,7 +1192,7 @@ public SocketAddress remoteAddress() { private static ChannelHandler captureAttrHandler( final SettableFuture sslSet, - final SettableFuture routingSettable) { + final SettableFuture> routingSettable) { return new ChannelInboundHandlerAdapter() { @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { diff --git a/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java b/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java index d7b883f1941..a3a2218d4c3 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java @@ -19,8 +19,9 @@ import static com.google.common.truth.Truth.assertThat; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.ServerInterceptor; import io.grpc.xds.EnvoyServerProtoData.FilterChain; -import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.FilterChainSelectorManager.Closer; import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; @@ -33,13 +34,15 @@ @RunWith(JUnit4.class) public final class FilterChainSelectorManagerTest { private FilterChainSelectorManager manager = new FilterChainSelectorManager(); - private ServerRoutingConfig noopConfig = ServerRoutingConfig.create( - Collections.emptyList(), - new AtomicReference>()); + private AtomicReference noopConfig = new AtomicReference<>( + ServerRoutingConfig.create(ImmutableList.of(), + ImmutableMap.of())); private FilterChainSelector selector1 = new FilterChainSelector( - Collections.emptyMap(), null, null); + Collections.>emptyMap(), + null, new AtomicReference()); private FilterChainSelector selector2 = new FilterChainSelector( - Collections.emptyMap(), null, noopConfig); + Collections.>emptyMap(), + null, noopConfig); private CounterRunnable runnable1 = new CounterRunnable(); private CounterRunnable runnable2 = new CounterRunnable(); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index d4421361158..c109bb44a13 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -104,6 +104,8 @@ public class XdsServerWrapperTest { private FakeXdsClient xdsClient = new FakeXdsClient(); private FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private XdsServerWrapper xdsServerWrapper; + private ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + ImmutableList.of(), ImmutableMap.of()); @Before public void setup() { @@ -380,9 +382,9 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(1); ServerRoutingConfig realConfig = - selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain); - assertThat(realConfig.virtualHosts().get()).isEqualTo(httpConnectionManager.virtualHosts()); - assertThat(realConfig.httpFilterConfigs()).isEqualTo(httpConnectionManager.httpFilterConfigs()); + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain).get(); + assertThat(realConfig.virtualHosts()).isEqualTo(httpConnectionManager.virtualHosts()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); verify(listener).onServing(); verify(mockServer).start(); } @@ -429,26 +431,21 @@ public void run() { start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer).start(); ServerRoutingConfig realConfig = - selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f0.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(2); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f2); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f2).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f2.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig(); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig().get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-2"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f3.getHttpConnectionManager().httpFilterConfigs()); assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) - .isEqualTo( - f3.getSslContextProviderSupplier()); + .isEqualTo(f3.getSslContextProviderSupplier()); } @Test @@ -481,22 +478,20 @@ public void run() { start.get(5000, TimeUnit.MILLISECONDS); verify(mockServer, times(1)).start(); ServerRoutingConfig realConfig = - selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f0.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f1.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); - realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig(); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig().get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f2.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) .isSameInstanceAs(f2.getSslContextProviderSupplier()); @@ -513,25 +508,22 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(2); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f5); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f5).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f5.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f3); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f3).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-0"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f3.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); - realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig(); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig().get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f4.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) - .isSameInstanceAs( - f4.getSslContextProviderSupplier()); + .isSameInstanceAs(f4.getSslContextProviderSupplier()); verify(mockServer, times(1)).start(); xdsServerWrapper.shutdown(); verify(mockServer, times(1)).shutdown(); @@ -567,35 +559,31 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(2); ServerRoutingConfig realConfig = - selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); - assertThat(realConfig.virtualHosts().get()).isNull(); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f1.getHttpConnectionManager().httpFilterConfigs()); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0); - assertThat(realConfig.virtualHosts().get()).isEqualTo(hcmVirtual.virtualHosts()); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f0.getHttpConnectionManager().httpFilterConfigs()); + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEmpty(); + assertThat(realConfig.interceptors()).isEmpty(); + + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0).get(); + assertThat(realConfig.virtualHosts()).isEqualTo(hcmVirtual.virtualHosts()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-1"))); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f1.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f1.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); - realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1); - assertThat(realConfig.virtualHosts().get()).isNull(); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - f1.getHttpConnectionManager().httpFilterConfigs()); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEmpty(); + assertThat(realConfig.interceptors()).isEmpty(); } @Test @@ -648,11 +636,11 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(1); ServerRoutingConfig realConfig = - selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain1); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - filterChain1.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + // xds update after start xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-2"))); @@ -664,11 +652,11 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(1); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() - .get(filterChain1); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + .get(filterChain1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-2"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - filterChain1.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + assertThat(sslSupplier1.isShutdown()).isFalse(); // not serving after serving @@ -705,11 +693,11 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(1); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() - .get(filterChain2); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + .get(filterChain2).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - filterChain2.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + assertThat(executor.numPendingTasks()).isEqualTo(1); xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); verify(mockServer, times(4)).shutdown(); @@ -733,11 +721,11 @@ public void run() { assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) .isEqualTo(1); realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() - .get(filterChain3); - assertThat(realConfig.virtualHosts().get()).isEqualTo( + .get(filterChain3).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( Collections.singletonList(createVirtualHost("virtual-host-1"))); - assertThat(realConfig.httpFilterConfigs()).isEqualTo( - filterChain3.getHttpConnectionManager().httpFilterConfigs()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + xdsServerWrapper.shutdown(); verify(mockServer, times(5)).shutdown(); assertThat(sslSupplier3.isShutdown()).isTrue(); @@ -747,9 +735,9 @@ public void run() { @Test @SuppressWarnings("unchecked") - public void interceptor_notServerInterceptor() throws Exception { + public void interceptor_success() throws Exception { ArgumentCaptor interceptorCaptor = - ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -764,26 +752,36 @@ public void run() { xdsClient.ldsResource.get(5, TimeUnit.SECONDS); verify(mockBuilder).intercept(interceptorCaptor.capture()); ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); - ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", - "foo.google.com", "filter-type-url"); + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); + Route route = Route.forAction(routeMatch, null, + ImmutableMap.of()); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of()); + final List interceptorTrace = new ArrayList<>(); + ServerInterceptor interceptor0 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(0); + return next.startCall(call, headers); + } + }; + ServerRoutingConfig realConfig = ServerRoutingConfig.create( + ImmutableList.of(virtualHost), ImmutableMap.of(route, interceptor0)); ServerCall serverCall = mock(ServerCall.class); - when(serverCall.getAttributes()).thenReturn( - Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, + new AtomicReference<>(realConfig)).build()); when(serverCall.getAuthority()).thenReturn("foo.google.com"); - - Filter filter = mock(Filter.class); - when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); - filterRegistry.register(filter); ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); - verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); - verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); - Status status = statusCaptor.getValue(); - assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); - assertThat(status.getDescription()).isEqualTo( - "HttpFilterConfig(type URL: filter-type-url) is not supported on server-side."); + verify(next).startCall(eq(serverCall), any(Metadata.class)); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(0)); } @Test @@ -809,7 +807,8 @@ public void run() { "foo.google.com", "filter-type-url"); ServerCall serverCall = mock(ServerCall.class); when(serverCall.getAttributes()).thenReturn( - Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, + new AtomicReference<>(routingConfig)).build()); when(serverCall.getAuthority()).thenReturn("not-match.google.com"); Filter filter = mock(Filter.class); @@ -848,7 +847,8 @@ public void run() { "foo.google.com", "filter-type-url"); ServerCall serverCall = mock(ServerCall.class); when(serverCall.getAttributes()).thenReturn( - Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); + Attributes.newBuilder() + .set(ATTR_SERVER_ROUTING_CONFIG, new AtomicReference<>(routingConfig)).build()); when(serverCall.getMethodDescriptor()).thenReturn(createMethod("NotMatchMethod")); when(serverCall.getAuthority()).thenReturn("foo.google.com"); @@ -884,12 +884,11 @@ public void run() { xdsClient.ldsResource.get(5, TimeUnit.SECONDS); verify(mockBuilder).intercept(interceptorCaptor.capture()); ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); - ServerRoutingConfig failingConfig = ServerRoutingConfig.create( - ImmutableList.of(), new AtomicReference>() - ); ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( - Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, failingConfig).build()); + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, + new AtomicReference<>(ServerRoutingConfig.FAILING_ROUTING_CONFIG)).build()); ServerCallHandler next = mock(ServerCallHandler.class); interceptor.interceptCall(serverCall, new Metadata(), next); @@ -899,14 +898,12 @@ public void run() { Status status = statusCaptor.getValue(); assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); assertThat(status.getDescription()).isEqualTo( - "Missing xDS routing config VirtualHosts due to RDS config unavailable."); + "Missing or broken xDS routing config: RDS config unavailable."); } @Test @SuppressWarnings("unchecked") - public void interceptors() throws Exception { - ArgumentCaptor interceptorCaptor = - ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + public void buildInterceptor_inline() throws Exception { final SettableFuture start = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -919,14 +916,12 @@ public void run() { } }); xdsClient.ldsResource.get(5, TimeUnit.SECONDS); - verify(mockBuilder).intercept(interceptorCaptor.capture()); - final ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); RouteMatch routeMatch = - RouteMatch.create( - PathMatcher.fromPath("/FooService/barMethod", true), - Collections.emptyList(), null); + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); Filter filter = mock(Filter.class, withSettings() - .extraInterfaces(ServerInterceptorBuilder.class)); + .extraInterfaces(ServerInterceptorBuilder.class)); when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); filterRegistry.register(filter); FilterConfig f0 = mock(FilterConfig.class); @@ -936,7 +931,7 @@ public void run() { ServerInterceptor interceptor0 = new ServerInterceptor() { @Override public ServerCall.Listener interceptCall(ServerCall call, - Metadata headers, ServerCallHandler next) { + Metadata headers, ServerCallHandler next) { interceptorTrace.add(0); return next.startCall(call, headers); } @@ -949,55 +944,130 @@ public ServerCall.Listener interceptCall(ServerCallof()); VirtualHost virtualHost = VirtualHost.create( - "v1", Collections.singletonList("foo.google.com"), - Arrays.asList(Route.forAction(routeMatch, null, - ImmutableMap.of())), - ImmutableMap.of("filter-config-name-0", f0Override)); - ServerRoutingConfig routingConfig = ServerRoutingConfig.create( - Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0), - new NamedFilterConfig("filter-config-name-1", f0)), - new AtomicReference<>(ImmutableList.of(virtualHost)) - ); + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of("filter-config-name-0", f0Override)); + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0), + new NamedFilterConfig("filter-config-name-1", f0))); + EnvoyServerProtoData.FilterChain filterChain = createFilterChain("filter-chain-0", hcmVirtual); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerInterceptor realInterceptor = selectorManager.getSelectorToUpdateSelector() + .getRoutingConfigs().get(filterChain).get().interceptors().get(route); + assertThat(realInterceptor).isNotNull(); + ServerCall serverCall = mock(ServerCall.class); ServerCallHandler mockNext = mock(ServerCallHandler.class); final ServerCall.Listener listener = new ServerCall.Listener() {}; when(mockNext.startCall(any(ServerCall.class), any(Metadata.class))).thenReturn(listener); - when(serverCall.getAttributes()).thenReturn( - Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, routingConfig).build()); - when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); - when(serverCall.getAuthority()).thenReturn("foo.google.com"); + realInterceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(1, 0)); + verify(mockNext).startCall(eq(serverCall), any(Metadata.class)); + } - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) - .thenReturn(null); + @Test + @SuppressWarnings("unchecked") + public void buildInterceptor_rds() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + + Filter filter = mock(Filter.class, withSettings() + .extraInterfaces(ServerInterceptorBuilder.class)); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + FilterConfig f0 = mock(FilterConfig.class); + FilterConfig f0Override = mock(FilterConfig.class); + when(f0.typeUrl()).thenReturn("filter-type-url"); + final List interceptorTrace = new ArrayList<>(); + ServerInterceptor interceptor0 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(0); + return next.startCall(call, headers); + } + }; + ServerInterceptor interceptor1 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(1); + return next.startCall(call, headers); + } + }; when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) - .thenReturn(null); - ServerCall.Listener configApplyingInterceptorListener = - interceptor.interceptCall(serverCall, new Metadata(), mockNext); - assertThat(configApplyingInterceptorListener).isSameInstanceAs(listener); + .thenReturn(interceptor0); + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) + .thenReturn(interceptor1); + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); + + HttpConnectionManager rdsHcm = HttpConnectionManager.forRdsName(0L, "r0", + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0), + new NamedFilterConfig("filter-config-name-1", f0))); + EnvoyServerProtoData.FilterChain filterChain = createFilterChain("filter-chain-0", rdsHcm); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + Route route = Route.forAction(routeMatch, null, + ImmutableMap.of()); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of("filter-config-name-0", f0Override)); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost)); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerInterceptor realInterceptor = selectorManager.getSelectorToUpdateSelector() + .getRoutingConfigs().get(filterChain).get().interceptors().get(route); + assertThat(realInterceptor).isNotNull(); + + ServerCall serverCall = mock(ServerCall.class); + ServerCallHandler mockNext = mock(ServerCallHandler.class); + final ServerCall.Listener listener = new ServerCall.Listener() {}; + when(mockNext.startCall(any(ServerCall.class), any(Metadata.class))).thenReturn(listener); + realInterceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(1, 0)); verify(mockNext).startCall(eq(serverCall), any(Metadata.class)); - assertThat(interceptorTrace).isEqualTo(Arrays.asList()); - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) - .thenReturn(null); - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) - .thenReturn(interceptor0); - configApplyingInterceptorListener = interceptor.interceptCall( - serverCall, new Metadata(), mockNext); - assertThat(configApplyingInterceptorListener).isSameInstanceAs(listener); + virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of()); + xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost)); + realInterceptor = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain).get().interceptors().get(route); + assertThat(realInterceptor).isNotNull(); + interceptorTrace.clear(); + realInterceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(0, 0)); verify(mockNext, times(2)).startCall(eq(serverCall), any(Metadata.class)); - assertThat(interceptorTrace).isEqualTo(Arrays.asList(0)); - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) - .thenReturn(interceptor0); - when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) - .thenReturn(interceptor1); - configApplyingInterceptorListener = interceptor.interceptCall( - serverCall, new Metadata(), mockNext); - assertThat(configApplyingInterceptorListener).isSameInstanceAs(listener); - verify(mockNext, times(3)).startCall(eq(serverCall), any(Metadata.class)); - assertThat(interceptorTrace).isEqualTo(Arrays.asList(0, 0, 1)); + xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain).get()).isEqualTo(noopConfig); } private static FilterChain createFilterChain(String name, HttpConnectionManager hcm) { @@ -1012,8 +1082,12 @@ private static VirtualHost createVirtualHost(String name) { } private static HttpConnectionManager createRds(String name) { + return createRds(name, null); + } + + private static HttpConnectionManager createRds(String name, FilterConfig filterConfig) { return HttpConnectionManager.forRdsName(0L, name, - Arrays.asList(new NamedFilterConfig("named-config-" + name, null))); + Arrays.asList(new NamedFilterConfig("named-config-" + name, filterConfig))); } private static EnvoyServerProtoData.FilterChainMatch createMatch() { @@ -1041,9 +1115,8 @@ private static ServerRoutingConfig createRoutingConfig(String path, String domai Collections.emptyMap()); FilterConfig f0 = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn(filterType); - return ServerRoutingConfig.create( - Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0)), - new AtomicReference<>(ImmutableList.of(virtualHost)) + return ServerRoutingConfig.create(ImmutableList.of(virtualHost), + ImmutableMap.of() ); } From 733ab98294dc261cf5947461e22848b32f13ed15 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Mon, 20 Sep 2021 16:43:03 -0700 Subject: [PATCH 79/82] xds: implement RBAC gRFC misc cases (1.41.x backport) (#8540) --- .../java/io/grpc/xds/ClientXdsClient.java | 12 +++- xds/src/main/java/io/grpc/xds/RbacFilter.java | 16 +++++ .../java/io/grpc/xds/XdsServerWrapper.java | 15 +++-- .../rbac/engine/GrpcAuthorizationEngine.java | 56 +++++++++++++++- .../io/grpc/xds/ClientXdsClientDataTest.java | 17 +++++ .../test/java/io/grpc/xds/RbacFilterTest.java | 43 ++++++++++++ .../io/grpc/xds/XdsServerWrapperTest.java | 63 +++++++++++++++-- .../engine/GrpcAuthorizationEngineTest.java | 67 +++++++++++++++++++ 8 files changed, 275 insertions(+), 14 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index d490c9861b9..3bef4c416e0 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -136,6 +136,10 @@ final class ClientXdsClient extends AbstractXdsClient { static boolean enableRetry = Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")) || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")); + @VisibleForTesting + static boolean enableRbac = + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")); private static final String TYPE_URL_HTTP_CONNECTION_MANAGER_V2 = "type.googleapis.com/envoy.config.filter.network.http_connection_manager.v2" @@ -218,7 +222,7 @@ protected void handleLdsResponse(String versionInfo, List resources, String listener, retainedRdsResources, enableFaultInjection && isResourceV3); } else { ldsUpdate = processServerSideListener( - listener, retainedRdsResources, enableFaultInjection && isResourceV3); + listener, retainedRdsResources, enableRbac); } } catch (ResourceInvalidException e) { errors.add( @@ -729,10 +733,14 @@ private static FilterChainMatch parseFilterChainMatch( static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( HttpConnectionManager proto, Set rdsResources, FilterRegistry filterRegistry, boolean parseHttpFilter, boolean isForClient) throws ResourceInvalidException { - if (proto.getXffNumTrustedHops() != 0) { + if (enableRbac && proto.getXffNumTrustedHops() != 0) { throw new ResourceInvalidException( "HttpConnectionManager with xff_num_trusted_hops unsupported"); } + if (enableRbac && !proto.getOriginalIpDetectionExtensionsList().isEmpty()) { + throw new ResourceInvalidException("HttpConnectionManager with " + + "original_ip_detection_extensions unsupported"); + } // Obtain max_stream_duration from Http Protocol Options. long maxStreamDuration = 0; if (proto.hasCommonHttpProtocolOptions()) { diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java index 48b4954767a..39f91b475ae 100644 --- a/xds/src/main/java/io/grpc/xds/RbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -28,6 +28,7 @@ import io.envoyproxy.envoy.config.rbac.v3.Principal; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; +import io.envoyproxy.envoy.type.v3.Int32Range; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; @@ -45,6 +46,7 @@ import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthenticatedMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationIpMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortRangeMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.InvertMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.Matcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.OrMatcher; @@ -216,6 +218,8 @@ private static Matcher parsePermission(Permission permission) { return createDestinationIpMatcher(permission.getDestinationIp()); case DESTINATION_PORT: return createDestinationPortMatcher(permission.getDestinationPort()); + case DESTINATION_PORT_RANGE: + return parseDestinationPortRangeMatcher(permission.getDestinationPortRange()); case NOT_RULE: return new InvertMatcher(parsePermission(permission.getNotRule())); case METADATA: // hard coded, never match. @@ -291,6 +295,14 @@ private static RequestedServerNameMatcher parseRequestedServerNameMatcher( private static AuthHeaderMatcher parseHeaderMatcher( io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { + if (proto.getName().startsWith("grpc-")) { + throw new IllegalArgumentException("Invalid header matcher config: [grpc-] prefixed " + + "header name is not allowed."); + } + if (":scheme".equals(proto.getName())) { + throw new IllegalArgumentException("Invalid header matcher config: header name [:scheme] " + + "is not allowed."); + } return new AuthHeaderMatcher(MatcherParser.parseHeaderMatcher(proto)); } @@ -304,6 +316,10 @@ private static DestinationPortMatcher createDestinationPortMatcher(int port) { return new DestinationPortMatcher(port); } + private static DestinationPortRangeMatcher parseDestinationPortRangeMatcher(Int32Range range) { + return new DestinationPortRangeMatcher(range.getStart(), range.getEnd()); + } + private static DestinationIpMatcher createDestinationIpMatcher(CidrRange cidrRange) { return new DestinationIpMatcher(Matchers.CidrMatcher.create( resolve(cidrRange), cidrRange.getPrefixLen().getValue())); diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index e7301500e0e..fdc5f099bfe 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -639,6 +639,9 @@ public void run() { if (!routeDiscoveryStates.containsKey(resourceName)) { return; } + if (savedVirtualHosts == null && !isPending) { + logger.log(Level.WARNING, "Received valid Rds {0} configuration.", resourceName); + } savedVirtualHosts = ImmutableList.copyOf(update.virtualHosts); updateRdsRoutingConfig(); maybeUpdateSelector(); @@ -746,8 +749,8 @@ public Listener interceptCall(ServerCall call, virtualHosts, call.getAuthority()); if (virtualHost == null) { call.close( - Status.UNAVAILABLE.withDescription("Could not find xDS virtual host matching RPC"), - new Metadata()); + Status.UNAVAILABLE.withDescription("Could not find xDS virtual host matching RPC"), + new Metadata()); return new Listener() {}; } Route selectedRoute = null; @@ -760,11 +763,15 @@ public Listener interceptCall(ServerCall call, } } if (selectedRoute == null) { - call.close( - Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC"), + call.close(Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC"), new Metadata()); return new ServerCall.Listener() {}; } + if (selectedRoute.routeAction() != null) { + call.close(Status.UNAVAILABLE.withDescription("Invalid xDS route action for matching " + + "route: only Route.non_forwarding_action should be allowed."), new Metadata()); + return new ServerCall.Listener() {}; + } ServerInterceptor routeInterceptor = noopInterceptor; Map perRouteInterceptors = routingConfig.interceptors(); if (perRouteInterceptors != null && perRouteInterceptors.get(selectedRoute) != null) { diff --git a/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java b/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java index 6d275d322a2..bb911461a27 100644 --- a/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java +++ b/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java @@ -20,6 +20,7 @@ import com.google.auto.value.AutoValue; import com.google.common.base.Joiner; +import com.google.common.io.BaseEncoding; import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -35,6 +36,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -234,6 +236,23 @@ public boolean matches(EvaluateArgs args) { } } + public static final class DestinationPortRangeMatcher implements Matcher { + private final int start; + private final int end; + + /** Start of the range is inclusive. End of the range is exclusive.*/ + public DestinationPortRangeMatcher(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public boolean matches(EvaluateArgs args) { + int port = args.getDestinationPort(); + return port >= start && port < end; + } + } + public static final class RequestedServerNameMatcher implements Matcher { private final Matchers.StringMatcher delegate; @@ -316,9 +335,44 @@ private Collection getPrincipalNames() { @Nullable private String getHeader(String headerName) { - if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + headerName = headerName.toLowerCase(Locale.ROOT); + if ("te".equals(headerName)) { return null; } + if (":authority".equals(headerName)) { + headerName = "host"; + } + if ("host".equals(headerName)) { + return serverCall.getAuthority(); + } + if (":path".equals(headerName)) { + return getPath(); + } + if (":method".equals(headerName)) { + return "POST"; + } + return deserializeHeader(headerName); + } + + @Nullable + private String deserializeHeader(String headerName) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + Metadata.Key key; + try { + key = Metadata.Key.of(headerName, Metadata.BINARY_BYTE_MARSHALLER); + } catch (IllegalArgumentException e) { + return null; + } + Iterable values = metadata.getAll(key); + if (values == null) { + return null; + } + List encoded = new ArrayList<>(); + for (byte[] v : values) { + encoded.add(BaseEncoding.base64().omitPadding().encode(v)); + } + return Joiner.on(",").join(encoded); + } Metadata.Key key; try { key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 876615d0b39..597ca7df2d9 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -131,16 +131,20 @@ public class ClientXdsClientDataTest { public final ExpectedException thrown = ExpectedException.none(); private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private boolean originalEnableRetry; + private boolean originalEnableRbac; @Before public void setUp() { originalEnableRetry = ClientXdsClient.enableRetry; assertThat(originalEnableRetry).isTrue(); + originalEnableRbac = ClientXdsClient.enableRbac; + assertThat(originalEnableRbac).isTrue(); } @After public void tearDown() { ClientXdsClient.enableRetry = originalEnableRetry; + ClientXdsClient.enableRbac = originalEnableRbac; } @Test @@ -1108,6 +1112,19 @@ public void parseHttpConnectionManager_xffNumTrustedHopsUnsupported() hcm, new HashSet(), filterRegistry, false /* does not matter */, true /* does not matter */); } + + @Test + public void parseHttpConnectionManager_OriginalIpDetectionExtensionsMustEmpty() + throws ResourceInvalidException { + @SuppressWarnings("deprecation") + HttpConnectionManager hcm = HttpConnectionManager.newBuilder() + .addOriginalIpDetectionExtensions(TypedExtensionConfig.newBuilder().build()) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("HttpConnectionManager with original_ip_detection_extensions unsupported"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, false /* does not matter */, false); + } @Test public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index d8f1d8aa825..082c49ef665 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -41,6 +41,7 @@ import io.envoyproxy.envoy.type.matcher.v3.MetadataMatcher; import io.envoyproxy.envoy.type.matcher.v3.PathMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.Int32Range; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; @@ -109,6 +110,33 @@ public void ipPortParser() { assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); } + @Test + @SuppressWarnings({"unchecked", "deprecation"}) + public void portRangeParser() { + List permissionList = Arrays.asList( + Permission.newBuilder().setDestinationPortRange( + Int32Range.newBuilder().setStart(1010).setEnd(65535).build() + ).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setRemoteIp( + CidrRange.newBuilder().setAddressPrefix("10.10.10.0") + .setPrefixLen(UInt32Value.of(24)).build() + ).build()); + ConfigOrError result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + ServerCall serverCall = mock(ServerCall.class); + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new InetSocketAddress("10.10.10.0", 1)) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress("10.10.10.0",9090)) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(method().build()); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(((RbacConfig)result.config).authConfig()); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); + } + @Test @SuppressWarnings("unchecked") public void pathParser() { @@ -172,6 +200,21 @@ public void headerParser() { assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); } + @Test + @SuppressWarnings("deprecation") + public void headerParser_headerName() { + HeaderMatcher headerMatcher = HeaderMatcher.newBuilder() + .setName("grpc--feature").setExactMatch("win").build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setHeader(headerMatcher).build()); + HeaderMatcher headerMatcher2 = HeaderMatcher.newBuilder() + .setName(":scheme").setExactMatch("win").build(); + List principalList = Arrays.asList( + Principal.newBuilder().setHeader(headerMatcher2).build()); + ConfigOrError result = parseOverride(permissionList, principalList); + assertThat(result.errorDetail).isNotNull(); + } + @Test @SuppressWarnings("unchecked") public void compositeRules() { diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index c109bb44a13..f2b6e9e4790 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -865,6 +865,50 @@ public void run() { assertThat(status.getDescription()).isEqualTo("Could not find xDS route matching RPC"); } + @Test + @SuppressWarnings("unchecked") + public void interceptor_invalidRouteAction() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url", Route.RouteAction.forCluster( + "cluster", Collections.emptyList(), null, null + )); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder() + .set(ATTR_SERVER_ROUTING_CONFIG, new AtomicReference<>(routingConfig)).build()); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo("Invalid xDS route action for matching " + + "route: only Route.non_forwarding_action should be allowed."); + } + @Test @SuppressWarnings("unchecked") public void interceptor_failingRouterConfig() throws Exception { @@ -1104,15 +1148,20 @@ private static EnvoyServerProtoData.FilterChainMatch createMatch() { private static ServerRoutingConfig createRoutingConfig(String path, String domain, String filterType) { + return createRoutingConfig(path, domain, filterType, null); + } + + private static ServerRoutingConfig createRoutingConfig(String path, String domain, + String filterType, Route.RouteAction action) { RouteMatch routeMatch = - RouteMatch.create( - PathMatcher.fromPath(path, true), - Collections.emptyList(), null); + RouteMatch.create( + PathMatcher.fromPath(path, true), + Collections.emptyList(), null); VirtualHost virtualHost = VirtualHost.create( - "v1", Collections.singletonList(domain), - Arrays.asList(Route.forAction(routeMatch, null, - ImmutableMap.of())), - Collections.emptyMap()); + "v1", Collections.singletonList(domain), + Arrays.asList(Route.forAction(routeMatch, action, + ImmutableMap.of())), + Collections.emptyMap()); FilterConfig f0 = mock(FilterConfig.class); when(f0.typeUrl()).thenReturn(filterType); return ServerRoutingConfig.create(ImmutableList.of(virtualHost), diff --git a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java index 504c9e8df2a..626a4cfc275 100644 --- a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java @@ -16,12 +16,14 @@ package io.grpc.xds.internal.rbac.engine; +import static com.google.common.base.Charsets.US_ASCII; import static com.google.common.truth.Truth.assertThat; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.io.BaseEncoding; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; @@ -177,6 +179,71 @@ public void headerMatcher() { assertThat(decision.decision()).isEqualTo(Action.DENY); } + @Test + public void headerMatcher_binaryHeader() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(HEADER_KEY + Metadata.BINARY_HEADER_SUFFIX, + BaseEncoding.base64().omitPadding().encode(HEADER_VALUE.getBytes(US_ASCII)), false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of(HEADER_KEY + Metadata.BINARY_HEADER_SUFFIX, + Metadata.BINARY_BYTE_MARSHALLER), HEADER_VALUE.getBytes(US_ASCII)); + AuthDecision decision = engine.evaluate(metadata, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void headerMatcher_hardcodePostMethod() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(":method", "POST", false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void headerMatcher_pathHeader() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(":path", "/" + PATH, false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void headerMatcher_aliasAuthorityAndHost() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue("Host", "google.com", false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + when(serverCall.getAuthority()).thenReturn("google.com"); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + @Test public void pathMatcher() { PathMatcher pathMatcher = new PathMatcher(STRING_MATCHER); From 695104802d128011bf6a65bd5ccf64416bdf0807 Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Tue, 21 Sep 2021 10:01:46 -0700 Subject: [PATCH 80/82] xds: disable rbac by default (#8541) --- .../java/io/grpc/xds/ClientXdsClient.java | 6 ++-- .../java/io/grpc/xds/XdsServerWrapper.java | 28 ++++++++++--------- .../io/grpc/xds/ClientXdsClientDataTest.java | 3 +- .../io/grpc/xds/ClientXdsClientTestBase.java | 5 ++++ 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 3bef4c416e0..f39992c24ac 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -138,8 +138,8 @@ final class ClientXdsClient extends AbstractXdsClient { || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")); @VisibleForTesting static boolean enableRbac = - Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")) - || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")); + !Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")) + && Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")); private static final String TYPE_URL_HTTP_CONNECTION_MANAGER_V2 = "type.googleapis.com/envoy.config.filter.network.http_connection_manager.v2" @@ -222,7 +222,7 @@ protected void handleLdsResponse(String versionInfo, List resources, String listener, retainedRdsResources, enableFaultInjection && isResourceV3); } else { ldsUpdate = processServerSideListener( - listener, retainedRdsResources, enableRbac); + listener, retainedRdsResources, enableRbac && isResourceV3); } } catch (ResourceInvalidException e) { errors.add( diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index fdc5f099bfe..5f7cc43d670 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -515,20 +515,22 @@ private ImmutableMap generatePerRouteInterceptors( Map selectedOverrideConfigs = new HashMap<>(virtualHost.filterConfigOverrides()); selectedOverrideConfigs.putAll(route.filterConfigOverrides()); - for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { - FilterConfig filterConfig = namedFilterConfig.filterConfig; - Filter filter = filterRegistry.get(filterConfig.typeUrl()); - if (filter instanceof ServerInterceptorBuilder) { - ServerInterceptor interceptor = - ((ServerInterceptorBuilder) filter).buildServerInterceptor( - filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); - if (interceptor != null) { - filterInterceptors.add(interceptor); + if (namedFilterConfigs != null) { + for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { + FilterConfig filterConfig = namedFilterConfig.filterConfig; + Filter filter = filterRegistry.get(filterConfig.typeUrl()); + if (filter instanceof ServerInterceptorBuilder) { + ServerInterceptor interceptor = + ((ServerInterceptorBuilder) filter).buildServerInterceptor( + filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); + if (interceptor != null) { + filterInterceptors.add(interceptor); + } + } else { + logger.log(Level.WARNING, "HttpFilterConfig(type URL: " + + filterConfig.typeUrl() + ") is not supported on server-side. " + + "Probably a bug at ClientXdsClient verification."); } - } else { - logger.log(Level.WARNING, "HttpFilterConfig(type URL: " - + filterConfig.typeUrl() + ") is not supported on server-side. " - + "Probably a bug at ClientXdsClient verification."); } } ServerInterceptor interceptor = combineInterceptors(filterInterceptors); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 597ca7df2d9..17ca907da42 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -138,7 +138,8 @@ public void setUp() { originalEnableRetry = ClientXdsClient.enableRetry; assertThat(originalEnableRetry).isTrue(); originalEnableRbac = ClientXdsClient.enableRbac; - assertThat(originalEnableRbac).isTrue(); + assertThat(originalEnableRbac).isFalse(); + ClientXdsClient.enableRbac = true; } @After diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index e66c73163be..3a9ab23aa74 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -246,6 +246,7 @@ public long currentTimeNanos() { private ManagedChannel channel; private ClientXdsClient xdsClient; private boolean originalEnableFaultInjection; + private boolean originalEnableRbac; @Before public void setUp() throws IOException { @@ -258,6 +259,9 @@ public void setUp() throws IOException { // Start the server and the client. originalEnableFaultInjection = ClientXdsClient.enableFaultInjection; ClientXdsClient.enableFaultInjection = true; + originalEnableRbac = ClientXdsClient.enableRbac; + assertThat(originalEnableRbac).isFalse(); + ClientXdsClient.enableRbac = true; final String serverName = InProcessServerBuilder.generateName(); cleanupRule.register( InProcessServerBuilder @@ -297,6 +301,7 @@ public void setUp() throws IOException { @After public void tearDown() { ClientXdsClient.enableFaultInjection = originalEnableFaultInjection; + ClientXdsClient.enableRbac = originalEnableRbac; xdsClient.shutdown(); channel.shutdown(); // channel not owned by XdsClient assertThat(adsEnded.get()).isTrue(); From 227961e628bc7cbb18cd6a736ac4eb954b545be2 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 22 Sep 2021 10:07:29 -0700 Subject: [PATCH 81/82] Update README etc to reference 1.41.0 --- README.md | 30 ++++++++++++------------ cronet/README.md | 2 +- documentation/android-channel-builder.md | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 6611b0ef1af..e6c06bde236 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.40.1/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.40.1/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.41.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.41.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,17 +43,17 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.40.1 + 1.41.0 io.grpc grpc-protobuf - 1.40.1 + 1.41.0 io.grpc grpc-stub - 1.40.1 + 1.41.0 org.apache.tomcat @@ -65,23 +65,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -implementation 'io.grpc:grpc-netty-shaded:1.40.1' -implementation 'io.grpc:grpc-protobuf:1.40.1' -implementation 'io.grpc:grpc-stub:1.40.1' +implementation 'io.grpc:grpc-netty-shaded:1.41.0' +implementation 'io.grpc:grpc-protobuf:1.41.0' +implementation 'io.grpc:grpc-stub:1.41.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.40.1' -implementation 'io.grpc:grpc-protobuf-lite:1.40.1' -implementation 'io.grpc:grpc-stub:1.40.1' +implementation 'io.grpc:grpc-okhttp:1.41.0' +implementation 'io.grpc:grpc-protobuf-lite:1.41.0' +implementation 'io.grpc:grpc-stub:1.41.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.40.1 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.41.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -113,7 +113,7 @@ For protobuf-based codegen integrated with the Maven build system, you can use com.google.protobuf:protoc:3.17.3:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.40.1:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.41.0:exe:${os.detected.classifier} @@ -143,7 +143,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.40.1' + artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' } } generateProtoTasks { @@ -176,7 +176,7 @@ protobuf { } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.40.1' + artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' } } generateProtoTasks { diff --git a/cronet/README.md b/cronet/README.md index c982604bdac..8b220bd606d 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.40.1' +implementation 'io.grpc:grpc-cronet:1.41.0' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index 113b20159b9..60e3bb35a85 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.40.1' -implementation 'io.grpc:grpc-okhttp:1.40.1' +implementation 'io.grpc:grpc-android:1.41.0' +implementation 'io.grpc:grpc-okhttp:1.41.0' ``` You also need permission to access the device's network state in your From d291594443db7d43a6e984a1ae0c58883c7a1921 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 22 Sep 2021 10:16:50 -0700 Subject: [PATCH 82/82] Bump version to 1.41.0 --- build.gradle | 2 +- .../src/test/golden/TestDeprecatedService.java.txt | 2 +- compiler/src/test/golden/TestService.java.txt | 2 +- .../src/testLite/golden/TestDeprecatedService.java.txt | 2 +- compiler/src/testLite/golden/TestService.java.txt | 2 +- core/src/main/java/io/grpc/internal/GrpcUtil.java | 2 +- examples/android/clientcache/app/build.gradle | 10 +++++----- examples/android/helloworld/app/build.gradle | 8 ++++---- examples/android/routeguide/app/build.gradle | 8 ++++---- examples/android/strictmode/app/build.gradle | 8 ++++---- examples/build.gradle | 2 +- examples/example-alts/build.gradle | 2 +- examples/example-gauth/build.gradle | 2 +- examples/example-gauth/pom.xml | 4 ++-- examples/example-hostname/build.gradle | 2 +- examples/example-hostname/pom.xml | 4 ++-- examples/example-jwt-auth/build.gradle | 2 +- examples/example-jwt-auth/pom.xml | 4 ++-- examples/example-tls/build.gradle | 2 +- examples/example-tls/pom.xml | 4 ++-- examples/example-xds/build.gradle | 2 +- examples/pom.xml | 4 ++-- 22 files changed, 40 insertions(+), 40 deletions(-) diff --git a/build.gradle b/build.gradle index 6c099e0cf39..001ebd148c0 100644 --- a/build.gradle +++ b/build.gradle @@ -18,7 +18,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.41.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.41.0" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index 018849586de..2beed7b2b7f 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 18af83a9119..ba2c37f4b81 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 30f22366765..72d1b428efb 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index 38626900571..bc1d50acecc 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.41.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 54f6d2f41d5..12ae8954ce5 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -202,7 +202,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - private static final String IMPLEMENTATION_VERSION = "1.41.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + private static final String IMPLEMENTATION_VERSION = "1.41.0"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index fab0934405e..9b5ef9d448e 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -34,7 +34,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.12' testImplementation 'com.google.truth:truth:1.0.1' - testImplementation 'io.grpc:grpc-testing:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.41.0' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 1fddbd6d481..8a3ed89300d 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index 250b10c3653..4b2d3989e2c 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index f68e8584ef0..e18273cbadc 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/build.gradle b/examples/build.gradle index 03967a41c0d..b21a15f8443 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index b1265e89440..e08d8fd1ce5 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protocVersion = '3.17.2' dependencies { diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index f7f332e6962..0e7d7ece1f0 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 1cc4c3ba6a7..91849437181 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,13 +6,13 @@ jar - 1.41.0-SNAPSHOT + 1.41.0 example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.41.0-SNAPSHOT + 1.41.0 3.17.2 1.7 diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index 44048a78e34..23779a52a2b 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,7 +21,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' dependencies { diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 9af512c3952..1df16a9acbe 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,13 +6,13 @@ jar - 1.41.0-SNAPSHOT + 1.41.0 example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.41.0-SNAPSHOT + 1.41.0 3.17.2 1.7 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index cf59da51d4a..762d8e08f3e 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index c4ae09cda90..bbe496f0c90 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,13 +7,13 @@ jar - 1.41.0-SNAPSHOT + 1.41.0 example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.41.0-SNAPSHOT + 1.41.0 3.17.2 3.17.2 diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 61f13e050de..f60b54c146a 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protocVersion = '3.17.2' dependencies { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index d83a0937725..67d6ff1a507 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,13 +6,13 @@ jar - 1.41.0-SNAPSHOT + 1.41.0 example-tls https://github.com/grpc/grpc-java UTF-8 - 1.41.0-SNAPSHOT + 1.41.0 3.17.2 2.0.34.Final diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 01ef4ba9266..47151427c8a 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.41.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def nettyTcNativeVersion = '2.0.31.Final' def protocVersion = '3.17.2' diff --git a/examples/pom.xml b/examples/pom.xml index 156b11fb7ac..a73da5d3bdf 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,13 +6,13 @@ jar - 1.41.0-SNAPSHOT + 1.41.0 examples https://github.com/grpc/grpc-java UTF-8 - 1.41.0-SNAPSHOT + 1.41.0 3.17.2 3.17.2