diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 5ff2c5157b5..5426f83f90b 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,11 +8,10 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [creamsoup](https://github.com/creamsoup), Google LLC + - [dapengzhang0](https://github.com/dapengzhang0), Google LLC - [ejona86](https://github.com/ejona86), Google LLC - [ericgribkoff](https://github.com/ericgribkoff), Google LLC -- [jiangtaoli2016](https://github.com/jiangtaoli2016), Google LLC - [ran-su](https://github.com/ran-su), Google LLC - [sanjaypujare](https://github.com/sanjaypujare), Google LLC - [sergiitk](https://github.com/sergiitk), Google LLC @@ -22,6 +21,8 @@ for general contribution guidelines. ## Emeritus Maintainers (in alphabetical order) - [carl-mastrangelo](https://github.com/carl-mastrangelo), Google LLC +- [creamsoup](https://github.com/creamsoup), Google LLC +- [jiangtaoli2016](https://github.com/jiangtaoli2016), Google LLC - [jtattermusch](https://github.com/jtattermusch), Google LLC - [louiscryan](https://github.com/louiscryan), Google LLC - [nicolasnoble](https://github.com/nicolasnoble), Google LLC diff --git a/README.md b/README.md index cde4222ba7b..e6c06bde236 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.39.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.39.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.41.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.41.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,17 +43,17 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.39.0 + 1.41.0 io.grpc grpc-protobuf - 1.39.0 + 1.41.0 io.grpc grpc-stub - 1.39.0 + 1.41.0 org.apache.tomcat @@ -65,23 +65,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -implementation 'io.grpc:grpc-netty-shaded:1.39.0' -implementation 'io.grpc:grpc-protobuf:1.39.0' -implementation 'io.grpc:grpc-stub:1.39.0' +implementation 'io.grpc:grpc-netty-shaded:1.41.0' +implementation 'io.grpc:grpc-protobuf:1.41.0' +implementation 'io.grpc:grpc-stub:1.41.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.39.0' -implementation 'io.grpc:grpc-protobuf-lite:1.39.0' -implementation 'io.grpc:grpc-stub:1.39.0' +implementation 'io.grpc:grpc-okhttp:1.41.0' +implementation 'io.grpc:grpc-protobuf-lite:1.41.0' +implementation 'io.grpc:grpc-stub:1.41.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.39.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.41.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -111,9 +111,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.17.2:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.17.3:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.39.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.41.0:exe:${os.detected.classifier} @@ -139,11 +139,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.17.2" + artifact = "com.google.protobuf:protoc:3.17.3" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' } } generateProtoTasks { @@ -172,11 +172,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.17.2" + artifact = "com.google.protobuf:protoc:3.17.3" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' } } generateProtoTasks { diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index f1c1d491233..bbf1fcfe99e 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -62,6 +62,7 @@ dependencies { project(':grpc-protobuf-lite'), project(':grpc-stub'), project(':grpc-testing'), + libraries.hdrhistogram, libraries.junit, libraries.truth, libraries.opencensus_contrib_grpc_metrics diff --git a/api/src/main/java/io/grpc/ClientStreamTracer.java b/api/src/main/java/io/grpc/ClientStreamTracer.java index 6259522487a..bb836ac82e1 100644 --- a/api/src/main/java/io/grpc/ClientStreamTracer.java +++ b/api/src/main/java/io/grpc/ClientStreamTracer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; -import io.grpc.Grpc; import javax.annotation.concurrent.ThreadSafe; /** @@ -28,6 +27,18 @@ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/2861") @ThreadSafe public abstract class ClientStreamTracer extends StreamTracer { + + /** + * The stream is being created on a ready transport. + * + * @param headers the mutable initial metadata. Modifications to it will be sent to the socket but + * not be seen by client interceptors and the application. + * + * @since 1.40.0 + */ + public void streamCreated(@Grpc.TransportAttr Attributes transportAttrs, Metadata headers) { + } + /** * Headers has been sent to the socket. */ @@ -54,22 +65,6 @@ public void inboundTrailers(Metadata trailers) { * Factory class for {@link ClientStreamTracer}. */ public abstract static class Factory { - /** - * Creates a {@link ClientStreamTracer} for a new client stream. - * - * @param callOptions the effective CallOptions of the call - * @param headers the mutable headers of the stream. It can be safely mutated within this - * method. It should not be saved because it is not safe for read or write after the - * method returns. - * - * @deprecated use {@link - * #newClientStreamTracer(io.grpc.ClientStreamTracer.StreamInfo, io.grpc.Metadata)} instead. - */ - @Deprecated - public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadata headers) { - throw new UnsupportedOperationException("Not implemented"); - } - /** * Creates a {@link ClientStreamTracer} for a new client stream. This is called inside the * transport when it's creating the stream. @@ -81,12 +76,15 @@ public ClientStreamTracer newClientStreamTracer(CallOptions callOptions, Metadat * * @since 1.20.0 */ - @SuppressWarnings("deprecation") public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { - return newClientStreamTracer(info.getCallOptions(), headers); + throw new UnsupportedOperationException("Not implemented"); } } + /** An abstract class for internal use only. */ + @Internal + public abstract static class InternalLimitedInfoFactory extends Factory {} + /** * Information about a stream. * @@ -99,15 +97,25 @@ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata header public static final class StreamInfo { private final Attributes transportAttrs; private final CallOptions callOptions; + private final int previousAttempts; + private final boolean isTransparentRetry; - StreamInfo(Attributes transportAttrs, CallOptions callOptions) { + StreamInfo( + Attributes transportAttrs, CallOptions callOptions, int previousAttempts, + boolean isTransparentRetry) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs"); this.callOptions = checkNotNull(callOptions, "callOptions"); + this.previousAttempts = previousAttempts; + this.isTransparentRetry = isTransparentRetry; } /** * Returns the attributes of the transport that this stream was created on. + * + * @deprecated Use {@link ClientStreamTracer#streamCreated(Attributes, Metadata)} to handle + * the transport Attributes instead. */ + @Deprecated @Grpc.TransportAttr public Attributes getTransportAttrs() { return transportAttrs; @@ -120,16 +128,35 @@ public CallOptions getCallOptions() { return callOptions; } + /** + * Returns the number of preceding attempts for the RPC. + * + * @since 1.40.0 + */ + public int getPreviousAttempts() { + return previousAttempts; + } + + /** + * Whether the stream is a transparent retry. + * + * @since 1.40.0 + */ + public boolean isTransparentRetry() { + return isTransparentRetry; + } + /** * Converts this StreamInfo into a new Builder. * * @since 1.21.0 */ public Builder toBuilder() { - Builder builder = new Builder(); - builder.setTransportAttrs(transportAttrs); - builder.setCallOptions(callOptions); - return builder; + return new Builder() + .setCallOptions(callOptions) + .setTransportAttrs(transportAttrs) + .setPreviousAttempts(previousAttempts) + .setIsTransparentRetry(isTransparentRetry); } /** @@ -146,6 +173,8 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("transportAttrs", transportAttrs) .add("callOptions", callOptions) + .add("previousAttempts", previousAttempts) + .add("isTransparentRetry", isTransparentRetry) .toString(); } @@ -157,6 +186,8 @@ public String toString() { public static final class Builder { private Attributes transportAttrs = Attributes.EMPTY; private CallOptions callOptions = CallOptions.DEFAULT; + private int previousAttempts; + private boolean isTransparentRetry; Builder() { } @@ -164,9 +195,12 @@ public static final class Builder { /** * Sets the attributes of the transport that this stream was created on. This field is * optional. + * + * @deprecated Use {@link ClientStreamTracer#streamCreated(Attributes, Metadata)} to handle + * the transport Attributes instead. */ - @Grpc.TransportAttr - public Builder setTransportAttrs(Attributes transportAttrs) { + @Deprecated + public Builder setTransportAttrs(@Grpc.TransportAttr Attributes transportAttrs) { this.transportAttrs = checkNotNull(transportAttrs, "transportAttrs cannot be null"); return this; } @@ -179,11 +213,31 @@ public Builder setCallOptions(CallOptions callOptions) { return this; } + /** + * Set the number of preceding attempts of the RPC. + * + * @since 1.40.0 + */ + public Builder setPreviousAttempts(int previousAttempts) { + this.previousAttempts = previousAttempts; + return this; + } + + /** + * Sets whether the stream is a transparent retry. + * + * @since 1.40.0 + */ + public Builder setIsTransparentRetry(boolean isTransparentRetry) { + this.isTransparentRetry = isTransparentRetry; + return this; + } + /** * Builds a new StreamInfo. */ public StreamInfo build() { - return new StreamInfo(transportAttrs, callOptions); + return new StreamInfo(transportAttrs, callOptions, previousAttempts, isTransparentRetry); } } } diff --git a/api/src/main/java/io/grpc/ManagedChannelBuilder.java b/api/src/main/java/io/grpc/ManagedChannelBuilder.java index e4a4611541d..98b22807ccc 100644 --- a/api/src/main/java/io/grpc/ManagedChannelBuilder.java +++ b/api/src/main/java/io/grpc/ManagedChannelBuilder.java @@ -467,7 +467,6 @@ public T perRpcBufferLimit(long bytes) { * @return this * @since 1.11.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3982") public T disableRetry() { throw new UnsupportedOperationException(); } @@ -479,13 +478,9 @@ public T disableRetry() { * transparent retries, which are safe for non-idempotent RPCs. Service config is ideally provided * by the name resolver, but may also be specified via {@link #defaultServiceConfig}. * - *

For the current release, this method may have a side effect that disables Census stats and - * tracing. - * * @return this * @since 1.11.0 */ - @ExperimentalApi("https://github.com/grpc/grpc-java/issues/3982") public T enableRetry() { throw new UnsupportedOperationException(); } diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index cd3a137dfca..f4c05aa6a64 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -50,7 +50,9 @@ *

Implementations don't need to be thread-safe. All methods are guaranteed to * be called sequentially. Additionally, all methods that have side-effects, i.e., * {@link #start(Listener2)}, {@link #shutdown} and {@link #refresh} are called from the same - * {@link SynchronizationContext} as returned by {@link Args#getSynchronizationContext}. + * {@link SynchronizationContext} as returned by {@link Args#getSynchronizationContext}. Do + * not block within the synchronization context; blocking I/O and time-consuming tasks + * should be offloaded to a separate thread, generally {@link Args#getOffloadExecutor}. * * @since 1.0.0 */ diff --git a/api/src/main/java/io/grpc/ProxyDetector.java b/api/src/main/java/io/grpc/ProxyDetector.java index 5202516bca7..7c04329c079 100644 --- a/api/src/main/java/io/grpc/ProxyDetector.java +++ b/api/src/main/java/io/grpc/ProxyDetector.java @@ -32,7 +32,7 @@ * underlying transport need to work together. * *

The {@link NameResolver} should invoke the {@link ProxyDetector} retrieved from the {@link - * NameResolver.Helper#getProxyDetector}, and pass the returned {@link ProxiedSocketAddress} to + * NameResolver.Args#getProxyDetector}, and pass the returned {@link ProxiedSocketAddress} to * {@link NameResolver.Listener#onAddresses}. The DNS name resolver shipped with gRPC is already * doing so. * diff --git a/api/src/main/java/io/grpc/Server.java b/api/src/main/java/io/grpc/Server.java index 781455b18ee..31e0a6478ed 100644 --- a/api/src/main/java/io/grpc/Server.java +++ b/api/src/main/java/io/grpc/Server.java @@ -43,7 +43,7 @@ public abstract class Server { * listening socket(s). * * @return {@code this} object - * @throws IllegalStateException if already started + * @throws IllegalStateException if already started or shut down * @throws IOException if unable to bind * @since 1.0.0 */ @@ -119,6 +119,9 @@ public List getMutableServices() { * {@link #awaitTermination()} or {@link #awaitTermination(long, TimeUnit)} needs to be called to * wait for existing calls to finish. * + *

Calling this method before {@code start()} will shut down and terminate the server like + * normal, but prevents starting the server in the future. + * * @return {@code this} object * @since 1.0.0 */ @@ -130,6 +133,9 @@ public List getMutableServices() { * return {@code false} immediately after this method returns. After this call returns, this * server has released the listening socket(s) and may be reused by another server. * + *

Calling this method before {@code start()} will shut down and terminate the server like + * normal, but prevents starting the server in the future. + * * @return {@code this} object * @since 1.0.0 */ @@ -157,6 +163,9 @@ public List getMutableServices() { /** * Waits for the server to become terminated, giving up if the timeout is reached. * + *

Calling this method before {@code start()} or {@code shutdown()} is permitted and does not + * change its behavior. + * * @return whether the server is terminated, as would be done by {@link #isTerminated()}. */ public abstract boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException; @@ -164,6 +173,9 @@ public List getMutableServices() { /** * Waits for the server to become terminated. * + *

Calling this method before {@code start()} or {@code shutdown()} is permitted and does not + * change its behavior. + * * @since 1.0.0 */ public abstract void awaitTermination() throws InterruptedException; diff --git a/api/src/test/java/io/grpc/CallOptionsTest.java b/api/src/test/java/io/grpc/CallOptionsTest.java index 31861306891..0bc0d357358 100644 --- a/api/src/test/java/io/grpc/CallOptionsTest.java +++ b/api/src/test/java/io/grpc/CallOptionsTest.java @@ -30,6 +30,7 @@ import static org.mockito.Mockito.mock; import com.google.common.base.Objects; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.internal.SerializingExecutor; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; @@ -271,7 +272,7 @@ public void increment(long period, TimeUnit unit) { } } - private static class FakeTracerFactory extends ClientStreamTracer.Factory { + private static class FakeTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { final String name; FakeTracerFactory(String name) { @@ -279,8 +280,7 @@ private static class FakeTracerFactory extends ClientStreamTracer.Factory { } @Override - public ClientStreamTracer newClientStreamTracer( - ClientStreamTracer.StreamInfo info, Metadata headers) { + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { return new ClientStreamTracer() {}; } diff --git a/binder/build.gradle b/binder/build.gradle index 537c23a0092..c5bb9885623 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -13,6 +13,8 @@ android { srcDirs += "${projectDir}/../core/src/test/java/" setIncludes(["io/grpc/internal/FakeClock.java", "io/grpc/binder/**"]) + exclude 'io/grpc/binder/ServerSecurityPolicyTest.java' + exclude 'io/grpc/binder/SecurityPoliciesTest.java' } } androidTest { diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index 0d3c3bf4b51..b99114bb501 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -30,6 +30,7 @@ import com.google.protobuf.Empty; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; @@ -68,6 +69,10 @@ */ @RunWith(AndroidJUnit4.class) public final class BinderClientTransportTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; + private final Context appContext = ApplicationProvider.getApplicationContext(); MethodDescriptor.Marshaller marshaller = @@ -155,7 +160,8 @@ public void tearDown() throws Exception { @Test public void testShutdownBeforeStreamStart_b153326034() throws Exception { - ClientStream stream = transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); // This shouldn't throw an exception. @@ -165,7 +171,7 @@ public void testShutdownBeforeStreamStart_b153326034() throws Exception { @Test public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception { ClientStream stream = - transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); @@ -183,7 +189,7 @@ public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception @Test public void testTransactionForDiscardedCall_b155244043() throws Exception { ClientStream stream = - transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); @@ -201,7 +207,7 @@ public void testTransactionForDiscardedCall_b155244043() throws Exception { @Test public void testBadTransactionStreamThroughput_b163053382() throws Exception { ClientStream stream = - transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); @@ -220,7 +226,7 @@ public void testBadTransactionStreamThroughput_b163053382() throws Exception { @Test public void testMessageProducerClosedAfterStream_b169313545() { ClientStream stream = - transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT); + transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 04070ddfcef..b132844069c 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -32,6 +32,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.Internal; import io.grpc.InternalChannelz.SocketStats; @@ -632,28 +633,28 @@ public synchronized Runnable start(ManagedClientTransport.Listener clientTranspo public synchronized ClientStream newStream( final MethodDescriptor method, final Metadata headers, - final CallOptions callOptions) { + final CallOptions callOptions, + ClientStreamTracer[] tracers) { if (isShutdown()) { - return newFailingClientStream(shutdownStatus, callOptions, attributes, headers); + return newFailingClientStream(shutdownStatus, attributes, headers, tracers); } else { int callId = latestCallId++; if (latestCallId == LAST_CALL_ID) { latestCallId = FIRST_CALL_ID; } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, attributes, headers); Inbound.ClientInbound inbound = new Inbound.ClientInbound( this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); if (ongoingCalls.putIfAbsent(callId, inbound) != null) { Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); shutdownInternal(failure, true); - return newFailingClientStream(failure, callOptions, attributes, headers); + return newFailingClientStream(failure, attributes, headers, tracers); } else { if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { clientTransportListener.transportInUse(true); } - StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(callOptions, attributes, headers); - Outbound.ClientOutbound outbound = new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); if (method.getType().clientSendsOneMessage()) { @@ -763,12 +764,12 @@ protected void handlePingResponse(Parcel parcel) { } private static ClientStream newFailingClientStream( - Status failure, CallOptions callOptions, Attributes attributes, Metadata headers) { + Status failure, Attributes attributes, Metadata headers, + ClientStreamTracer[] tracers) { StatsTraceContext statsTraceContext = - StatsTraceContext.newClientContext(callOptions, attributes, headers); + StatsTraceContext.newClientContext(tracers, attributes, headers); statsTraceContext.clientOutboundHeaders(); - statsTraceContext.streamClosed(failure); - return new FailingClientStream(failure); + return new FailingClientStream(failure, tracers); } private static InternalLogId buildLogId( diff --git a/build.gradle b/build.gradle index 94590cfa97c..001ebd148c0 100644 --- a/build.gradle +++ b/build.gradle @@ -18,7 +18,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.40.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.41.0" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() @@ -145,9 +145,9 @@ subprojects { animalsniffer_annotations: "org.codehaus.mojo:animal-sniffer-annotations:1.19", autovalue: "com.google.auto.value:auto-value:${autovalueVersion}", autovalue_annotation: "com.google.auto.value:auto-value-annotations:${autovalueVersion}", - errorprone: "com.google.errorprone:error_prone_annotations:2.4.0", - cronet_api: 'org.chromium.net:cronet-api:76.3809.111', - cronet_embedded: 'org.chromium.net:cronet-embedded:66.3359.158', + errorprone: "com.google.errorprone:error_prone_annotations:2.9.0", + cronet_api: 'org.chromium.net:cronet-api:92.4515.131', + cronet_embedded: 'org.chromium.net:cronet-embedded:92.4515.131', gson: "com.google.code.gson:gson:2.8.6", guava: "com.google.guava:guava:${guavaVersion}", javax_annotation: 'org.apache.tomcat:annotations-api:6.0.53', @@ -269,11 +269,7 @@ subprojects { jar.manifest { attributes('Implementation-Title': name, - 'Implementation-Version': version, - 'Built-By': System.getProperty('user.name'), - 'Built-JDK': System.getProperty('java.version'), - 'Source-Compatibility': sourceCompatibility, - 'Target-Compatibility': targetCompatibility) + 'Implementation-Version': version) } javadoc.options { diff --git a/buildscripts/kokoro/android-interop.sh b/buildscripts/kokoro/android-interop.sh index 8a8a2bc7bc5..5d9774bb12f 100755 --- a/buildscripts/kokoro/android-interop.sh +++ b/buildscripts/kokoro/android-interop.sh @@ -38,18 +38,3 @@ gcloud firebase test android run \ --device model=Nexus6P,version=23,locale=en,orientation=portrait \ --device model=Nexus6,version=22,locale=en,orientation=portrait \ --device model=Nexus6,version=21,locale=en,orientation=portrait - -# Build and run binder transport instrumentation tests on Firebase Test Lab -cd ../binder -../gradlew assembleDebug -../gradlew assembleDebugAndroidTest -gcloud firebase test android run \ - --type instrumentation \ - --test build/outputs/apk/androidTest/debug/grpc-binder-debug-androidTest.apk \ - --device model=Nexus6P,version=27,locale=en,orientation=portrait \ - --device model=Nexus6P,version=26,locale=en,orientation=portrait \ - --device model=Nexus6P,version=25,locale=en,orientation=portrait \ - --device model=Nexus6P,version=24,locale=en,orientation=portrait \ - --device model=Nexus6P,version=23,locale=en,orientation=portrait \ - --device model=Nexus6,version=22,locale=en,orientation=portrait \ - --device model=Nexus6,version=21,locale=en,orientation=portrait diff --git a/buildscripts/kokoro/android.sh b/buildscripts/kokoro/android.sh index 50337ef878b..7b9e7f53885 100755 --- a/buildscripts/kokoro/android.sh +++ b/buildscripts/kokoro/android.sh @@ -18,8 +18,9 @@ export OS_NAME=$(uname) cat <> gradle.properties # defaults to -Xmx512m -XX:MaxMetaspaceSize=256m # https://docs.gradle.org/current/userguide/build_environment.html#sec:configuring_jvm_memory -# Increased due to java.lang.OutOfMemoryError: Metaspace failures -org.gradle.jvmargs=-Xmx512m -XX:MaxMetaspaceSize=512m +# Increased due to java.lang.OutOfMemoryError: Metaspace failures, "JVM heap +# space is exhausted", and to increase build speed +org.gradle.jvmargs=-Xmx2048m -XX:MaxMetaspaceSize=512m EOF echo y | ${ANDROID_HOME}/tools/bin/sdkmanager "build-tools;28.0.3" @@ -31,6 +32,8 @@ buildscripts/make_dependencies.sh :grpc-android-interop-testing:build \ :grpc-android:build \ :grpc-cronet:build \ + :grpc-binder:build \ + assembleAndroidTest \ publishToMavenLocal if [[ ! -z $(git status --porcelain) ]]; then diff --git a/buildscripts/kokoro/xds-k8s.sh b/buildscripts/kokoro/xds-k8s.sh index cafd884ccaf..0a234f2c6ef 100755 --- a/buildscripts/kokoro/xds-k8s.sh +++ b/buildscripts/kokoro/xds-k8s.sh @@ -54,10 +54,16 @@ build_test_app_docker_images() { cp -v "${docker_dir}/"*.Dockerfile "${build_dir}" cp -v "${docker_dir}/"*.properties "${build_dir}" cp -rv "${SRC_DIR}/${BUILD_APP_PATH}" "${build_dir}" + # Pick a branch name for the built image + if [[ -n $KOKORO_JOB_NAME ]]; then + branch_name=$(echo "$KOKORO_JOB_NAME" | sed -E 's|^grpc/java/([^/]+)/.*|\1|') + else + branch_name='experimental' + fi # Run Google Cloud Build gcloud builds submit "${build_dir}" \ --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT}" + --substitutions "_SERVER_IMAGE_NAME=${SERVER_IMAGE_NAME},_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=${branch_name}" # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x # TODO(sergiitk): do this when adding support for custom configs per version } diff --git a/buildscripts/kokoro/xds_url_map.cfg b/buildscripts/kokoro/xds_url_map.cfg index 36ff8398b0c..4b5be84f880 100644 --- a/buildscripts/kokoro/xds_url_map.cfg +++ b/buildscripts/kokoro/xds_url_map.cfg @@ -2,7 +2,7 @@ # Location of the continuous shell script in repository. build_file: "grpc-java/buildscripts/kokoro/xds_url_map.sh" -timeout_mins: 60 +timeout_mins: 90 action { define_artifacts { diff --git a/buildscripts/kokoro/xds_url_map.sh b/buildscripts/kokoro/xds_url_map.sh index cbb1552835b..d8487582980 100755 --- a/buildscripts/kokoro/xds_url_map.sh +++ b/buildscripts/kokoro/xds_url_map.sh @@ -4,8 +4,8 @@ set -eo pipefail # Constants readonly GITHUB_REPOSITORY_NAME="grpc-java" # GKE Cluster -readonly GKE_CLUSTER_NAME="interop-test-psm-sec-v2-us-central1-a" -readonly GKE_CLUSTER_ZONE="us-central1-a" +readonly GKE_CLUSTER_NAME="interop-test-psm-basic" +readonly GKE_CLUSTER_ZONE="us-central1-c" ## xDS test client Docker images readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-client" readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" @@ -54,7 +54,7 @@ build_test_app_docker_images() { # Run Google Cloud Build gcloud builds submit "${build_dir}" \ --config "${docker_dir}/cloudbuild.yaml" \ - --substitutions "_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT}" + --substitutions "_CLIENT_IMAGE_NAME=${CLIENT_IMAGE_NAME},COMMIT_SHA=${GIT_COMMIT},BRANCH_NAME=experimental" # TODO(sergiitk): extra "cosmetic" tags for versioned branches, e.g. v1.34.x # TODO(sergiitk): do this when adding support for custom configs per version } diff --git a/buildscripts/xds-k8s/cloudbuild.yaml b/buildscripts/xds-k8s/cloudbuild.yaml index 03c57489214..577ed73ce58 100644 --- a/buildscripts/xds-k8s/cloudbuild.yaml +++ b/buildscripts/xds-k8s/cloudbuild.yaml @@ -3,6 +3,7 @@ steps: args: - 'build' - '--tag=${_SERVER_IMAGE_NAME}:${COMMIT_SHA}' + - '--tag=${_SERVER_IMAGE_NAME}:${BRANCH_NAME}' - '--file=test-server.Dockerfile' - '.' @@ -10,6 +11,7 @@ steps: args: - 'build' - '--tag=${_CLIENT_IMAGE_NAME}:${COMMIT_SHA}' + - '--tag=${_CLIENT_IMAGE_NAME}:${BRANCH_NAME}' - '--file=test-client.Dockerfile' - '.' @@ -19,4 +21,6 @@ substitutions: images: - '${_SERVER_IMAGE_NAME}:${COMMIT_SHA}' + - '${_SERVER_IMAGE_NAME}:${BRANCH_NAME}' - '${_CLIENT_IMAGE_NAME}:${COMMIT_SHA}' + - '${_CLIENT_IMAGE_NAME}:${BRANCH_NAME}' diff --git a/census/src/main/java/io/grpc/census/CensusStatsModule.java b/census/src/main/java/io/grpc/census/CensusStatsModule.java index d625a6f5c6f..de860d0854c 100644 --- a/census/src/main/java/io/grpc/census/CensusStatsModule.java +++ b/census/src/main/java/io/grpc/census/CensusStatsModule.java @@ -17,26 +17,30 @@ package io.grpc.census; import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; import com.google.common.base.Supplier; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Context; +import io.grpc.Deadline; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.StreamTracer; import io.grpc.census.internal.DeprecatedCensusConstants; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; +import io.opencensus.stats.Measure; import io.opencensus.stats.Measure.MeasureDouble; import io.opencensus.stats.Measure.MeasureLong; import io.opencensus.stats.MeasureMap; @@ -50,19 +54,22 @@ import io.opencensus.tags.propagation.TagContextSerializationException; import io.opencensus.tags.unsafe.ContextUtils; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; /** * Provides factories for {@link StreamTracer} that records stats to Census. * - *

On the client-side, a factory is created for each call, because ClientCall starts earlier than - * the ClientStream, and in some cases may even not create a ClientStream at all. Therefore, it's - * the factory that reports the summary to Census. + *

On the client-side, a factory is created for each call, and the factory creates a stream + * tracer for each attempt. If there is no stream created when the call is ended, we still create a + * tracer. It's the tracer that reports per-attempt stats, and the factory that reports the stats + * of the overall RPC, such as RETRIES_PER_CALL, to Census. * *

On the server-side, there is only one ServerStream per each ServerCall, and ServerStream * starts earlier than the ServerCall. Therefore, only one tracer is created per stream/call and @@ -138,15 +145,6 @@ public TagContext parseBytes(byte[] serialized) { }); } - /** - * Creates a {@link ClientCallTracer} for a new call. - */ - @VisibleForTesting - ClientCallTracer newClientCallTracer( - TagContext parentCtx, String fullMethodName) { - return new ClientCallTracer(this, parentCtx, fullMethodName); - } - /** * Returns the server tracer factory. */ @@ -176,7 +174,6 @@ private void recordRealTimeMetric(TagContext ctx, MeasureLong measure, long valu } private static final class ClientTracer extends ClientStreamTracer { - @Nullable private static final AtomicLongFieldUpdater outboundMessageCountUpdater; @Nullable private static final AtomicLongFieldUpdater inboundMessageCountUpdater; @Nullable private static final AtomicLongFieldUpdater outboundWireSizeUpdater; @@ -230,19 +227,41 @@ private static final class ClientTracer extends ClientStreamTracer { inboundUncompressedSizeUpdater = tmpInboundUncompressedSizeUpdater; } - private final CensusStatsModule module; - private final TagContext startCtx; - + final Stopwatch stopwatch; + final CallAttemptsTracerFactory attemptsState; + final AtomicBoolean inboundReceivedOrClosed = new AtomicBoolean(); + final CensusStatsModule module; + final TagContext parentCtx; + final TagContext startCtx; + final StreamInfo info; volatile long outboundMessageCount; volatile long inboundMessageCount; volatile long outboundWireSize; volatile long inboundWireSize; volatile long outboundUncompressedSize; volatile long inboundUncompressedSize; + long roundtripNanos; + Code statusCode; + + ClientTracer( + CallAttemptsTracerFactory attemptsState, CensusStatsModule module, TagContext parentCtx, + TagContext startCtx, StreamInfo info) { + this.attemptsState = attemptsState; + this.module = module; + this.parentCtx = parentCtx; + this.startCtx = startCtx; + this.info = info; + this.stopwatch = module.stopwatchSupplier.get().start(); + } - ClientTracer(CensusStatsModule module, TagContext startCtx) { - this.module = checkNotNull(module, "module"); - this.startCtx = checkNotNull(startCtx, "startCtx"); + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + if (module.propagateTags) { + headers.discardAll(module.statsHeader); + if (!module.tagger.empty().equals(parentCtx)) { + headers.put(module.statsHeader, parentCtx); + } + } } @Override @@ -292,6 +311,11 @@ public void inboundUncompressedSize(long bytes) { @Override @SuppressWarnings("NonAtomicVolatileUpdate") public void inboundMessage(int seqNo) { + if (inboundReceivedOrClosed.compareAndSet(false, true)) { + // Because inboundUncompressedSize() might be called after streamClosed(), + // we will report stats in callEnded(). Note that this attempt is already committed. + attemptsState.inboundMetricTracer = this; + } if (inboundMessageCountUpdater != null) { inboundMessageCountUpdater.getAndIncrement(this); } else { @@ -312,55 +336,109 @@ public void outboundMessage(int seqNo) { module.recordRealTimeMetric( startCtx, RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1); } - } - @VisibleForTesting - static final class ClientCallTracer extends ClientStreamTracer.Factory { - @Nullable - private static final AtomicReferenceFieldUpdater - streamTracerUpdater; - - @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Override + public void streamClosed(Status status) { + attemptsState.attemptEnded(); + stopwatch.stop(); + roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); + Deadline deadline = info.getCallOptions().getDeadline(); + statusCode = status.getCode(); + if (statusCode == Status.Code.CANCELLED && deadline != null) { + // When the server's deadline expires, it can only reset the stream with CANCEL and no + // description. Since our timer may be delayed in firing, we double-check the deadline and + // turn the failure into the likely more helpful DEADLINE_EXCEEDED status. + if (deadline.isExpired()) { + statusCode = Code.DEADLINE_EXCEEDED; + } + } + if (inboundReceivedOrClosed.compareAndSet(false, true)) { + if (module.recordFinishedRpcs) { + // Stream is closed early. So no need to record metrics for any inbound events after this + // point. + recordFinishedAttempt(); + } + } // Otherwise will report stats in callEnded() to guarantee all inbound metrics are recorded. + } - /** - * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their - * JDK reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to - * (potentially racy) direct updates of the volatile variables. - */ - static { - AtomicReferenceFieldUpdater tmpStreamTracerUpdater; - AtomicIntegerFieldUpdater tmpCallEndedUpdater; - try { - tmpStreamTracerUpdater = - AtomicReferenceFieldUpdater.newUpdater( - ClientCallTracer.class, ClientTracer.class, "streamTracer"); - tmpCallEndedUpdater = - AtomicIntegerFieldUpdater.newUpdater(ClientCallTracer.class, "callEnded"); - } catch (Throwable t) { - logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); - tmpStreamTracerUpdater = null; - tmpCallEndedUpdater = null; + void recordFinishedAttempt() { + MeasureMap measureMap = module.statsRecorder.newMeasureMap() + // TODO(songya): remove the deprecated measure constants once they are completed removed. + .put(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT, 1) + // The latency is double value + .put( + DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, + roundtripNanos / NANOS_PER_MILLI) + .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT, outboundMessageCount) + .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT, inboundMessageCount) + .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES, outboundWireSize) + .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES, inboundWireSize) + .put( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES, + outboundUncompressedSize) + .put( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES, + inboundUncompressedSize); + if (statusCode != Code.OK) { + measureMap.put(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT, 1); } - streamTracerUpdater = tmpStreamTracerUpdater; - callEndedUpdater = tmpCallEndedUpdater; + TagValue statusTag = TagValue.create(statusCode.toString()); + measureMap.record( + module + .tagger + .toBuilder(startCtx) + .putLocal(RpcMeasureConstants.GRPC_CLIENT_STATUS, statusTag) + .build()); } + } + @VisibleForTesting + static final class CallAttemptsTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { + static final MeasureLong RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/retries_per_call", "Number of retries per call", "1"); + static final MeasureLong TRANSPARENT_RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/transparent_retries_per_call", "Transparent retries per call", "1"); + static final MeasureDouble RETRY_DELAY_PER_CALL = + Measure.MeasureDouble.create( + "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); + + ClientTracer inboundMetricTracer; private final CensusStatsModule module; private final Stopwatch stopwatch; - private volatile ClientTracer streamTracer; - private volatile int callEnded; + @GuardedBy("lock") + private boolean callEnded; private final TagContext parentCtx; private final TagContext startCtx; - - ClientCallTracer(CensusStatsModule module, TagContext parentCtx, String fullMethodName) { - this.module = checkNotNull(module); - this.parentCtx = checkNotNull(parentCtx); + private final String fullMethodName; + + // TODO(zdapeng): optimize memory allocation using AtomicFieldUpdater. + private final AtomicLong attemptsPerCall = new AtomicLong(); + private final AtomicLong transparentRetriesPerCall = new AtomicLong(); + // write happens before read + private Status status; + private final Object lock = new Object(); + // write @GuardedBy("lock") and happens before read + private long retryDelayNanos; + @GuardedBy("lock") + private int activeStreams; + @GuardedBy("lock") + private boolean finishedCallToBeRecorded; + + CallAttemptsTracerFactory( + CensusStatsModule module, TagContext parentCtx, String fullMethodName) { + this.module = checkNotNull(module, "module"); + this.parentCtx = checkNotNull(parentCtx, "parentCtx"); + this.fullMethodName = checkNotNull(fullMethodName, "fullMethodName"); + this.stopwatch = module.stopwatchSupplier.get(); TagValue methodTag = TagValue.create(fullMethodName); - this.startCtx = module.tagger.toBuilder(parentCtx) + startCtx = module.tagger.toBuilder(parentCtx) .putLocal(RpcMeasureConstants.GRPC_CLIENT_METHOD, methodTag) .build(); - this.stopwatch = module.stopwatchSupplier.get().start(); if (module.recordStartedRpcs) { + // Record here in case newClientStreamTracer() would never be called. module.statsRecorder.newMeasureMap() .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) .record(startCtx); @@ -368,81 +446,97 @@ static final class ClientCallTracer extends ClientStreamTracer.Factory { } @Override - public ClientStreamTracer newClientStreamTracer( - ClientStreamTracer.StreamInfo info, Metadata headers) { - ClientTracer tracer = new ClientTracer(module, startCtx); - // TODO(zhangkun83): Once retry or hedging is implemented, a ClientCall may start more than - // one streams. We will need to update this file to support them. - if (streamTracerUpdater != null) { - checkState( - streamTracerUpdater.compareAndSet(this, null, tracer), - "Are you creating multiple streams per call? This class doesn't yet support this case"); + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { + synchronized (lock) { + if (finishedCallToBeRecorded) { + // This can be the case when the called is cancelled but a retry attempt is created. + return new ClientStreamTracer() {}; + } + if (++activeStreams == 1 && stopwatch.isRunning()) { + stopwatch.stop(); + retryDelayNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); + } + } + if (module.recordStartedRpcs && attemptsPerCall.get() > 0) { + module.statsRecorder.newMeasureMap() + .put(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT, 1) + .record(startCtx); + } + if (info.isTransparentRetry()) { + transparentRetriesPerCall.incrementAndGet(); } else { - checkState( - streamTracer == null, - "Are you creating multiple streams per call? This class doesn't yet support this case"); - streamTracer = tracer; + attemptsPerCall.incrementAndGet(); } - if (module.propagateTags) { - headers.discardAll(module.statsHeader); - if (!module.tagger.empty().equals(parentCtx)) { - headers.put(module.statsHeader, parentCtx); + return new ClientTracer(this, module, parentCtx, startCtx, info); + } + + // Called whenever each attempt is ended. + void attemptEnded() { + if (!module.recordFinishedRpcs) { + return; + } + boolean shouldRecordFinishedCall = false; + synchronized (lock) { + if (--activeStreams == 0) { + stopwatch.start(); + if (callEnded && !finishedCallToBeRecorded) { + shouldRecordFinishedCall = true; + finishedCallToBeRecorded = true; + } } } - return tracer; + if (shouldRecordFinishedCall) { + recordFinishedCall(); + } } - /** - * Record a finished call and mark the current time as the end time. - * - *

Can be called from any thread without synchronization. Calling it the second time or more - * is a no-op. - */ void callEnded(Status status) { - if (callEndedUpdater != null) { - if (callEndedUpdater.getAndSet(this, 1) != 0) { + if (!module.recordFinishedRpcs) { + return; + } + this.status = status; + boolean shouldRecordFinishedCall = false; + synchronized (lock) { + if (callEnded) { + // TODO(https://github.com/grpc/grpc-java/issues/7921): this shouldn't happen return; } - } else { - if (callEnded != 0) { - return; + callEnded = true; + if (activeStreams == 0 && !finishedCallToBeRecorded) { + shouldRecordFinishedCall = true; + finishedCallToBeRecorded = true; } - callEnded = 1; } - if (!module.recordFinishedRpcs) { - return; + if (shouldRecordFinishedCall) { + recordFinishedCall(); } - stopwatch.stop(); - long roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); - ClientTracer tracer = streamTracer; - if (tracer == null) { - tracer = new ClientTracer(module, startCtx); + } + + void recordFinishedCall() { + if (attemptsPerCall.get() == 0) { + ClientTracer tracer = new ClientTracer(this, module, parentCtx, startCtx, null); + tracer.roundtripNanos = stopwatch.elapsed(TimeUnit.NANOSECONDS); + tracer.statusCode = status.getCode(); + tracer.recordFinishedAttempt(); + } else if (inboundMetricTracer != null) { + inboundMetricTracer.recordFinishedAttempt(); } - MeasureMap measureMap = module.statsRecorder.newMeasureMap() - // TODO(songya): remove the deprecated measure constants once they are completed removed. - .put(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT, 1) - // The latency is double value - .put( - DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY, - roundtripNanos / NANOS_PER_MILLI) - .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT, tracer.outboundMessageCount) - .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT, tracer.inboundMessageCount) - .put(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES, tracer.outboundWireSize) - .put(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES, tracer.inboundWireSize) - .put( - DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES, - tracer.outboundUncompressedSize) - .put( - DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES, - tracer.inboundUncompressedSize); - if (!status.isOk()) { - measureMap.put(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT, 1); + + long retriesPerCall = 0; + long attempts = attemptsPerCall.get(); + if (attempts > 0) { + retriesPerCall = attempts - 1; } + MeasureMap measureMap = module.statsRecorder.newMeasureMap() + .put(RETRIES_PER_CALL, retriesPerCall) + .put(TRANSPARENT_RETRIES_PER_CALL, transparentRetriesPerCall.get()) + .put(RETRY_DELAY_PER_CALL, retryDelayNanos / NANOS_PER_MILLI); + TagValue methodTag = TagValue.create(fullMethodName); TagValue statusTag = TagValue.create(status.getCode().toString()); measureMap.record( - module - .tagger - .toBuilder(startCtx) + module.tagger + .toBuilder(parentCtx) + .putLocal(RpcMeasureConstants.GRPC_CLIENT_METHOD, methodTag) .putLocal(RpcMeasureConstants.GRPC_CLIENT_STATUS, statusTag) .build()); } @@ -686,8 +780,8 @@ public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { // New RPCs on client-side inherit the tag context from the current Context. TagContext parentCtx = tagger.getCurrentTagContext(); - final ClientCallTracer tracerFactory = - newClientCallTracer(parentCtx, method.getFullMethodName()); + final CallAttemptsTracerFactory tracerFactory = new CallAttemptsTracerFactory( + CensusStatsModule.this, parentCtx, method.getFullMethodName()); ClientCall call = next.newCall(method, callOptions.withStreamTracerFactory(tracerFactory)); return new SimpleForwardingClientCall(call) { diff --git a/census/src/main/java/io/grpc/census/CensusTracingModule.java b/census/src/main/java/io/grpc/census/CensusTracingModule.java index fc35d89db55..08d5fe3ca97 100644 --- a/census/src/main/java/io/grpc/census/CensusTracingModule.java +++ b/census/src/main/java/io/grpc/census/CensusTracingModule.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -31,6 +32,7 @@ import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; import io.grpc.StreamTracer; +import io.opencensus.trace.AttributeValue; import io.opencensus.trace.BlankSpan; import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.MessageEvent; @@ -59,7 +61,8 @@ final class CensusTracingModule { private static final Logger logger = Logger.getLogger(CensusTracingModule.class.getName()); - @Nullable private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Nullable + private static final AtomicIntegerFieldUpdater callEndedUpdater; @Nullable private static final AtomicIntegerFieldUpdater streamClosedUpdater; @@ -69,11 +72,11 @@ final class CensusTracingModule { * (potentially racy) direct updates of the volatile variables. */ static { - AtomicIntegerFieldUpdater tmpCallEndedUpdater; + AtomicIntegerFieldUpdater tmpCallEndedUpdater; AtomicIntegerFieldUpdater tmpStreamClosedUpdater; try { tmpCallEndedUpdater = - AtomicIntegerFieldUpdater.newUpdater(ClientCallTracer.class, "callEnded"); + AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); tmpStreamClosedUpdater = AtomicIntegerFieldUpdater.newUpdater(ServerTracer.class, "streamClosed"); } catch (Throwable t) { @@ -115,11 +118,12 @@ public SpanContext parseBytes(byte[] serialized) { } /** - * Creates a {@link ClientCallTracer} for a new call. + * Creates a {@link CallAttemptsTracerFactory} for a new call. */ @VisibleForTesting - ClientCallTracer newClientCallTracer(@Nullable Span parentSpan, MethodDescriptor method) { - return new ClientCallTracer(parentSpan, method); + CallAttemptsTracerFactory newClientCallTracer( + @Nullable Span parentSpan, MethodDescriptor method) { + return new CallAttemptsTracerFactory(parentSpan, method); } /** @@ -222,19 +226,21 @@ private static void recordMessageEvent( } @VisibleForTesting - final class ClientCallTracer extends ClientStreamTracer.Factory { + final class CallAttemptsTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { volatile int callEnded; private final boolean isSampledToLocalTracing; private final Span span; + private final String fullMethodName; - ClientCallTracer(@Nullable Span parentSpan, MethodDescriptor method) { + CallAttemptsTracerFactory(@Nullable Span parentSpan, MethodDescriptor method) { checkNotNull(method, "method"); this.isSampledToLocalTracing = method.isSampledToLocalTracing(); + this.fullMethodName = method.getFullMethodName(); this.span = censusTracer .spanBuilderWithExplicitParent( - generateTraceSpanName(false, method.getFullMethodName()), + generateTraceSpanName(false, fullMethodName), parentSpan) .setRecordEvents(true) .startSpan(); @@ -243,11 +249,17 @@ final class ClientCallTracer extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - if (span != BlankSpan.INSTANCE) { - headers.discardAll(tracingHeader); - headers.put(tracingHeader, span.getContext()); - } - return new ClientTracer(span); + Span attemptSpan = censusTracer + .spanBuilderWithExplicitParent( + "Attempt." + fullMethodName.replace('/', '.'), + span) + .setRecordEvents(true) + .startSpan(); + attemptSpan.putAttribute( + "previous-rpc-attempts", AttributeValue.longAttributeValue(info.getPreviousAttempts())); + attemptSpan.putAttribute( + "transparent-retry", AttributeValue.booleanAttributeValue(info.isTransparentRetry())); + return new ClientTracer(attemptSpan, tracingHeader, isSampledToLocalTracing); } /** @@ -273,9 +285,22 @@ void callEnded(io.grpc.Status status) { private static final class ClientTracer extends ClientStreamTracer { private final Span span; + final Metadata.Key tracingHeader; + final boolean isSampledToLocalTracing; - ClientTracer(Span span) { + ClientTracer( + Span span, Metadata.Key tracingHeader, boolean isSampledToLocalTracing) { this.span = checkNotNull(span, "span"); + this.tracingHeader = tracingHeader; + this.isSampledToLocalTracing = isSampledToLocalTracing; + } + + @Override + public void streamCreated(Attributes transportAtts, Metadata headers) { + if (span != BlankSpan.INSTANCE) { + headers.discardAll(tracingHeader); + headers.put(tracingHeader, span.getContext()); + } } @Override @@ -291,6 +316,11 @@ public void inboundMessageRead( recordMessageEvent( span, MessageEvent.Type.RECEIVED, seqNo, optionalWireSize, optionalUncompressedSize); } + + @Override + public void streamClosed(io.grpc.Status status) { + span.end(createEndSpanOptions(status, isSampledToLocalTracing)); + } } @@ -381,7 +411,7 @@ public ClientCall interceptCall( // Safe usage of the unsafe trace API because CONTEXT_SPAN_KEY.get() returns the same value // as Tracer.getCurrentSpan() except when no value available when the return value is null // for the direct access and BlankSpan when Tracer API is used. - final ClientCallTracer tracerFactory = + final CallAttemptsTracerFactory tracerFactory = newClientCallTracer(ContextUtils.getValue(Context.current()), method); ClientCall call = next.newCall( diff --git a/census/src/test/java/io/grpc/census/CensusModulesTest.java b/census/src/test/java/io/grpc/census/CensusModulesTest.java index fbbcd44150c..d285c8fe8c2 100644 --- a/census/src/test/java/io/grpc/census/CensusModulesTest.java +++ b/census/src/test/java/io/grpc/census/CensusModulesTest.java @@ -18,6 +18,9 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static io.grpc.census.CensusStatsModule.CallAttemptsTracerFactory.RETRIES_PER_CALL; +import static io.grpc.census.CensusStatsModule.CallAttemptsTracerFactory.RETRY_DELAY_PER_CALL; +import static io.grpc.census.CensusStatsModule.CallAttemptsTracerFactory.TRANSPARENT_RETRIES_PER_CALL; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -58,6 +61,7 @@ import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer.ServerCallInfo; import io.grpc.Status; +import io.grpc.census.CensusTracingModule.CallAttemptsTracerFactory; import io.grpc.census.internal.DeprecatedCensusConstants; import io.grpc.internal.FakeClock; import io.grpc.internal.testing.StatsTestUtils; @@ -81,6 +85,7 @@ import io.opencensus.stats.View; import io.opencensus.tags.TagContext; import io.opencensus.tags.TagValue; +import io.opencensus.trace.AttributeValue; import io.opencensus.trace.BlankSpan; import io.opencensus.trace.EndSpanOptions; import io.opencensus.trace.MessageEvent; @@ -173,10 +178,12 @@ public String parse(InputStream stream) { private final Random random = new Random(1234); private final Span fakeClientParentSpan = MockableSpan.generateRandomSpan(random); private final Span spyClientSpan = spy(MockableSpan.generateRandomSpan(random)); - private final SpanContext fakeClientSpanContext = spyClientSpan.getContext(); + private final Span spyAttemptSpan = spy(MockableSpan.generateRandomSpan(random)); + private final SpanContext fakeAttemptSpanContext = spyAttemptSpan.getContext(); private final Span spyServerSpan = spy(MockableSpan.generateRandomSpan(random)); private final byte[] binarySpanContext = new byte[]{3, 1, 5}; private final SpanBuilder spyClientSpanBuilder = spy(new MockableSpan.Builder()); + private final SpanBuilder spyAttemptSpanBuilder = spy(new MockableSpan.Builder()); private final SpanBuilder spyServerSpanBuilder = spy(new MockableSpan.Builder()); @Rule @@ -201,15 +208,20 @@ public String parse(InputStream stream) { @Before public void setUp() throws Exception { when(spyClientSpanBuilder.startSpan()).thenReturn(spyClientSpan); - when(tracer.spanBuilderWithExplicitParent(anyString(), ArgumentMatchers.any())) + when(spyAttemptSpanBuilder.startSpan()).thenReturn(spyAttemptSpan); + when(tracer.spanBuilderWithExplicitParent( + eq("Sent.package1.service2.method3"), ArgumentMatchers.any())) .thenReturn(spyClientSpanBuilder); + when(tracer.spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), ArgumentMatchers.any())) + .thenReturn(spyAttemptSpanBuilder); when(spyServerSpanBuilder.startSpan()).thenReturn(spyServerSpan); when(tracer.spanBuilderWithRemoteParent(anyString(), ArgumentMatchers.any())) .thenReturn(spyServerSpanBuilder); when(mockTracingPropagationHandler.toByteArray(any(SpanContext.class))) .thenReturn(binarySpanContext); when(mockTracingPropagationHandler.fromByteArray(any(byte[].class))) - .thenReturn(fakeClientSpanContext); + .thenReturn(fakeAttemptSpanContext); censusStats = new CensusStatsModule( tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), @@ -292,10 +304,10 @@ public ClientCall interceptCall( assertEquals(2, capturedCallOptions.get().getStreamTracerFactories().size()); assertTrue( capturedCallOptions.get().getStreamTracerFactories().get(0) - instanceof CensusTracingModule.ClientCallTracer); + instanceof CallAttemptsTracerFactory); assertTrue( capturedCallOptions.get().getStreamTracerFactories().get(1) - instanceof CensusStatsModule.ClientCallTracer); + instanceof CensusStatsModule.CallAttemptsTracerFactory); // Make the call Metadata headers = new Metadata(); @@ -355,6 +367,7 @@ record = statsRecorder.pollRecord(); .setSampleToLocalSpanStore(false) .build()); verify(spyClientSpan, never()).end(); + assertZeroRetryRecorded(); } @Test @@ -388,11 +401,12 @@ private void subtestClientBasicStatsDefaultContext( new CensusStatsModule( tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), true, recordStarts, recordFinishes, recordRealTime); - CensusStatsModule.ClientCallTracer callTracer = - localCensusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); Metadata headers = new Metadata(); - ClientStreamTracer tracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); if (recordStarts) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -455,7 +469,7 @@ private void subtestClientBasicStatsDefaultContext( tracer.inboundUncompressedSize(552); tracer.streamClosed(Status.OK); - callTracer.callEnded(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); if (recordFinishes) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -488,11 +502,200 @@ private void subtestClientBasicStatsDefaultContext( DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); assertEquals(30 + 100 + 16 + 24, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + assertZeroRetryRecorded(); } else { assertNull(statsRecorder.pollRecord()); } } + // This test is only unit-testing the stat recording logic. The retry behavior is faked. + @Test + public void recordRetryStats() { + CensusStatsModule localCensusStats = + new CensusStatsModule( + tagger, tagCtxSerializer, statsRecorder, fakeClock.getStopwatchSupplier(), + true, true, true, true); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + + StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + + fakeClock.forwardTime(30, MILLISECONDS); + tracer.outboundHeaders(); + fakeClock.forwardTime(100, MILLISECONDS); + tracer.outboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundMessage(1); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundWireSize(1028); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD, 1028, true, true); + tracer.outboundUncompressedSize(1128); + fakeClock.forwardTime(24, MILLISECONDS); + tracer.streamClosed(Status.UNAVAILABLE); + record = statsRecorder.pollRecord(); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + TagValue statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.UNAVAILABLE.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); + assertEquals( + 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals( + 1128, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals( + 30 + 100 + 24, + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + + // faking retry + fakeClock.forwardTime(1000, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + tracer.outboundHeaders(); + tracer.outboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundMessage(1); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundWireSize(1028); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD, 1028, true, true); + tracer.outboundUncompressedSize(1128); + fakeClock.forwardTime(100, MILLISECONDS); + tracer.streamClosed(Status.NOT_FOUND); + record = statsRecorder.pollRecord(); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.NOT_FOUND.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); + assertEquals( + 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals( + 1128, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals( + 100 , + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + + // fake transparent retry + fakeClock.forwardTime(10, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + tracer.streamClosed(Status.UNAVAILABLE); + record = statsRecorder.pollRecord(); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.UNAVAILABLE.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertEquals(1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); + assertEquals( + 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 0, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + + // fake another transparent retry + fakeClock.forwardTime(10, MILLISECONDS); + tracer = callAttemptsTracerFactory.newClientStreamTracer( + STREAM_INFO.toBuilder().setIsTransparentRetry(true).build(), new Metadata()); + record = statsRecorder.pollRecord(); + assertEquals(1, record.tags.size()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)); + tracer.outboundHeaders(); + tracer.outboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundMessage(1); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD, 1, true, true); + tracer.outboundWireSize(1028); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD, 1028, true, true); + tracer.outboundUncompressedSize(1128); + fakeClock.forwardTime(16, MILLISECONDS); + tracer.inboundMessage(0); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_METHOD, 1, true, true); + tracer.inboundWireSize(33); + assertRealTimeMetric( + RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_METHOD, 33, true, true); + tracer.inboundUncompressedSize(67); + fakeClock.forwardTime(24, MILLISECONDS); + // RPC succeeded + tracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); + + record = statsRecorder.pollRecord(); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.OK.toString(), statusTag.asString()); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)); + assertThat(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)).isNull(); + assertEquals( + 2, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)); + assertEquals( + 1028, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_BYTES)); + assertEquals( + 1128, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_REQUEST_BYTES)); + assertEquals( + 1, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_COUNT)); + assertEquals( + 33, + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_RESPONSE_BYTES)); + assertEquals( + 67, + record.getMetricAsLongOrFail( + DeprecatedCensusConstants.RPC_CLIENT_UNCOMPRESSED_RESPONSE_BYTES)); + assertEquals( + 16 + 24 , + record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); + + record = statsRecorder.pollRecord(); + methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertEquals(Status.Code.OK.toString(), statusTag.asString()); + assertThat(record.getMetric(RETRIES_PER_CALL)).isEqualTo(1); + assertThat(record.getMetric(TRANSPARENT_RETRIES_PER_CALL)).isEqualTo(2); + assertThat(record.getMetric(RETRY_DELAY_PER_CALL)).isEqualTo(1000D + 10 + 10); + } + private void assertRealTimeMetric( Measure measure, long expectedValue, boolean recordRealTimeMetrics, boolean clientSide) { StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -516,15 +719,28 @@ private void assertRealTimeMetric( assertEquals(expectedValue, record.getMetricAsLongOrFail(measure)); } + private void assertZeroRetryRecorded() { + StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); + TagValue methodTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_METHOD); + assertEquals(method.getFullMethodName(), methodTag.asString()); + assertThat(record.getMetric(RETRIES_PER_CALL)).isEqualTo(0); + assertThat(record.getMetric(TRANSPARENT_RETRIES_PER_CALL)).isEqualTo(0); + assertThat(record.getMetric(RETRY_DELAY_PER_CALL)).isEqualTo(0D); + } + @Test public void clientBasicTracingDefaultSpan() { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(null, method); Metadata headers = new Metadata(); ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), ArgumentMatchers.isNull()); + verify(tracer).spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), eq(spyClientSpan)); verify(spyClientSpan, never()).end(any(EndSpanOptions.class)); + verify(spyAttemptSpan, never()).end(any(EndSpanOptions.class)); clientStreamTracer.outboundMessage(0); clientStreamTracer.outboundMessageSent(0, 882, -1); @@ -536,8 +752,12 @@ public void clientBasicTracingDefaultSpan() { clientStreamTracer.streamClosed(Status.OK); callTracer.callEnded(Status.OK); - InOrder inOrder = inOrder(spyClientSpan); - inOrder.verify(spyClientSpan, times(3)).addMessageEvent(messageEventCaptor.capture()); + InOrder inOrder = inOrder(spyClientSpan, spyAttemptSpan); + inOrder.verify(spyAttemptSpan) + .putAttribute("previous-rpc-attempts", AttributeValue.longAttributeValue(0)); + inOrder.verify(spyAttemptSpan) + .putAttribute("transparent-retry", AttributeValue.booleanAttributeValue(false)); + inOrder.verify(spyAttemptSpan, times(3)).addMessageEvent(messageEventCaptor.capture()); List events = messageEventCaptor.getAllValues(); assertEquals( MessageEvent.builder(MessageEvent.Type.SENT, 0).setCompressedMessageSize(882).build(), @@ -551,18 +771,23 @@ public void clientBasicTracingDefaultSpan() { .setUncompressedMessageSize(90) .build(), events.get(2)); + inOrder.verify(spyAttemptSpan).end( + EndSpanOptions.builder() + .setStatus(io.opencensus.trace.Status.OK) + .setSampleToLocalSpanStore(false) + .build()); inOrder.verify(spyClientSpan).end( EndSpanOptions.builder() .setStatus(io.opencensus.trace.Status.OK) .setSampleToLocalSpanStore(false) .build()); - verifyNoMoreInteractions(spyClientSpan); + inOrder.verifyNoMoreInteractions(); verifyNoMoreInteractions(tracer); } @Test public void clientTracingSampledToLocalSpanStore() { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(null, sampledMethod); callTracer.callEnded(Status.OK); @@ -575,11 +800,15 @@ public void clientTracingSampledToLocalSpanStore() { @Test public void clientStreamNeverCreatedStillRecordStats() { - CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); - + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + censusStats, tagger.empty(), method.getFullMethodName()); + ClientStreamTracer streamTracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); fakeClock.forwardTime(3000, MILLISECONDS); - callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); + Status status = Status.DEADLINE_EXCEEDED.withDescription("3 seconds"); + streamTracer.streamClosed(status); + callAttemptsTracerFactory.callEnded(status); // Upstart record StatsTestUtils.MetricsRecord record = statsRecorder.pollRecord(); @@ -625,11 +854,12 @@ record = statsRecorder.pollRecord(); 3000, record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)); assertNull(record.getMetric(DeprecatedCensusConstants.RPC_CLIENT_SERVER_ELAPSED_TIME)); + assertZeroRetryRecorded(); } @Test public void clientStreamNeverCreatedStillRecordTracing() { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), same(fakeClientParentSpan)); @@ -680,10 +910,13 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS fakeClock.getStopwatchSupplier(), propagate, recordStats, recordStats, recordStats); Metadata headers = new Metadata(); - CensusStatsModule.ClientCallTracer callTracer = - census.newClientCallTracer(clientCtx, method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + census, clientCtx, method.getFullMethodName()); // This propagates clientCtx to headers if propagates==true - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); if (recordStats) { // Client upstart record StatsTestUtils.MetricsRecord clientRecord = statsRecorder.pollRecord(); @@ -746,7 +979,8 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS // Verifies that the client tracer factory uses clientCtx, which includes the custom tags, to // record stats. - callTracer.callEnded(Status.OK); + streamTracer.streamClosed(Status.OK); + callAttemptsTracerFactory.callEnded(Status.OK); if (recordStats) { // Client completion record @@ -760,6 +994,7 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS assertNull(clientRecord.getMetric(DeprecatedCensusConstants.RPC_CLIENT_ERROR_COUNT)); TagValue clientPropagatedTag = clientRecord.tags.get(StatsTestUtils.EXTRA_TAG); assertEquals("extra-tag-value-897", clientPropagatedTag.asString()); + assertZeroRetryRecorded(); } if (!recordStats) { @@ -769,10 +1004,12 @@ private void subtestStatsHeadersPropagateTags(boolean propagate, boolean recordS @Test public void statsHeadersNotPropagateDefaultContext() { - CensusStatsModule.ClientCallTracer callTracer = - censusStats.newClientCallTracer(tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + censusStats, tagger.empty(), method.getFullMethodName()); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers) + .streamCreated(Attributes.EMPTY, headers); assertFalse(headers.containsKey(censusStats.statsHeader)); // Clear recorded stats to satisfy the assertions in wrapUp() statsRecorder.rolloverRecords(); @@ -800,15 +1037,18 @@ public void statsHeaderMalformed() { @Test public void traceHeadersPropagateSpanContext() throws Exception { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); - verify(mockTracingPropagationHandler).toByteArray(same(fakeClientSpanContext)); + verify(mockTracingPropagationHandler).toByteArray(same(fakeAttemptSpanContext)); verifyNoMoreInteractions(mockTracingPropagationHandler); verify(tracer).spanBuilderWithExplicitParent( eq("Sent.package1.service2.method3"), same(fakeClientParentSpan)); + verify(tracer).spanBuilderWithExplicitParent( + eq("Attempt.package1.service2.method3"), same(spyClientSpan)); verify(spyClientSpanBuilder).setRecordEvents(eq(true)); verifyNoMoreInteractions(tracer); assertTrue(headers.containsKey(censusTracing.tracingHeader)); @@ -818,7 +1058,7 @@ public void traceHeadersPropagateSpanContext() throws Exception { method.getFullMethodName(), headers); verify(mockTracingPropagationHandler).fromByteArray(same(binarySpanContext)); verify(tracer).spanBuilderWithRemoteParent( - eq("Recv.package1.service2.method3"), same(spyClientSpan.getContext())); + eq("Recv.package1.service2.method3"), same(spyAttemptSpan.getContext())); verify(spyServerSpanBuilder).setRecordEvents(eq(true)); Context filteredContext = serverTracer.filterContext(Context.ROOT); @@ -827,11 +1067,12 @@ public void traceHeadersPropagateSpanContext() throws Exception { @Test public void traceHeaders_propagateSpanContext() throws Exception { - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(fakeClientParentSpan, method); Metadata headers = new Metadata(); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + ClientStreamTracer streamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + streamTracer.streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).isNotEmpty(); } @@ -840,12 +1081,14 @@ public void traceHeaders_propagateSpanContext() throws Exception { public void traceHeaders_missingCensusImpl_notPropagateSpanContext() throws Exception { reset(spyClientSpanBuilder); + reset(spyAttemptSpanBuilder); when(spyClientSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); + when(spyAttemptSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); Metadata headers = new Metadata(); - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).isEmpty(); } @@ -853,16 +1096,18 @@ public void traceHeaders_missingCensusImpl_notPropagateSpanContext() @Test public void traceHeaders_clientMissingCensusImpl_preservingHeaders() throws Exception { reset(spyClientSpanBuilder); + reset(spyAttemptSpanBuilder); when(spyClientSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); + when(spyAttemptSpanBuilder.startSpan()).thenReturn(BlankSpan.INSTANCE); Metadata headers = new Metadata(); headers.put( Metadata.Key.of("never-used-key-bin", Metadata.BINARY_BYTE_MARSHALLER), new byte[] {}); Set originalHeaderKeys = new HashSet<>(headers.keys()); - CensusTracingModule.ClientCallTracer callTracer = + CallAttemptsTracerFactory callTracer = censusTracing.newClientCallTracer(BlankSpan.INSTANCE, method); - callTracer.newClientStreamTracer(STREAM_INFO, headers); + callTracer.newClientStreamTracer(STREAM_INFO, headers).streamCreated(Attributes.EMPTY, headers); assertThat(headers.keys()).containsExactlyElementsIn(originalHeaderKeys); } @@ -871,9 +1116,9 @@ public void traceHeaders_clientMissingCensusImpl_preservingHeaders() throws Exce public void traceHeaderMalformed() throws Exception { // As comparison, normal header parsing Metadata headers = new Metadata(); - headers.put(censusTracing.tracingHeader, fakeClientSpanContext); + headers.put(censusTracing.tracingHeader, fakeAttemptSpanContext); // mockTracingPropagationHandler was stubbed to always return fakeServerParentSpanContext - assertSame(spyClientSpan.getContext(), headers.get(censusTracing.tracingHeader)); + assertSame(spyAttemptSpan.getContext(), headers.get(censusTracing.tracingHeader)); // Make BinaryPropagationHandler always throw when parsing the header when(mockTracingPropagationHandler.fromByteArray(any(byte[].class))) @@ -881,7 +1126,7 @@ public void traceHeaderMalformed() throws Exception { headers = new Metadata(); assertNull(headers.get(censusTracing.tracingHeader)); - headers.put(censusTracing.tracingHeader, fakeClientSpanContext); + headers.put(censusTracing.tracingHeader, fakeAttemptSpanContext); assertSame(SpanContext.INVALID, headers.get(censusTracing.tracingHeader)); assertNotSame(spyClientSpan.getContext(), SpanContext.INVALID); @@ -1186,13 +1431,18 @@ public void newTagsPopulateOldViews() throws InterruptedException { tagger, tagCtxSerializer, localStats.getStatsRecorder(), fakeClock.getStopwatchSupplier(), false, false, true, false /* real-time */); - CensusStatsModule.ClientCallTracer callTracer = - localCensusStats.newClientCallTracer( - tagger.empty(), method.getFullMethodName()); + CensusStatsModule.CallAttemptsTracerFactory callAttemptsTracerFactory = + new CensusStatsModule.CallAttemptsTracerFactory( + localCensusStats, tagger.empty(), method.getFullMethodName()); - callTracer.newClientStreamTracer(STREAM_INFO, new Metadata()); + Metadata headers = new Metadata(); + ClientStreamTracer tracer = + callAttemptsTracerFactory.newClientStreamTracer(STREAM_INFO, headers); + tracer.streamCreated(Attributes.EMPTY, headers); fakeClock.forwardTime(30, MILLISECONDS); - callTracer.callEnded(Status.PERMISSION_DENIED.withDescription("No you don't")); + Status status = Status.PERMISSION_DENIED.withDescription("No you don't"); + tracer.streamClosed(status); + callAttemptsTracerFactory.callEnded(status); // Give OpenCensus a chance to update the views asynchronously. Thread.sleep(100); diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index ecf5e3889dd..2beed7b2b7f 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 6abbd4732fc..ba2c37f4b81 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 4de0f949c59..72d1b428efb 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index b2481063ca6..bc1d50acecc 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.40.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.41.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/BUILD.bazel b/core/BUILD.bazel index c50e86a511c..60a08798d58 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -60,6 +60,7 @@ java_library( "@com_google_code_findbugs_jsr305//jar", "@com_google_guava_guava//jar", "@com_google_j2objc_j2objc_annotations//jar", + "@org_codehaus_mojo_animal_sniffer_annotations//jar", ], ) diff --git a/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java b/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java index aec2659f024..4d4349eef1b 100644 --- a/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java +++ b/core/src/jmh/java/io/grpc/internal/StatsTraceContextBenchmark.java @@ -17,7 +17,7 @@ package io.grpc.internal; import io.grpc.Attributes; -import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.ServerStreamTracer; @@ -50,7 +50,8 @@ public class StatsTraceContextBenchmark { @BenchmarkMode(Mode.SampleTime) @OutputTimeUnit(TimeUnit.NANOSECONDS) public StatsTraceContext newClientContext() { - return StatsTraceContext.newClientContext(CallOptions.DEFAULT, Attributes.EMPTY, emptyMetadata); + return StatsTraceContext.newClientContext( + new ClientStreamTracer[] { new ClientStreamTracer() {} }, Attributes.EMPTY, emptyMetadata); } /** diff --git a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java index 58df4371e72..895b709559b 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessTransport.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessTransport.java @@ -26,6 +26,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Compressor; import io.grpc.Deadline; import io.grpc.Decompressor; @@ -205,10 +206,12 @@ public void run() { @Override public synchronized ClientStream newStream( - final MethodDescriptor method, final Metadata headers, final CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); if (shutdownStatus != null) { - return failedClientStream( - StatsTraceContext.newClientContext(callOptions, attributes, headers), shutdownStatus); + return failedClientStream(statsTraceContext, shutdownStatus); } headers.put(GrpcUtil.USER_AGENT_KEY, userAgent); @@ -226,12 +229,12 @@ public synchronized ClientStream newStream( "Request metadata larger than %d: %d", serverMaxInboundMetadataSize, metadataSize)); - return failedClientStream( - StatsTraceContext.newClientContext(callOptions, attributes, headers), status); + return failedClientStream(statsTraceContext, status); } } - return new InProcessStream(method, headers, callOptions, authority).clientStream; + return new InProcessStream(method, headers, callOptions, authority, statsTraceContext) + .clientStream; } private ClientStream failedClientStream( @@ -377,12 +380,12 @@ private class InProcessStream { private InProcessStream( MethodDescriptor method, Metadata headers, CallOptions callOptions, - String authority) { + String authority , StatsTraceContext statsTraceContext) { this.method = checkNotNull(method, "method"); this.headers = checkNotNull(headers, "headers"); this.callOptions = checkNotNull(callOptions, "callOptions"); this.authority = authority; - this.clientStream = new InProcessClientStream(callOptions, headers); + this.clientStream = new InProcessClientStream(callOptions, statsTraceContext); this.serverStream = new InProcessServerStream(method, headers); } @@ -673,9 +676,10 @@ private class InProcessClientStream implements ClientStream { @GuardedBy("this") private int outboundSeqNo; - InProcessClientStream(CallOptions callOptions, Metadata headers) { + InProcessClientStream( + CallOptions callOptions, StatsTraceContext statsTraceContext) { this.callOptions = callOptions; - statsTraceCtx = StatsTraceContext.newClientContext(callOptions, attributes, headers); + statsTraceCtx = statsTraceContext; } private synchronized void setListener(ServerStreamListener listener) { diff --git a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java index 0b1ce3514a2..6b6472825d2 100644 --- a/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java +++ b/core/src/main/java/io/grpc/internal/CallCredentialsApplyingTransportFactory.java @@ -25,6 +25,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.CompositeCallCredentials; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -104,7 +105,8 @@ protected ConnectionClientTransport delegate() { @Override @SuppressWarnings("deprecation") public ClientStream newStream( - final MethodDescriptor method, Metadata headers, final CallOptions callOptions) { + final MethodDescriptor method, Metadata headers, final CallOptions callOptions, + ClientStreamTracer[] tracers) { CallCredentials creds = callOptions.getCredentials(); if (creds == null) { creds = channelCallCredentials; @@ -113,10 +115,10 @@ public ClientStream newStream( } if (creds != null) { MetadataApplierImpl applier = new MetadataApplierImpl( - delegate, method, headers, callOptions, applierListener); + delegate, method, headers, callOptions, applierListener, tracers); if (pendingApplier.incrementAndGet() > 0) { applierListener.onComplete(); - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } RequestInfo requestInfo = new RequestInfo() { @Override @@ -152,9 +154,9 @@ public Attributes getTransportAttrs() { return applier.returnStream(); } else { if (pendingApplier.get() >= 0) { - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } - return delegate.newStream(method, headers, callOptions); + return delegate.newStream(method, headers, callOptions, tracers); } } diff --git a/core/src/main/java/io/grpc/internal/ClientCallImpl.java b/core/src/main/java/io/grpc/internal/ClientCallImpl.java index c2e1bd2b1f2..dd17244e2a5 100644 --- a/core/src/main/java/io/grpc/internal/ClientCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ClientCallImpl.java @@ -33,6 +33,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.Codec; import io.grpc.Compressor; import io.grpc.CompressorRegistry; @@ -254,9 +255,12 @@ public void runInContext() { effectiveDeadline, context.getDeadline(), callOptions.getDeadline()); stream = clientStreamProvider.newStream(method, callOptions, headers, context); } else { + ClientStreamTracer[] tracers = + GrpcUtil.getClientStreamTracers(callOptions, headers, 0, false); stream = new FailingClientStream( DEADLINE_EXCEEDED.withDescription( - "ClientCall started after deadline exceeded: " + effectiveDeadline)); + "ClientCall started after deadline exceeded: " + effectiveDeadline), + tracers); } if (callExecutorIsDirect) { diff --git a/core/src/main/java/io/grpc/internal/ClientTransport.java b/core/src/main/java/io/grpc/internal/ClientTransport.java index cc8471ab6a3..a569a7922df 100644 --- a/core/src/main/java/io/grpc/internal/ClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ClientTransport.java @@ -17,6 +17,7 @@ package io.grpc.internal; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalInstrumented; import io.grpc.Metadata; @@ -46,10 +47,15 @@ public interface ClientTransport extends InternalInstrumented { * @param method the descriptor of the remote method to be called for this stream. * @param headers to send at the beginning of the call * @param callOptions runtime options of the call + * @param tracers a non-empty array of tracers. The last element in it is reserved to be set by + * the load balancer's pick result and otherwise is a no-op tracer. * @return the newly created stream. */ // TODO(nmittler): Consider also throwing for stopping. - ClientStream newStream(MethodDescriptor method, Metadata headers, CallOptions callOptions); + ClientStream newStream( + MethodDescriptor method, Metadata headers, CallOptions callOptions, + // Using array for tracers instead of a list or composition for better performance. + ClientStreamTracer[] tracers); /** * Pings a remote endpoint. When an acknowledgement is received, the given callback will be diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 6a72eb7c21e..2b1145d1c4b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; @@ -133,7 +134,8 @@ public void run() { */ @Override public final ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { try { PickSubchannelArgs args = new PickSubchannelArgsImpl(method, headers, callOptions); SubchannelPicker picker = null; @@ -141,14 +143,14 @@ public final ClientStream newStream( while (true) { synchronized (lock) { if (shutdownStatus != null) { - return new FailingClientStream(shutdownStatus); + return new FailingClientStream(shutdownStatus, tracers); } if (lastPicker == null) { - return createPendingStream(args); + return createPendingStream(args, tracers); } // Check for second time through the loop, and whether anything changed if (picker != null && pickerVersion == lastPickerVersion) { - return createPendingStream(args); + return createPendingStream(args, tracers); } picker = lastPicker; pickerVersion = lastPickerVersion; @@ -158,7 +160,8 @@ public final ClientStream newStream( callOptions.isWaitForReady()); if (transport != null) { return transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions()); + args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + tracers); } // This picker's conclusion is "buffer". If there hasn't been a newer picker set (possible // race with reprocess()), we will buffer it. Otherwise, will try with the new picker. @@ -173,8 +176,9 @@ public final ClientStream newStream( * schedule tasks on syncContext. */ @GuardedBy("lock") - private PendingStream createPendingStream(PickSubchannelArgs args) { - PendingStream pendingStream = new PendingStream(args); + private PendingStream createPendingStream( + PickSubchannelArgs args, ClientStreamTracer[] tracers) { + PendingStream pendingStream = new PendingStream(args, tracers); pendingStreams.add(pendingStream); if (getPendingStreamsCount() == 1) { syncContext.executeLater(reportTransportInUse); @@ -239,7 +243,8 @@ public final void shutdownNow(Status status) { } if (savedReportTransportTerminated != null) { for (PendingStream stream : savedPendingStreams) { - Runnable runnable = stream.setStream(new FailingClientStream(status, RpcProgress.REFUSED)); + Runnable runnable = stream.setStream( + new FailingClientStream(status, RpcProgress.REFUSED, stream.tracers)); if (runnable != null) { // Drain in-line instead of using an executor as failing stream just throws everything // away. This is essentially the same behavior as DelayedStream.cancel() but can be done @@ -346,9 +351,11 @@ public InternalLogId getLogId() { private class PendingStream extends DelayedStream { private final PickSubchannelArgs args; private final Context context = Context.current(); + private final ClientStreamTracer[] tracers; - private PendingStream(PickSubchannelArgs args) { + private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { this.args = args; + this.tracers = tracers; } /** Runnable may be null. */ @@ -357,7 +364,8 @@ private Runnable createRealStream(ClientTransport transport) { Context origContext = context.attach(); try { realStream = transport.newStream( - args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions()); + args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(), + tracers); } finally { context.detach(origContext); } @@ -382,6 +390,13 @@ public void cancel(Status reason) { syncContext.drain(); } + @Override + protected void onEarlyCancellation(Status reason) { + for (ClientStreamTracer tracer : tracers) { + tracer.streamClosed(reason); + } + } + @Override public void appendTimeoutInsight(InsightBuilder insight) { if (args.getCallOptions().isWaitForReady()) { diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index f0a378e8124..28ce2764c75 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -324,11 +324,15 @@ public void run() { }); } else { drainPendingCalls(); + onEarlyCancellation(reason); // Note that listener is a DelayedStreamListener listener.closed(reason, RpcProgress.PROCESSED, new Metadata()); } } + protected void onEarlyCancellation(Status reason) { + } + @GuardedBy("this") private void setRealStream(ClientStream realStream) { checkState(this.realStream == null, "realStream already set to %s", this.realStream); diff --git a/core/src/main/java/io/grpc/internal/FailingClientStream.java b/core/src/main/java/io/grpc/internal/FailingClientStream.java index 6d368b6975f..6388ef8b6ee 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientStream.java +++ b/core/src/main/java/io/grpc/internal/FailingClientStream.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -30,27 +31,33 @@ public final class FailingClientStream extends NoopClientStream { private boolean started; private final Status error; private final RpcProgress rpcProgress; + private final ClientStreamTracer[] tracers; /** * Creates a {@code FailingClientStream} that would fail with the given error. */ - public FailingClientStream(Status error) { - this(error, RpcProgress.PROCESSED); + public FailingClientStream(Status error, ClientStreamTracer[] tracers) { + this(error, RpcProgress.PROCESSED, tracers); } /** * Creates a {@code FailingClientStream} that would fail with the given error. */ - public FailingClientStream(Status error, RpcProgress rpcProgress) { + public FailingClientStream( + Status error, RpcProgress rpcProgress, ClientStreamTracer[] tracers) { Preconditions.checkArgument(!error.isOk(), "error must not be OK"); this.error = error; this.rpcProgress = rpcProgress; + this.tracers = tracers; } @Override public void start(ClientStreamListener listener) { Preconditions.checkState(!started, "already started"); started = true; + for (ClientStreamTracer tracer : tracers) { + tracer.streamClosed(error); + } listener.closed(error, rpcProgress, new Metadata()); } diff --git a/core/src/main/java/io/grpc/internal/FailingClientTransport.java b/core/src/main/java/io/grpc/internal/FailingClientTransport.java index 25d20017c92..5b31e6e5073 100644 --- a/core/src/main/java/io/grpc/internal/FailingClientTransport.java +++ b/core/src/main/java/io/grpc/internal/FailingClientTransport.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -45,8 +46,9 @@ class FailingClientTransport implements ClientTransport { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return new FailingClientStream(error, rpcProgress); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + return new FailingClientStream(error, rpcProgress, tracers); } @Override diff --git a/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java new file mode 100644 index 00000000000..fd03564d396 --- /dev/null +++ b/core/src/main/java/io/grpc/internal/ForwardingClientStreamTracer.java @@ -0,0 +1,101 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import com.google.common.base.MoreObjects; +import io.grpc.Attributes; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.Status; + +public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { + + /** + * Returns the underlying {@code ClientStreamTracer}. + */ + protected abstract ClientStreamTracer delegate(); + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + delegate().streamCreated(transportAttrs, headers); + } + + @Override + public void outboundHeaders() { + delegate().outboundHeaders(); + } + + @Override + public void inboundHeaders() { + delegate().inboundHeaders(); + } + + @Override + public void inboundTrailers(Metadata trailers) { + delegate().inboundTrailers(trailers); + } + + @Override + public void streamClosed(Status status) { + delegate().streamClosed(status); + } + + @Override + public void outboundMessage(int seqNo) { + delegate().outboundMessage(seqNo); + } + + @Override + public void inboundMessage(int seqNo) { + delegate().inboundMessage(seqNo); + } + + @Override + public void outboundMessageSent(int seqNo, long optionalWireSize, long optionalUncompressedSize) { + delegate().outboundMessageSent(seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead(int seqNo, long optionalWireSize, long optionalUncompressedSize) { + delegate().inboundMessageRead(seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void outboundWireSize(long bytes) { + delegate().outboundWireSize(bytes); + } + + @Override + public void outboundUncompressedSize(long bytes) { + delegate().outboundUncompressedSize(bytes); + } + + @Override + public void inboundWireSize(long bytes) { + delegate().inboundWireSize(bytes); + } + + @Override + public void inboundUncompressedSize(long bytes) { + delegate().inboundUncompressedSize(bytes); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); + } +} diff --git a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java index e54f8b169d6..bfdccbe5d6a 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java +++ b/core/src/main/java/io/grpc/internal/ForwardingConnectionClientTransport.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -45,8 +46,9 @@ public void shutdownNow(Status status) { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return delegate().newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + return delegate().newStream(method, headers, callOptions, tracers); } @Override diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 45c0fce7122..12ae8954ce5 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -17,6 +17,7 @@ package io.grpc.internal; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Objects; @@ -26,8 +27,11 @@ import com.google.common.base.Supplier; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.InternalLimitedInfoFactory; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.InternalMetadata; @@ -54,6 +58,7 @@ import java.net.URISyntaxException; import java.nio.charset.Charset; import java.util.Collection; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -197,7 +202,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - private static final String IMPLEMENTATION_VERSION = "1.40.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + private static final String IMPLEMENTATION_VERSION = "1.41.0"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. @@ -253,6 +258,8 @@ public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) { public static final CallOptions.Key CALL_OPTIONS_RPC_OWNED_BY_BALANCER = CallOptions.Key.create("io.grpc.internal.CALL_OPTIONS_RPC_OWNED_BY_BALANCER"); + private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; + /** * Returns true if an RPC with the given properties should be counted when calculating the * in-use state of a transport. @@ -711,9 +718,14 @@ static ClientTransport getTransportFromPickResult(PickResult result, boolean isW return new ClientTransport() { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - return transport.newStream( - method, headers, callOptions.withStreamTracerFactory(streamTracerFactory)); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + StreamInfo info = StreamInfo.newBuilder().setCallOptions(callOptions).build(); + ClientStreamTracer streamTracer = + newClientStreamTracer(streamTracerFactory, info, headers); + checkState(tracers[tracers.length - 1] == NOOP_TRACER, "lb tracer already assigned"); + tracers[tracers.length - 1] = streamTracer; + return transport.newStream(method, headers, callOptions, tracers); } @Override @@ -743,6 +755,72 @@ public ListenableFuture getStats() { return null; } + /** Gets stream tracers based on CallOptions. */ + public static ClientStreamTracer[] getClientStreamTracers( + CallOptions callOptions, Metadata headers, int previousAttempts, boolean isTransparentRetry) { + List factories = callOptions.getStreamTracerFactories(); + ClientStreamTracer[] tracers = new ClientStreamTracer[factories.size() + 1]; + StreamInfo streamInfo = StreamInfo.newBuilder() + .setCallOptions(callOptions) + .setPreviousAttempts(previousAttempts) + .setIsTransparentRetry(isTransparentRetry) + .build(); + for (int i = 0; i < factories.size(); i++) { + tracers[i] = newClientStreamTracer(factories.get(i), streamInfo, headers); + } + // Reserved to be set later by the lb as per the API contract of ClientTransport.newStream(). + // See also GrpcUtil.getTransportFromPickResult() + tracers[tracers.length - 1] = NOOP_TRACER; + return tracers; + } + + // A util function for backward compatibility to support deprecated StreamInfo.getAttributes(). + @VisibleForTesting + static ClientStreamTracer newClientStreamTracer( + final ClientStreamTracer.Factory streamTracerFactory, final StreamInfo info, + final Metadata headers) { + ClientStreamTracer streamTracer; + if (streamTracerFactory instanceof InternalLimitedInfoFactory) { + streamTracer = streamTracerFactory.newClientStreamTracer(info, headers); + } else { + streamTracer = new ForwardingClientStreamTracer() { + final ClientStreamTracer noop = new ClientStreamTracer() {}; + volatile ClientStreamTracer delegate = noop; + + void maybeInit(StreamInfo info, Metadata headers) { + if (delegate != noop) { + return; + } + synchronized (this) { + if (delegate == noop) { + delegate = streamTracerFactory.newClientStreamTracer(info, headers); + } + } + } + + @Override + protected ClientStreamTracer delegate() { + return delegate; + } + + @SuppressWarnings("deprecation") + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + StreamInfo streamInfo = info.toBuilder().setTransportAttrs(transportAttrs).build(); + maybeInit(streamInfo, headers); + delegate().streamCreated(transportAttrs, headers); + } + + @Override + public void streamClosed(Status status) { + maybeInit(info, headers); + delegate().streamClosed(status); + } + }; + } + return streamTracer; + } + /** Quietly closes all messages in MessageProducer. */ static void closeQuietly(MessageProducer producer) { InputStream message; diff --git a/core/src/main/java/io/grpc/internal/InUseStateAggregator.java b/core/src/main/java/io/grpc/internal/InUseStateAggregator.java index f4f3a186d88..f3d870e8797 100644 --- a/core/src/main/java/io/grpc/internal/InUseStateAggregator.java +++ b/core/src/main/java/io/grpc/internal/InUseStateAggregator.java @@ -53,6 +53,21 @@ public final boolean isInUse() { return !inUseObjects.isEmpty(); } + /** + * Returns {@code true} if any of the given objects are in use. + * + * @param objects The objects to consider. + * @return {@code true} if any of the given objects are in use. + */ + public final boolean anyObjectInUse(Object... objects) { + for (Object object : objects) { + if (inUseObjects.contains(object)) { + return true; + } + } + return false; + } + /** * Called when the aggregated in-use state has changed to true, which means at least one object is * in use. diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index 331add6c8c4..fa2bf2e46bc 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -34,6 +34,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; @@ -667,8 +668,9 @@ protected ConnectionClientTransport delegate() { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { - final ClientStream streamDelegate = super.newStream(method, headers, callOptions); + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { + final ClientStream streamDelegate = super.newStream(method, headers, callOptions, tracers); return new ForwardingClientStream() { @Override protected ClientStream delegate() { diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index a9d24cd247a..2e079078fc7 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -423,7 +423,10 @@ private void enterIdleMode() { delayedTransport.reprocess(null); channelLogger.log(ChannelLogLevel.INFO, "Entering IDLE state"); channelStateManager.gotoState(IDLE); - if (inUseStateAggregator.isInUse()) { + // If the inUseStateAggregator still considers pending calls to be queued up or the delayed + // transport to be holding some we need to exit idle mode to give these calls a chance to + // be processed. + if (inUseStateAggregator.anyObjectInUse(pendingCallsInUseObject, delayedTransport)) { exitIdleMode(); } } @@ -532,8 +535,10 @@ public ClientStream newStream( ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, headers, callOptions)); Context origContext = context.attach(); + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, 0, /* isTransparentRetry= */ false); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } @@ -569,13 +574,17 @@ void postCommit() { } @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata newHeaders) { - CallOptions newOptions = callOptions.withStreamTracerFactory(tracerFactory); + ClientStream newSubstream( + Metadata newHeaders, ClientStreamTracer.Factory factory, int previousAttempts, + boolean isTransparentRetry) { + CallOptions newOptions = callOptions.withStreamTracerFactory(factory); + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + newOptions, newHeaders, previousAttempts, isTransparentRetry); ClientTransport transport = getTransport(new PickSubchannelArgsImpl(method, newHeaders, newOptions)); Context origContext = context.attach(); try { - return transport.newStream(method, newHeaders, newOptions); + return transport.newStream(method, newHeaders, newOptions, tracers); } finally { context.detach(origContext); } @@ -619,7 +628,7 @@ ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata new channelLogger = new ChannelLoggerImpl(channelTracer, timeProvider); ProxyDetector proxyDetector = builder.proxyDetector != null ? builder.proxyDetector : GrpcUtil.DEFAULT_PROXY_DETECTOR; - this.retryEnabled = builder.retryEnabled && !builder.temporarilyDisableRetry; + this.retryEnabled = builder.retryEnabled; this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy); this.offloadExecutorHolder = new ExecutorHolder( diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index d42b3832136..26c48fc8596 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -142,11 +142,7 @@ public static ManagedChannelBuilder forTarget(String target) { int maxHedgedAttempts = 5; long retryBufferSize = DEFAULT_RETRY_BUFFER_SIZE_IN_BYTES; long perRpcBufferLimit = DEFAULT_PER_RPC_BUFFER_LIMIT_IN_BYTES; - boolean retryEnabled = false; // TODO(zdapeng): default to true - // Temporarily disable retry when stats or tracing is enabled to avoid breakage, until we know - // what should be the desired behavior for retry + stats/tracing. - // TODO(zdapeng): delete me - boolean temporarilyDisableRetry; + boolean retryEnabled = true; InternalChannelz channelz = InternalChannelz.instance(); int maxTraceEvents; @@ -460,8 +456,6 @@ public ManagedChannelImplBuilder disableRetry() { @Override public ManagedChannelImplBuilder enableRetry() { retryEnabled = true; - statsEnabled = false; - tracingEnabled = false; return this; } @@ -592,9 +586,6 @@ public void setStatsRecordRealTimeMetrics(boolean value) { /** * Disable or enable tracing features. Enabled by default. - * - *

For the current release, calling {@code setTracingEnabled(true)} may have a side effect that - * disables retry. */ public void setTracingEnabled(boolean value) { tracingEnabled = value; @@ -642,9 +633,7 @@ public ManagedChannel build() { List getEffectiveInterceptors() { List effectiveInterceptors = new ArrayList<>(this.interceptors); - temporarilyDisableRetry = false; if (statsEnabled) { - temporarilyDisableRetry = true; ClientInterceptor statsInterceptor = null; try { Class censusStatsAccessor = @@ -679,7 +668,6 @@ List getEffectiveInterceptors() { } } if (tracingEnabled) { - temporarilyDisableRetry = true; ClientInterceptor tracingInterceptor = null; try { Class censusTracingAccessor = diff --git a/core/src/main/java/io/grpc/internal/MessageDeframer.java b/core/src/main/java/io/grpc/internal/MessageDeframer.java index 9a523746e50..534398315e8 100644 --- a/core/src/main/java/io/grpc/internal/MessageDeframer.java +++ b/core/src/main/java/io/grpc/internal/MessageDeframer.java @@ -517,8 +517,8 @@ private void reportCount() { private void verifySize() { if (count > maxMessageSize) { throw Status.RESOURCE_EXHAUSTED.withDescription(String.format( - "Compressed gRPC message exceeds maximum size %d: %d bytes read", - maxMessageSize, count)).asRuntimeException(); + "Decompressed gRPC message exceeds maximum size %d", + maxMessageSize)).asRuntimeException(); } } } diff --git a/core/src/main/java/io/grpc/internal/MessageFramer.java b/core/src/main/java/io/grpc/internal/MessageFramer.java index 83592e691a9..2042bddca03 100644 --- a/core/src/main/java/io/grpc/internal/MessageFramer.java +++ b/core/src/main/java/io/grpc/internal/MessageFramer.java @@ -267,7 +267,7 @@ private static int writeToOutputStream(InputStream message, OutputStream outputS return ((Drainable) message).drainTo(outputStream); } else { // This makes an unnecessary copy of the bytes when bytebuf supports array(). However, we - // expect performance-critical code to support flushTo(). + // expect performance-critical code to support drainTo(). @SuppressWarnings("BetaApi") // ByteStreams is not Beta in v27 long written = ByteStreams.copy(message, outputStream); checkArgument(written <= Integer.MAX_VALUE, "Message size overflow: %s", written); diff --git a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java index 76d280b2d00..6893713c1d2 100644 --- a/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java +++ b/core/src/main/java/io/grpc/internal/MetadataApplierImpl.java @@ -22,6 +22,7 @@ import io.grpc.CallCredentials.MetadataApplier; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -36,7 +37,7 @@ final class MetadataApplierImpl extends MetadataApplier { private final CallOptions callOptions; private final Context ctx; private final MetadataApplierListener listener; - + private final ClientStreamTracer[] tracers; private final Object lock = new Object(); // null if neither apply() or returnStream() are called. @@ -52,13 +53,14 @@ final class MetadataApplierImpl extends MetadataApplier { MetadataApplierImpl( ClientTransport transport, MethodDescriptor method, Metadata origHeaders, - CallOptions callOptions, MetadataApplierListener listener) { + CallOptions callOptions, MetadataApplierListener listener, ClientStreamTracer[] tracers) { this.transport = transport; this.method = method; this.origHeaders = origHeaders; this.callOptions = callOptions; this.ctx = Context.current(); this.listener = listener; + this.tracers = tracers; } @Override @@ -69,7 +71,7 @@ public void apply(Metadata headers) { ClientStream realStream; Context origCtx = ctx.attach(); try { - realStream = transport.newStream(method, origHeaders, callOptions); + realStream = transport.newStream(method, origHeaders, callOptions, tracers); } finally { ctx.detach(origCtx); } @@ -80,7 +82,7 @@ public void apply(Metadata headers) { public void fail(Status status) { checkArgument(!status.isOk(), "Cannot fail with OK status"); checkState(!finalized, "apply() or fail() already called"); - finalizeWith(new FailingClientStream(status)); + finalizeWith(new FailingClientStream(status, tracers)); } private void finalizeWith(ClientStream stream) { diff --git a/core/src/main/java/io/grpc/internal/OobChannel.java b/core/src/main/java/io/grpc/internal/OobChannel.java index f69fd17e5c4..589824ae10e 100644 --- a/core/src/main/java/io/grpc/internal/OobChannel.java +++ b/core/src/main/java/io/grpc/internal/OobChannel.java @@ -26,6 +26,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.Context; @@ -86,12 +87,14 @@ final class OobChannel extends ManagedChannel implements InternalInstrumented method, CallOptions callOptions, Metadata headers, Context context) { + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, 0, /* isTransparentRetry= */ false); Context origContext = context.attach(); // delayed transport's newStream() always acquires a lock, but concurrent performance doesn't // matter here because OOB communication should be sparse, and it's not on application RPC's // critical path. try { - return delayedTransport.newStream(method, headers, callOptions); + return delayedTransport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 9d752b86576..1fb8d3c43bd 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -30,8 +30,10 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.SynchronizationContext; import io.grpc.internal.ClientStreamListener.RpcProgress; import java.io.InputStream; +import java.lang.Thread.UncaughtExceptionHandler; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -64,6 +66,16 @@ abstract class RetriableStream implements ClientStream { private final MethodDescriptor method; private final Executor callExecutor; + private final Executor listenerSerializeExecutor = new SynchronizationContext( + new UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw Status.fromThrowable(e) + .withDescription("Uncaught exception in the SynchronizationContext. Re-thrown.") + .asRuntimeException(); + } + } + ); private final ScheduledExecutorService scheduledExecutorService; // Must not modify it. private final Metadata headers; @@ -104,6 +116,8 @@ abstract class RetriableStream implements ClientStream { @GuardedBy("lock") private FutureCanceller scheduledHedging; private long nextBackoffIntervalNanos; + private Status cancellationStatus; + private boolean isClosed; RetriableStream( MethodDescriptor method, Metadata headers, @@ -203,11 +217,11 @@ private void commitAndRun(Substream winningSubstream) { } } - private Substream createSubstream(int previousAttemptCount) { + private Substream createSubstream(int previousAttemptCount, boolean isTransparentRetry) { Substream sub = new Substream(previousAttemptCount); // one tracer per substream final ClientStreamTracer bufferSizeTracer = new BufferSizeTracer(sub); - ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.Factory() { + ClientStreamTracer.Factory tracerFactory = new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -217,7 +231,7 @@ public ClientStreamTracer newClientStreamTracer( Metadata newHeaders = updateHeaders(headers, previousAttemptCount); // NOTICE: This set _must_ be done before stream.start() and it actually is. - sub.stream = newSubstream(tracerFactory, newHeaders); + sub.stream = newSubstream(newHeaders, tracerFactory, previousAttemptCount, isTransparentRetry); return sub; } @@ -226,7 +240,8 @@ public ClientStreamTracer newClientStreamTracer( * Client stream is not yet started. */ abstract ClientStream newSubstream( - ClientStreamTracer.Factory tracerFactory, Metadata headers); + Metadata headers, ClientStreamTracer.Factory tracerFactory, int previousAttempts, + boolean isTransparentRetry); /** Adds grpc-previous-rpc-attempts in the headers of a retry/hedging RPC. */ @VisibleForTesting @@ -244,19 +259,37 @@ private void drain(Substream substream) { int index = 0; int chunk = 0x80; List list = null; + boolean streamStarted = false; + Runnable onReadyRunnable = null; while (true) { State savedState; synchronized (lock) { savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me - break; + if (streamStarted) { + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; + } } if (index == savedState.buffer.size()) { // I'm drained state = savedState.substreamDrained(substream); - return; + if (!isReady()) { + return; + } + onReadyRunnable = new Runnable() { + @Override + public void run() { + if (!isClosed) { + masterListener.onReady(); + } + } + }; + break; } if (substream.closed) { @@ -274,22 +307,30 @@ private void drain(Substream substream) { } for (BufferEntry bufferEntry : list) { - savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me - break; + bufferEntry.runWith(substream); + if (bufferEntry instanceof RetriableStream.StartEntry) { + streamStarted = true; } - if (savedState.cancelled) { - checkState( - savedState.winningSubstream == substream, - "substream should be CANCELLED_BECAUSE_COMMITTED already"); - return; + if (streamStarted) { + savedState = state; + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; + } } - bufferEntry.runWith(substream); } } - substream.stream.cancel(CANCELLED_BECAUSE_COMMITTED); + if (onReadyRunnable != null) { + listenerSerializeExecutor.execute(onReadyRunnable); + return; + } + + substream.stream.cancel( + state.winningSubstream == substream ? cancellationStatus : CANCELLED_BECAUSE_COMMITTED); } /** @@ -299,6 +340,13 @@ private void drain(Substream substream) { @Nullable abstract Status prestart(); + class StartEntry implements BufferEntry { + @Override + public void runWith(Substream substream) { + substream.stream.start(new Sublistener(substream)); + } + } + /** Starts the first PRC attempt. */ @Override public final void start(ClientStreamListener listener) { @@ -311,18 +359,11 @@ public final void start(ClientStreamListener listener) { return; } - class StartEntry implements BufferEntry { - @Override - public void runWith(Substream substream) { - substream.stream.start(new Sublistener(substream)); - } - } - synchronized (lock) { state.buffer.add(new StartEntry()); } - Substream substream = createSubstream(0); + Substream substream = createSubstream(0, false); if (isHedging) { FutureCanceller scheduledHedgingRef = null; @@ -399,7 +440,7 @@ public void run() { // If this run is not cancelled, the value of state.hedgingAttemptCount won't change // until state.addActiveHedge() is called subsequently, even the state could possibly // change. - Substream newSubstream = createSubstream(state.hedgingAttemptCount); + Substream newSubstream = createSubstream(state.hedgingAttemptCount, false); boolean cancelled = false; FutureCanceller future = null; @@ -439,22 +480,37 @@ public void run() { } @Override - public final void cancel(Status reason) { + public final void cancel(final Status reason) { Substream noopSubstream = new Substream(0 /* previousAttempts doesn't matter here */); noopSubstream.stream = new NoopClientStream(); Runnable runnable = commit(noopSubstream); if (runnable != null) { - masterListener.closed(reason, RpcProgress.PROCESSED, new Metadata()); runnable.run(); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(reason, RpcProgress.PROCESSED, new Metadata()); + + } + }); return; } - state.winningSubstream.stream.cancel(reason); + Substream winningSubstreamToCancel = null; synchronized (lock) { - // This is not required, but causes a short-circuit in the draining process. + if (state.drainedSubstreams.contains(state.winningSubstream)) { + winningSubstreamToCancel = state.winningSubstream; + } else { // the winningSubstream will be cancelled while draining + cancellationStatus = reason; + } state = state.cancelled(); } + if (winningSubstreamToCancel != null) { + winningSubstreamToCancel.stream.cancel(reason); + } } private void delayOrExecute(BufferEntry bufferEntry) { @@ -753,18 +809,25 @@ private final class Sublistener implements ClientStreamListener { } @Override - public void headersRead(Metadata headers) { + public void headersRead(final Metadata headers) { commitAndRun(substream); if (state.winningSubstream == substream) { - masterListener.headersRead(headers); if (throttle != null) { throttle.onSuccess(); } + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + masterListener.headersRead(headers); + } + }); } } @Override - public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + public void closed( + final Status status, final RpcProgress rpcProgress, final Metadata trailers) { synchronized (lock) { state = state.substreamClosed(substream); closedSubstreamsInsight.append(status.getCode()); @@ -775,7 +838,14 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (substream.bufferLimitExceeded) { commitAndRun(substream); if (state.winningSubstream == substream) { - masterListener.closed(status, rpcProgress, trailers); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(status, rpcProgress, trailers); + } + }); } return; } @@ -784,8 +854,7 @@ public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { if (rpcProgress == RpcProgress.REFUSED && noMoreTransparentRetry.compareAndSet(false, true)) { // transparent retry - final Substream newSubstream = createSubstream( - substream.previousAttemptCount); + final Substream newSubstream = createSubstream(substream.previousAttemptCount, true); if (isHedging) { boolean commit = false; synchronized (lock) { @@ -853,23 +922,26 @@ public void run() { synchronized (lock) { scheduledRetry = scheduledRetryCopy = new FutureCanceller(lock); } - scheduledRetryCopy.setFuture( - scheduledExecutorService.schedule( + class RetryBackoffRunnable implements Runnable { + @Override + public void run() { + callExecutor.execute( new Runnable() { @Override public void run() { - callExecutor.execute( - new Runnable() { - @Override - public void run() { - // retry - Substream newSubstream = - createSubstream(substream.previousAttemptCount + 1); - drain(newSubstream); - } - }); + // retry + Substream newSubstream = createSubstream( + substream.previousAttemptCount + 1, + false); + drain(newSubstream); } - }, + }); + } + } + + scheduledRetryCopy.setFuture( + scheduledExecutorService.schedule( + new RetryBackoffRunnable(), retryPlan.backoffNanos, TimeUnit.NANOSECONDS)); return; @@ -880,7 +952,14 @@ public void run() { commitAndRun(substream); if (state.winningSubstream == substream) { - masterListener.closed(status, rpcProgress, trailers); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + isClosed = true; + masterListener.closed(status, rpcProgress, trailers); + } + }); } } @@ -950,22 +1029,37 @@ private Integer getPushbackMills(Metadata trailer) { } @Override - public void messagesAvailable(MessageProducer producer) { + public void messagesAvailable(final MessageProducer producer) { State savedState = state; checkState( savedState.winningSubstream != null, "Headers should be received prior to messages."); if (savedState.winningSubstream != substream) { return; } - masterListener.messagesAvailable(producer); + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + masterListener.messagesAvailable(producer); + } + }); } @Override public void onReady() { // FIXME(#7089): hedging case is broken. - // TODO(zdapeng): optimization: if the substream is not drained yet, delay onReady() once - // drained and if is still ready. - masterListener.onReady(); + if (!isReady()) { + return; + } + listenerSerializeExecutor.execute( + new Runnable() { + @Override + public void run() { + if (!isClosed) { + masterListener.onReady(); + } + } + }); } } diff --git a/core/src/main/java/io/grpc/internal/ServerCallImpl.java b/core/src/main/java/io/grpc/internal/ServerCallImpl.java index 6f123e76678..f82d87cade0 100644 --- a/core/src/main/java/io/grpc/internal/ServerCallImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerCallImpl.java @@ -279,7 +279,11 @@ public ServerStreamListenerImpl( new Context.CancellationListener() { @Override public void cancelled(Context context) { - ServerStreamListenerImpl.this.call.cancelled = true; + // If the context has a cancellation cause then something exceptional happened + // and we should also mark the call as cancelled. + if (context.cancellationCause() != null) { + ServerStreamListenerImpl.this.call.cancelled = true; + } } }, MoreExecutors.directExecutor()); @@ -355,6 +359,8 @@ private void closedInternal(Status status) { } finally { // Cancel context after delivering RPC closure notification to allow the application to // clean up and update any state based on whether onComplete or onCancel was called. + // Note that in failure situations JumpToApplicationThreadServerStreamListener has already + // closed the context. In these situations this cancel() call will be a no-op. context.cancel(null); } } diff --git a/core/src/main/java/io/grpc/internal/StatsTraceContext.java b/core/src/main/java/io/grpc/internal/StatsTraceContext.java index adb0b63ec8a..33e84e5a0b8 100644 --- a/core/src/main/java/io/grpc/internal/StatsTraceContext.java +++ b/core/src/main/java/io/grpc/internal/StatsTraceContext.java @@ -20,7 +20,6 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; -import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Metadata; @@ -48,21 +47,12 @@ public final class StatsTraceContext { * Factory method for the client-side. */ public static StatsTraceContext newClientContext( - final CallOptions callOptions, final Attributes transportAttrs, Metadata headers) { - List factories = callOptions.getStreamTracerFactories(); - if (factories.isEmpty()) { - return NOOP; + ClientStreamTracer[] tracers, Attributes transportAtts, Metadata headers) { + StatsTraceContext ctx = new StatsTraceContext(tracers); + for (ClientStreamTracer tracer : tracers) { + tracer.streamCreated(transportAtts, headers); } - ClientStreamTracer.StreamInfo info = - ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs(transportAttrs).setCallOptions(callOptions).build(); - // This array will be iterated multiple times per RPC. Use primitive array instead of Collection - // so that for-each doesn't create an Iterator every time. - StreamTracer[] tracers = new StreamTracer[factories.size()]; - for (int i = 0; i < tracers.length; i++) { - tracers[i] = factories.get(i).newClientStreamTracer(info, headers); - } - return new StatsTraceContext(tracers); + return ctx; } /** diff --git a/core/src/main/java/io/grpc/internal/SubchannelChannel.java b/core/src/main/java/io/grpc/internal/SubchannelChannel.java index 6c316e4f185..a1d454ed2fb 100644 --- a/core/src/main/java/io/grpc/internal/SubchannelChannel.java +++ b/core/src/main/java/io/grpc/internal/SubchannelChannel.java @@ -22,6 +22,7 @@ import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.InternalConfigSelector; import io.grpc.Metadata; @@ -57,9 +58,11 @@ public ClientStream newStream(MethodDescriptor method, if (transport == null) { transport = notReadyTransport; } + ClientStreamTracer[] tracers = GrpcUtil.getClientStreamTracers( + callOptions, headers, 0, /* isTransparentRetry= */ false); Context origContext = context.attach(); try { - return transport.newStream(method, headers, callOptions); + return transport.newStream(method, headers, callOptions, tracers); } finally { context.detach(origContext); } diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java new file mode 100644 index 00000000000..adaa1e6e69a --- /dev/null +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -0,0 +1,234 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.ExperimentalApi; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.net.Socket; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.util.Arrays; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedKeyManager; + +/** + * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure + * advanced TLS features, such as private key and certificate chain reloading, etc. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") +public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { + private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); + + // The credential information sent to peers to prove our identity. + private volatile KeyInfo keyInfo; + + /** + * Constructs an AdvancedTlsX509KeyManager. + */ + public AdvancedTlsX509KeyManager() throws CertificateException { } + + @Override + public PrivateKey getPrivateKey(String alias) { + if (alias.equals("default")) { + return this.keyInfo.key; + } + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + if (alias.equals("default")) { + return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length); + } + return null; + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return new String[] {"default"}; + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return "default"; + } + + @Override + public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { + return "default"; + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return new String[] {"default"}; + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return "default"; + } + + @Override + public String chooseEngineServerAlias(String keyType, Principal[] issuers, + SSLEngine engine) { + return "default"; + } + + /** + * Updates the current cached private key and cert chains. + * + * @param key the private key that is going to be used + * @param certs the certificate chain that is going to be used + */ + public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) + throws CertificateException { + // TODO(ZhenLian): explore possibilities to do a crypto check here. + this.keyInfo = new KeyInfo(checkNotNull(key, "key"), checkNotNull(certs, "certs")); + } + + /** + * Schedules a {@code ScheduledExecutorService} to read private key and certificate chains from + * the local file paths periodically, and update the cached identity credentials if they are both + * updated. + * + * @param keyFile the file on disk holding the private key + * @param certFile the file on disk holding the certificate chain + * @param period the period between successive read-and-update executions + * @param unit the time unit of the initialDelay and period parameters + * @param executor the execute service we use to read and update the credentials + * @return an object that caller should close when the file refreshes are not needed + */ + public Closeable updateIdentityCredentialsFromFile(File keyFile, File certFile, + long period, TimeUnit unit, ScheduledExecutorService executor) { + final ScheduledFuture future = + executor.scheduleWithFixedDelay( + new LoadFilePathExecution(keyFile, certFile), 0, period, unit); + return new Closeable() { + @Override public void close() { + future.cancel(false); + } + }; + } + + private static class KeyInfo { + // The private key and the cert chain we will use to send to peers to prove our identity. + final PrivateKey key; + final X509Certificate[] certs; + + public KeyInfo(PrivateKey key, X509Certificate[] certs) { + this.key = key; + this.certs = certs; + } + } + + private class LoadFilePathExecution implements Runnable { + File keyFile; + File certFile; + long currentKeyTime; + long currentCertTime; + + public LoadFilePathExecution(File keyFile, File certFile) { + this.keyFile = keyFile; + this.certFile = certFile; + this.currentKeyTime = 0; + this.currentCertTime = 0; + } + + @Override + public void run() { + try { + UpdateResult newResult = readAndUpdate(this.keyFile, this.certFile, this.currentKeyTime, + this.currentCertTime); + if (newResult.success) { + this.currentKeyTime = newResult.keyTime; + this.currentCertTime = newResult.certTime; + } + } catch (CertificateException | IOException | NoSuchAlgorithmException + | InvalidKeySpecException e) { + log.log(Level.SEVERE, "Failed refreshing private key and certificate chain from files. " + + "Using previous ones", e); + } + } + } + + private static class UpdateResult { + boolean success; + long keyTime; + long certTime; + + public UpdateResult(boolean success, long keyTime, long certTime) { + this.success = success; + this.keyTime = keyTime; + this.certTime = certTime; + } + } + + /** + * Reads the private key and certificates specified in the path locations. Updates {@code key} and + * {@code cert} if both of their modified time changed since last read. + * + * @param keyFile the file on disk holding the private key + * @param certFile the file on disk holding the certificate chain + * @param oldKeyTime the time when the private key file is modified during last execution + * @param oldCertTime the time when the certificate chain file is modified during last execution + * @return the result of this update execution + */ + private UpdateResult readAndUpdate(File keyFile, File certFile, long oldKeyTime, long oldCertTime) + throws IOException, CertificateException, NoSuchAlgorithmException, InvalidKeySpecException { + long newKeyTime = keyFile.lastModified(); + long newCertTime = certFile.lastModified(); + // We only update when both the key and the certs are updated. + if (newKeyTime != oldKeyTime && newCertTime != oldCertTime) { + FileInputStream keyInputStream = new FileInputStream(keyFile); + try { + PrivateKey key = CertificateUtils.getPrivateKey(keyInputStream); + FileInputStream certInputStream = new FileInputStream(certFile); + try { + X509Certificate[] certs = CertificateUtils.getX509Certificates(certInputStream); + updateIdentityCredentials(key, certs); + return new UpdateResult(true, newKeyTime, newCertTime); + } finally { + certInputStream.close(); + } + } finally { + keyInputStream.close(); + } + } + return new UpdateResult(false, oldKeyTime, oldCertTime); + } + + /** + * Mainly used to avoid throwing IO Exceptions in java.io.Closeable. + */ + public interface Closeable extends java.io.Closeable { + @Override public void close(); + } +} + diff --git a/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java new file mode 100644 index 00000000000..f6e366d3219 --- /dev/null +++ b/core/src/main/java/io/grpc/util/AdvancedTlsX509TrustManager.java @@ -0,0 +1,361 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.util; + +import io.grpc.ExperimentalApi; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.net.Socket; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; + +/** + * AdvancedTlsX509TrustManager is an {@code X509ExtendedTrustManager} that allows users to configure + * advanced TLS features, such as root certificate reloading, peer cert custom verification, etc. + * For Android users: this class is only supported in API level 24 and above. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8024") +@IgnoreJRERequirement +public final class AdvancedTlsX509TrustManager extends X509ExtendedTrustManager { + private static final Logger log = Logger.getLogger(AdvancedTlsX509TrustManager.class.getName()); + + private final Verification verification; + private final SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier; + + // The delegated trust manager used to perform traditional certificate verification. + private volatile X509ExtendedTrustManager delegateManager = null; + + private AdvancedTlsX509TrustManager(Verification verification, + SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier) throws CertificateException { + this.verification = verification; + this.socketAndEnginePeerVerifier = socketAndEnginePeerVerifier; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException( + "Not enough information to validate peer. SSLEngine or Socket required."); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + checkTrusted(chain, authType, null, socket, false); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + checkTrusted(chain, authType, engine, null, false); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + checkTrusted(chain, authType, engine, null, true); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException( + "Not enough information to validate peer. SSLEngine or Socket required."); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + checkTrusted(chain, authType, null, socket, true); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + if (this.delegateManager == null) { + return new X509Certificate[0]; + } + return this.delegateManager.getAcceptedIssuers(); + } + + /** + * Uses the default trust certificates stored on user's local system. + * After this is used, functions that will provide new credential + * data(e.g. updateTrustCredentials(), updateTrustCredentialsFromFile()) should not be called. + */ + public void useSystemDefaultTrustCerts() throws CertificateException, KeyStoreException, + NoSuchAlgorithmException { + // Passing a null value of KeyStore would make {@code TrustManagerFactory} attempt to use + // system-default trust CA certs. + this.delegateManager = createDelegateTrustManager(null); + } + + /** + * Updates the current cached trust certificates as well as the key store. + * + * @param trustCerts the trust certificates that are going to be used + */ + public void updateTrustCredentials(X509Certificate[] trustCerts) throws CertificateException, + KeyStoreException, NoSuchAlgorithmException, IOException { + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(null, null); + int i = 1; + for (X509Certificate cert: trustCerts) { + String alias = Integer.toString(i); + keyStore.setCertificateEntry(alias, cert); + i++; + } + X509ExtendedTrustManager newDelegateManager = createDelegateTrustManager(keyStore); + this.delegateManager = newDelegateManager; + } + + private static X509ExtendedTrustManager createDelegateTrustManager(KeyStore keyStore) + throws CertificateException, KeyStoreException, NoSuchAlgorithmException { + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(keyStore); + X509ExtendedTrustManager delegateManager = null; + TrustManager[] tms = tmf.getTrustManagers(); + // Iterate over the returned trust managers, looking for an instance of X509TrustManager. + // If found, use that as the delegate trust manager. + for (int j = 0; j < tms.length; j++) { + if (tms[j] instanceof X509ExtendedTrustManager) { + delegateManager = (X509ExtendedTrustManager) tms[j]; + break; + } + } + if (delegateManager == null) { + throw new CertificateException( + "Failed to find X509ExtendedTrustManager with default TrustManager algorithm " + + TrustManagerFactory.getDefaultAlgorithm()); + } + return delegateManager; + } + + private void checkTrusted(X509Certificate[] chain, String authType, SSLEngine sslEngine, + Socket socket, boolean checkingServer) throws CertificateException { + if (chain == null || chain.length == 0) { + throw new IllegalArgumentException( + "Want certificate verification but got null or empty certificates"); + } + if (sslEngine == null && socket == null) { + throw new CertificateException( + "Not enough information to validate peer. SSLEngine or Socket required."); + } + if (this.verification != Verification.INSECURELY_SKIP_ALL_VERIFICATION) { + X509ExtendedTrustManager currentDelegateManager = this.delegateManager; + if (currentDelegateManager == null) { + throw new CertificateException("No trust roots configured"); + } + if (checkingServer) { + String algorithm = this.verification == Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION + ? "HTTPS" : ""; + if (sslEngine != null) { + SSLParameters sslParams = sslEngine.getSSLParameters(); + sslParams.setEndpointIdentificationAlgorithm(algorithm); + sslEngine.setSSLParameters(sslParams); + currentDelegateManager.checkServerTrusted(chain, authType, sslEngine); + } else { + if (!(socket instanceof SSLSocket)) { + throw new CertificateException("socket is not a type of SSLSocket"); + } + SSLSocket sslSocket = (SSLSocket)socket; + SSLParameters sslParams = sslSocket.getSSLParameters(); + sslParams.setEndpointIdentificationAlgorithm(algorithm); + sslSocket.setSSLParameters(sslParams); + currentDelegateManager.checkServerTrusted(chain, authType, sslSocket); + } + } else { + currentDelegateManager.checkClientTrusted(chain, authType, sslEngine); + } + } + // Perform the additional peer cert check. + if (socketAndEnginePeerVerifier != null) { + if (sslEngine != null) { + socketAndEnginePeerVerifier.verifyPeerCertificate(chain, authType, sslEngine); + } else { + socketAndEnginePeerVerifier.verifyPeerCertificate(chain, authType, socket); + } + } + } + + /** + * Schedules a {@code ScheduledExecutorService} to read trust certificates from a local file path + * periodically, and update the cached trust certs if there is an update. + * + * @param trustCertFile the file on disk holding the trust certificates + * @param period the period between successive read-and-update executions + * @param unit the time unit of the initialDelay and period parameters + * @param executor the execute service we use to read and update the credentials + * @return an object that caller should close when the file refreshes are not needed + */ + public Closeable updateTrustCredentialsFromFile(File trustCertFile, long period, TimeUnit unit, + ScheduledExecutorService executor) { + final ScheduledFuture future = + executor.scheduleWithFixedDelay( + new LoadFilePathExecution(trustCertFile), 0, period, unit); + return new Closeable() { + @Override public void close() { + future.cancel(false); + } + }; + } + + private class LoadFilePathExecution implements Runnable { + File file; + long currentTime; + + public LoadFilePathExecution(File file) { + this.file = file; + this.currentTime = 0; + } + + @Override + public void run() { + try { + this.currentTime = readAndUpdate(this.file, this.currentTime); + } catch (CertificateException | IOException | KeyStoreException + | NoSuchAlgorithmException e) { + log.log(Level.SEVERE, "Failed refreshing trust CAs from file. Using previous CAs", e); + } + } + } + + /** + * Reads the trust certificates specified in the path location, and update the key store if the + * modified time has changed since last read. + * + * @param trustCertFile the file on disk holding the trust certificates + * @param oldTime the time when the trust file is modified during last execution + * @return oldTime if failed or the modified time is not changed, otherwise the new modified time + */ + private long readAndUpdate(File trustCertFile, long oldTime) + throws CertificateException, IOException, KeyStoreException, NoSuchAlgorithmException { + long newTime = trustCertFile.lastModified(); + if (newTime == oldTime) { + return oldTime; + } + FileInputStream inputStream = new FileInputStream(trustCertFile); + try { + X509Certificate[] certificates = CertificateUtils.getX509Certificates(inputStream); + updateTrustCredentials(certificates); + return newTime; + } finally { + inputStream.close(); + } + } + + // Mainly used to avoid throwing IO Exceptions in java.io.Closeable. + public interface Closeable extends java.io.Closeable { + @Override public void close(); + } + + public static Builder newBuilder() { + return new Builder(); + } + + // The verification mode when authenticating the peer certificate. + public enum Verification { + // This is the DEFAULT and RECOMMENDED mode for most applications. + // Setting this on the client side will do the certificate and hostname verification, while + // setting this on the server side will only do the certificate verification. + CERTIFICATE_AND_HOST_NAME_VERIFICATION, + // This SHOULD be chosen only when you know what the implication this will bring, and have a + // basic understanding about TLS. + // It SHOULD be accompanied with proper additional peer identity checks set through + // {@code PeerVerifier}(nit: why this @code not working?). Failing to do so will leave + // applications to MITM attack. + // Also note that this will only take effect if the underlying SDK implementation invokes + // checkClientTrusted/checkServerTrusted with the {@code SSLEngine} parameter while doing + // verification. + // Setting this on either side will only do the certificate verification. + CERTIFICATE_ONLY_VERIFICATION, + // Setting is very DANGEROUS. Please try to avoid this in a real production environment, unless + // you are a super advanced user intended to re-implement the whole verification logic on your + // own. A secure verification might include: + // 1. proper verification on the peer certificate chain + // 2. proper checks on the identity of the peer certificate + INSECURELY_SKIP_ALL_VERIFICATION, + } + + // Additional custom peer verification check. + // It will be used when checkClientTrusted/checkServerTrusted is called with the {@code Socket} or + // the {@code SSLEngine} parameter. + public interface SslSocketAndEnginePeerVerifier { + /** + * Verifies the peer certificate chain. For more information, please refer to + * {@code X509ExtendedTrustManager}. + * + * @param peerCertChain the certificate chain sent from the peer + * @param authType the key exchange algorithm used, e.g. "RSA", "DHE_DSS", etc + * @param socket the socket used for this connection. This parameter can be null, which + * indicates that implementations need not check the ssl parameters + */ + void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, Socket socket) + throws CertificateException; + + /** + * Verifies the peer certificate chain. For more information, please refer to + * {@code X509ExtendedTrustManager}. + * + * @param peerCertChain the certificate chain sent from the peer + * @param authType the key exchange algorithm used, e.g. "RSA", "DHE_DSS", etc + * @param engine the engine used for this connection. This parameter can be null, which + * indicates that implementations need not check the ssl parameters + */ + void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, SSLEngine engine) + throws CertificateException; + } + + public static final class Builder { + + private Verification verification = Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION; + private SslSocketAndEnginePeerVerifier socketAndEnginePeerVerifier; + + private Builder() {} + + public Builder setVerification(Verification verification) { + this.verification = verification; + return this; + } + + public Builder setSslSocketAndEnginePeerVerifier(SslSocketAndEnginePeerVerifier verifier) { + this.socketAndEnginePeerVerifier = verifier; + return this; + } + + public AdvancedTlsX509TrustManager build() throws CertificateException { + return new AdvancedTlsX509TrustManager(this.verification, this.socketAndEnginePeerVerifier); + } + } +} + diff --git a/core/src/main/java/io/grpc/util/CertificateUtils.java b/core/src/main/java/io/grpc/util/CertificateUtils.java index e8bbc90cb36..980862d3836 100644 --- a/core/src/main/java/io/grpc/util/CertificateUtils.java +++ b/core/src/main/java/io/grpc/util/CertificateUtils.java @@ -65,36 +65,24 @@ public static X509Certificate[] getX509Certificates(InputStream inputStream) public static PrivateKey getPrivateKey(InputStream inputStream) throws UnsupportedEncodingException, IOException, NoSuchAlgorithmException, InvalidKeySpecException { - InputStreamReader isr = null; - BufferedReader reader = null; - try { - isr = new InputStreamReader(inputStream, "UTF-8"); - reader = new BufferedReader(isr); - String line; - while ((line = reader.readLine()) != null) { - if ("-----BEGIN PRIVATE KEY-----".equals(line)) { - break; - } + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, "UTF-8")); + String line; + while ((line = reader.readLine()) != null) { + if ("-----BEGIN PRIVATE KEY-----".equals(line)) { + break; } - StringBuilder keyContent = new StringBuilder(); - while ((line = reader.readLine()) != null) { - if ("-----END PRIVATE KEY-----".equals(line)) { - break; - } - keyContent.append(line); - } - byte[] decodedKeyBytes = BaseEncoding.base64().decode(keyContent.toString()); - KeyFactory keyFactory = KeyFactory.getInstance("RSA"); - PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKeyBytes); - return keyFactory.generatePrivate(keySpec); - } finally { - if (null != reader) { - reader.close(); - } - if (null != isr) { - isr.close(); + } + StringBuilder keyContent = new StringBuilder(); + while ((line = reader.readLine()) != null) { + if ("-----END PRIVATE KEY-----".equals(line)) { + break; } + keyContent.append(line); } + byte[] decodedKeyBytes = BaseEncoding.base64().decode(keyContent.toString()); + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKeyBytes); + return keyFactory.generatePrivate(keySpec); } } diff --git a/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java b/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java index de7d12e397c..7bb9d8cf71a 100644 --- a/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java +++ b/core/src/main/java/io/grpc/util/ForwardingClientStreamTracer.java @@ -17,6 +17,7 @@ package io.grpc.util; import com.google.common.base.MoreObjects; +import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.ExperimentalApi; import io.grpc.Metadata; @@ -27,6 +28,11 @@ public abstract class ForwardingClientStreamTracer extends ClientStreamTracer { /** Returns the underlying {@code ClientStreamTracer}. */ protected abstract ClientStreamTracer delegate(); + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + delegate().streamCreated(transportAttrs, headers); + } + @Override public void outboundHeaders() { delegate().outboundHeaders(); diff --git a/core/src/test/java/io/grpc/ClientStreamTracerTest.java b/core/src/test/java/io/grpc/ClientStreamTracerTest.java index 2008a3de5c7..df450adc630 100644 --- a/core/src/test/java/io/grpc/ClientStreamTracerTest.java +++ b/core/src/test/java/io/grpc/ClientStreamTracerTest.java @@ -34,6 +34,7 @@ public class ClientStreamTracerTest { Attributes.newBuilder().set(TRANSPORT_ATTR_KEY, "value").build(); @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_empty() { StreamInfo info = StreamInfo.newBuilder().build(); assertThat(info.getCallOptions()).isSameInstanceAs(CallOptions.DEFAULT); @@ -41,6 +42,7 @@ public void streamInfo_empty() { } @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_withInfo() { StreamInfo info = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); @@ -49,6 +51,7 @@ public void streamInfo_withInfo() { } @Test + @SuppressWarnings("deprecation") // info.setTransportAttrs() public void streamInfo_noEquality() { StreamInfo info1 = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); @@ -60,6 +63,7 @@ public void streamInfo_noEquality() { } @Test + @SuppressWarnings("deprecation") // info.getTransportAttrs() public void streamInfo_toBuilder() { StreamInfo info1 = StreamInfo.newBuilder() .setCallOptions(callOptions).setTransportAttrs(transportAttrs).build(); diff --git a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java index 091415efadc..cd522181311 100644 --- a/core/src/test/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractTransportTest.java @@ -48,7 +48,6 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; import io.grpc.ClientStreamTracer; -import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Grpc; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -172,7 +171,7 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} .setRequestMarshaller(StringMarshaller.INSTANCE) .setResponseMarshaller(StringMarshaller.INSTANCE) .build(); - private CallOptions callOptions; + private final CallOptions callOptions = CallOptions.DEFAULT; private Metadata.Key asciiKey = Metadata.Key.of( "ascii-key", Metadata.ASCII_STRING_MARSHALLER); @@ -186,24 +185,14 @@ public void log(ChannelLogLevel level, String messageFormat, Object... args) {} = mock(ManagedClientTransport.Listener.class); private MockServerListener serverListener = new MockServerListener(); private ArgumentCaptor throwableCaptor = ArgumentCaptor.forClass(Throwable.class); - private final TestClientStreamTracer clientStreamTracer1 = new TestClientStreamTracer(); - private final TestClientStreamTracer clientStreamTracer2 = new TestClientStreamTracer(); - private final ClientStreamTracer.Factory clientStreamTracerFactory = mock( - ClientStreamTracer.Factory.class, - delegatesTo(new ClientStreamTracer.Factory() { - final ArrayDeque tracers = - new ArrayDeque<>(Arrays.asList(clientStreamTracer1, clientStreamTracer2)); - - @Override - public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata metadata) { - metadata.put(tracerHeaderKey, tracerKeyValue); - TestClientStreamTracer tracer = tracers.poll(); - if (tracer != null) { - return tracer; - } - return new TestClientStreamTracer(); - } - })); + private final TestClientStreamTracer clientStreamTracer1 = new TestHeaderClientStreamTracer(); + private final TestClientStreamTracer clientStreamTracer2 = new TestHeaderClientStreamTracer(); + private final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + clientStreamTracer1, clientStreamTracer2 + }; + private final ClientStreamTracer[] noopTracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final TestServerStreamTracer serverStreamTracer1 = new TestServerStreamTracer(); private final TestServerStreamTracer serverStreamTracer2 = new TestServerStreamTracer(); @@ -230,7 +219,6 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata @Before public void setUp() { server = newServer(Arrays.asList(serverStreamTracerFactory)); - callOptions = CallOptions.DEFAULT.withStreamTracerFactory(clientStreamTracerFactory); } @After @@ -291,7 +279,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { // after having sent a RST_STREAM to the server. Previously, this would have broken the // Netty channel. - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -314,7 +303,8 @@ public void frameAfterRstStreamShouldNotBreakClientChannel() throws Exception { // Test that the channel is still usable i.e. we can receive headers from the server on a // new stream. - stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); stream.start(mockClientStreamListener2); serverStreamCreation = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); @@ -449,7 +439,8 @@ public void openStreamPreventsTermination() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -501,7 +492,8 @@ public void shutdownNowKillsClientStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -539,7 +531,8 @@ public void shutdownNowKillsServerStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -594,7 +587,8 @@ public void ping_duringShutdown() throws Exception { client = newClientTransport(server); startTransport(client, mockClientTransportListener); // Stream prevents termination - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); @@ -633,22 +627,19 @@ public void ping_afterTermination() throws Exception { @Test public void newStream_duringShutdown() throws Exception { - InOrder inOrder = inOrder(clientStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); // Stream prevents termination - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); - inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); client.shutdown(Status.UNAVAILABLE); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportShutdown(any(Status.class)); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); - inOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); Status clientStreamStatus2 = @@ -683,15 +674,14 @@ public void newStream_afterTermination() throws Exception { client.shutdown(shutdownReason); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportTerminated(); Thread.sleep(100); - ClientStream stream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); stream.start(clientStreamListener); assertEquals( shutdownReason, clientStreamListener.status.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); assertNotNull(clientStreamListener.trailers.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)); verify(mockClientTransportListener, never()).transportInUse(anyBoolean()); - verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(shutdownReason, clientStreamTracer1.getStatus()); // Assert no interactions @@ -708,7 +698,8 @@ public void transportInUse_balancerRpcsNotCounted() throws Exception { // CALL_OPTIONS_RPC_OWNED_BY_BALANCER in CallOptions. It won't be counted for in-use signal. ClientStream stream1 = client.newStream( methodDescriptor, new Metadata(), - callOptions.withOption(GrpcUtil.CALL_OPTIONS_RPC_OWNED_BY_BALANCER, Boolean.TRUE)); + callOptions.withOption(GrpcUtil.CALL_OPTIONS_RPC_OWNED_BY_BALANCER, Boolean.TRUE), + noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); MockServerTransportListener serverTransportListener @@ -717,7 +708,8 @@ methodDescriptor, new Metadata(), = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); // stream2 is the normal RPC, and will be counted for in-use - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -743,7 +735,8 @@ public void transportInUse_normalClose() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream1 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); @@ -751,7 +744,8 @@ public void transportInUse_normalClose() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); StreamCreation serverStreamCreation1 = serverTransportListener.takeStreamOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); StreamCreation serverStreamCreation2 @@ -773,11 +767,13 @@ public void transportInUse_clientCancel() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream stream1 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream1 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener1 = new ClientStreamListenerBase(); stream1.start(clientStreamListener1); verify(mockClientTransportListener, timeout(TIMEOUT_MS)).transportInUse(true); - ClientStream stream2 = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream stream2 = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener2 = new ClientStreamListenerBase(); stream2.start(clientStreamListener2); @@ -792,7 +788,6 @@ public void transportInUse_clientCancel() throws Exception { @Test public void basicStream() throws Exception { - InOrder clientInOrder = inOrder(clientStreamTracerFactory); InOrder serverInOrder = inOrder(serverStreamTracerFactory); server.start(serverListener); client = newClientTransport(server); @@ -816,14 +811,10 @@ public void basicStream() throws Exception { Metadata clientHeadersCopy = new Metadata(); clientHeadersCopy.merge(clientHeaders); - ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); - ArgumentCaptor streamInfoCaptor = ArgumentCaptor.forClass(null); - clientInOrder.verify(clientStreamTracerFactory).newClientStreamTracer( - streamInfoCaptor.capture(), same(clientHeaders)); - ClientStreamTracer.StreamInfo streamInfo = streamInfoCaptor.getValue(); - assertThat(streamInfo.getTransportAttrs()).isSameInstanceAs( - ((ConnectionClientTransport) client).getAttributes()); - assertThat(streamInfo.getCallOptions()).isSameInstanceAs(callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, clientHeaders, callOptions, tracers); + assertThat(((TestHeaderClientStreamTracer) clientStreamTracer1).transportAttrs) + .isSameInstanceAs(((ConnectionClientTransport) client).getAttributes()); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -974,7 +965,8 @@ public void authorityPropagation() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); Metadata clientHeaders = new Metadata(); - ClientStream clientStream = client.newStream(methodDescriptor, clientHeaders, callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, clientHeaders, callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1005,7 +997,8 @@ public void zeroMessageStream() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1044,7 +1037,8 @@ public void earlyServerClose_withServerHeaders() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1080,7 +1074,8 @@ public void earlyServerClose_noServerHeaders() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1122,7 +1117,8 @@ public void earlyServerClose_serverFailure() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1155,7 +1151,8 @@ public void earlyServerClose_serverFailure_withClientCancelOnListenerClosed() th serverTransport = serverTransportListener.transport; final ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase() { @Override @@ -1196,7 +1193,8 @@ public void clientCancel() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1230,7 +1228,8 @@ public void clientCancelFromWithinMessageRead() throws Exception { final SettableFuture closedCalled = SettableFuture.create(); final ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); final Status status = Status.CANCELLED.withDescription("nevermind"); clientStream.start(new ClientStreamListener() { private boolean messageReceived = false; @@ -1311,7 +1310,8 @@ public void serverCancel() throws Exception { = serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation @@ -1331,8 +1331,6 @@ public void serverCancel() throws Exception { // Cause should not be transmitted between server and client assertNull(clientStreamStatus.getCause()); - verify(clientStreamTracerFactory).newClientStreamTracer( - any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)); assertTrue(clientStreamTracer1.getOutboundHeaders()); assertNull(clientStreamTracer1.getInboundTrailers()); assertSame(clientStreamStatus, clientStreamTracer1.getStatus()); @@ -1353,7 +1351,8 @@ public void flowControlPushBack() throws Exception { serverListener.takeListenerOrFail(TIMEOUT_MS, TimeUnit.MILLISECONDS); serverTransport = serverTransportListener.transport; - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = @@ -1515,7 +1514,8 @@ public void interactionsAfterServerStreamCloseAreNoops() throws Exception { // boilerplate ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation server @@ -1547,7 +1547,8 @@ public void interactionsAfterClientStreamCancelAreNoops() throws Exception { // boilerplate ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListener clientListener = mock(ClientStreamListener.class); clientStream.start(clientListener); StreamCreation server @@ -1594,7 +1595,8 @@ public void transportTracer_streamStarted() throws Exception { assertEquals(0, clientBefore.streamsStarted); assertEquals(0, clientBefore.lastRemoteStreamCreatedTimeNanos); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener @@ -1624,7 +1626,8 @@ public void transportTracer_streamStarted() throws Exception { TransportStats clientBefore = getTransportStats(client); assertEquals(1, clientBefore.streamsStarted); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, noopTracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); StreamCreation serverStreamCreation = serverTransportListener @@ -1654,7 +1657,8 @@ public void transportTracer_server_streamEnded_ok() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1693,7 +1697,8 @@ public void transportTracer_server_streamEnded_nonOk() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1733,7 +1738,8 @@ public void transportTracer_client_streamEnded_nonOk() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener = @@ -1768,7 +1774,8 @@ public void transportTracer_server_receive_msg() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1809,7 +1816,8 @@ public void transportTracer_server_send_msg() throws Exception { server.start(serverListener); client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); MockServerTransportListener serverTransportListener @@ -1849,7 +1857,8 @@ public void socketStats() throws Exception { server.start(serverListener); ManagedClientTransport client = newClientTransport(server); startTransport(client, mockClientTransportListener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1896,8 +1905,8 @@ public void serverChecksInboundMetadataSize() throws Exception { Metadata.Key.of("foo-bin", Metadata.BINARY_BYTE_MARSHALLER), new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); - ClientStream clientStream = - client.newStream(methodDescriptor, tooLargeMetadata, callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, tooLargeMetadata, callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1931,7 +1940,8 @@ public void clientChecksInboundMetadataSize_header() throws Exception { new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -1975,7 +1985,8 @@ public void clientChecksInboundMetadataSize_trailer() throws Exception { new byte[GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE]); ClientStream clientStream = - client.newStream(methodDescriptor, new Metadata(), callOptions); + client.newStream( + methodDescriptor, new Metadata(), callOptions, tracers); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -2011,7 +2022,9 @@ private void doPingPong(MockServerListener serverListener) throws Exception { ManagedClientTransport client = newClientTransport(server); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); startTransport(client, listener); - ClientStream clientStream = client.newStream(methodDescriptor, new Metadata(), callOptions); + ClientStream clientStream = client.newStream( + methodDescriptor, new Metadata(), callOptions, + new ClientStreamTracer[] { new ClientStreamTracer() {} }); ClientStreamListenerBase clientStreamListener = new ClientStreamListenerBase(); clientStream.start(clientStreamListener); @@ -2092,6 +2105,16 @@ private static void startTransport( verify(listener, timeout(TIMEOUT_MS)).transportReady(); } + private final class TestHeaderClientStreamTracer extends TestClientStreamTracer { + Attributes transportAttrs; + + @Override + public void streamCreated(Attributes transportAttrs, Metadata metadata) { + this.transportAttrs = transportAttrs; + metadata.put(tracerHeaderKey, tracerKeyValue); + } + } + private static class MockServerListener implements ServerListener { public final BlockingQueue listeners = new LinkedBlockingQueue<>(); diff --git a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java index 7725c46726b..963a586319b 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentials2ApplyingTest.java @@ -34,6 +34,7 @@ import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -48,6 +49,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; @@ -103,6 +105,9 @@ public class CallCredentials2ApplyingTest { private static final Metadata.Key CREDS_KEY = Metadata.Key.of("test-creds", Metadata.ASCII_STRING_MARSHALLER); private static final String CREDS_VALUE = "some credentials"; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final Metadata origHeaders = new Metadata(); private ForwardingConnectionClientTransport transport; @@ -118,7 +123,9 @@ public void setUp() { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, null, mockExecutor); @@ -134,7 +141,7 @@ public void parameterPropagation_base() { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -155,7 +162,7 @@ public void parameterPropagation_transportSetSecurityLevel() { .build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -176,8 +183,10 @@ public void parameterPropagation_callOptionsSetAuthority() { when(mockTransport.getAttributes()).thenReturn(transportAttrs); Executor anotherExecutor = mock(Executor.class); - transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -199,13 +208,15 @@ public void credentialThrows() { any(io.grpc.CallCredentials2.MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -226,14 +237,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { any(RequestInfo.class), same(mockExecutor), any(io.grpc.CallCredentials2.MetadataApplier.class)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -254,12 +265,14 @@ public Void answer(InvocationOnMock invocation) throws Throwable { any(io.grpc.CallCredentials2.MetadataApplier.class)); FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + (FailingClientStream) transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertSame(error, stream.getError()); transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @@ -269,12 +282,15 @@ public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); transport.shutdown(Status.UNAVAILABLE); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); @@ -283,11 +299,11 @@ public void applyMetadata_delayed() { headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -297,7 +313,8 @@ public void fail_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata( @@ -306,11 +323,13 @@ public void fail_delayed() { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -318,14 +337,14 @@ public void fail_delayed() { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream(method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } diff --git a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java index 61a221f73de..ef49e66bf2d 100644 --- a/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java +++ b/core/src/test/java/io/grpc/internal/CallCredentialsApplyingTest.java @@ -35,6 +35,7 @@ import io.grpc.CallCredentials.RequestInfo; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.Metadata; import io.grpc.MethodDescriptor; @@ -49,6 +50,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; @@ -86,6 +88,9 @@ public class CallCredentialsApplyingTest { @Mock private ChannelLogger channelLogger; + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private static final String AUTHORITY = "testauthority"; private static final String USER_AGENT = "testuseragent"; private static final Attributes.Key ATTR_KEY = Attributes.Key.create("somekey"); @@ -117,7 +122,9 @@ public void setUp() { origHeaders.put(ORIG_HEADER_KEY, ORIG_HEADER_VALUE); when(mockTransportFactory.newClientTransport(address, clientTransportOptions, channelLogger)) .thenReturn(mockTransport); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); ClientTransportFactory transportFactory = new CallCredentialsApplyingTransportFactory( mockTransportFactory, null, mockExecutor); @@ -133,7 +140,7 @@ public void parameterPropagation_base() { Attributes transportAttrs = Attributes.newBuilder().set(ATTR_KEY, ATTR_VALUE).build(); when(mockTransport.getAttributes()).thenReturn(transportAttrs); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), same(mockExecutor), @@ -154,8 +161,10 @@ public void parameterPropagation_overrideByCallOptions() { when(mockTransport.getAttributes()).thenReturn(transportAttrs); Executor anotherExecutor = mock(Executor.class); - transport.newStream(method, origHeaders, - callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor)); + transport.newStream( + method, origHeaders, + callOptions.withAuthority("calloptions-authority").withExecutor(anotherExecutor), + tracers); ArgumentCaptor infoCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(infoCaptor.capture(), @@ -175,15 +184,17 @@ public void credentialThrows() { any(RequestInfo.class), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertEquals(Status.Code.UNAUTHENTICATED, stream.getError().getCode()); assertSame(ex, stream.getError().getCause()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -193,14 +204,15 @@ public void applyMetadata_inline() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -220,13 +232,15 @@ public Void answer(InvocationOnMock invocation) throws Throwable { }).when(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), any(CallCredentials.MetadataApplier.class)); - FailingClientStream stream = - (FailingClientStream) transport.newStream(method, origHeaders, callOptions); + FailingClientStream stream = (FailingClientStream) transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); assertSame(error, stream.getError()); transport.shutdownNow(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdownNow(Status.UNAVAILABLE); } @@ -236,23 +250,26 @@ public void applyMetadata_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); transport.shutdown(Status.UNAVAILABLE); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream.getRealStream()); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); @@ -261,20 +278,20 @@ public void applyMetadata_delayed() { @Test public void delayedShutdown_shutdownShutdownNowThenApply() { - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); transport.shutdownNow(Status.ABORTED); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(any(Status.class)); verify(mockTransport, never()).shutdownNow(any(Status.class)); Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); applierCaptor.getValue().apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); verify(mockTransport).shutdownNow(Status.ABORTED); @@ -282,12 +299,12 @@ public void delayedShutdown_shutdownShutdownNowThenApply() { @Test public void delayedShutdown_shutdownThenApplyThenShutdownNow() { - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(any(Status.class)); Metadata headers = new Metadata(); @@ -308,25 +325,25 @@ public void delayedShutdown_shutdownMulti() { Metadata headers = new Metadata(); headers.put(CREDS_KEY, CREDS_VALUE); - transport.newStream(method, origHeaders, callOptions); - transport.newStream(method, origHeaders, callOptions); - transport.newStream(method, origHeaders, callOptions); + transport.newStream(method, origHeaders, callOptions, tracers); + transport.newStream(method, origHeaders, callOptions, tracers); + transport.newStream(method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds, times(3)).applyRequestMetadata(any(RequestInfo.class), same(mockExecutor), applierCaptor.capture()); applierCaptor.getAllValues().get(1).apply(headers); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); applierCaptor.getAllValues().get(0).apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport, never()).shutdown(Status.UNAVAILABLE); applierCaptor.getAllValues().get(2).apply(headers); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -336,7 +353,8 @@ public void fail_delayed() { when(mockTransport.getAttributes()).thenReturn(Attributes.EMPTY); // Will call applyRequestMetadata(), which is no-op. - DelayedStream stream = (DelayedStream) transport.newStream(method, origHeaders, callOptions); + DelayedStream stream = (DelayedStream) transport.newStream( + method, origHeaders, callOptions, tracers); ArgumentCaptor applierCaptor = ArgumentCaptor.forClass(null); verify(mockCreds).applyRequestMetadata(any(RequestInfo.class), @@ -345,11 +363,13 @@ public void fail_delayed() { Status error = Status.FAILED_PRECONDITION.withDescription("channel not secure for creds"); applierCaptor.getValue().fail(error); - verify(mockTransport, never()).newStream(method, origHeaders, callOptions); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); FailingClientStream failingStream = (FailingClientStream) stream.getRealStream(); assertSame(error, failingStream.getError()); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -357,14 +377,15 @@ public void fail_delayed() { @Test public void noCreds() { callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); - verify(mockTransport).newStream(method, origHeaders, callOptions); + verify(mockTransport).newStream(method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertNull(origHeaders.get(CREDS_KEY)); assertEquals(ORIG_HEADER_VALUE, origHeaders.get(ORIG_HEADER_KEY)); transport.shutdown(Status.UNAVAILABLE); - assertTrue(transport.newStream(method, origHeaders, callOptions) + assertTrue(transport.newStream(method, origHeaders, callOptions, tracers) instanceof FailingClientStream); verify(mockTransport).shutdown(Status.UNAVAILABLE); } @@ -373,7 +394,8 @@ public void noCreds() { public void justCallOptionCreds() { callOptions = callOptions.withCallCredentials(new FakeCallCredentials(CREDS_KEY, CREDS_VALUE)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); @@ -388,7 +410,8 @@ public void justChannelCreds() { transportFactory.newClientTransport(address, clientTransportOptions, channelLogger); callOptions = callOptions.withCallCredentials(null); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); @@ -406,7 +429,8 @@ public void callOptionAndChanelCreds() { String creds2Value = "some more credentials"; callOptions = callOptions.withCallCredentials(new FakeCallCredentials(creds2Key, creds2Value)); - ClientStream stream = transport.newStream(method, origHeaders, callOptions); + ClientStream stream = transport.newStream( + method, origHeaders, callOptions, tracers); assertSame(mockStream, stream); assertEquals(CREDS_VALUE, origHeaders.get(CREDS_KEY)); diff --git a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java index 1808a4bd478..0e5e5f50599 100644 --- a/core/src/test/java/io/grpc/internal/ClientCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ClientCallImplTest.java @@ -37,6 +37,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; @@ -47,6 +48,7 @@ import io.grpc.CallOptions; import io.grpc.ClientCall; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.Codec; import io.grpc.Context; import io.grpc.Deadline; @@ -143,6 +145,8 @@ public void setUp() { any(Metadata.class), any(Context.class))) .thenReturn(stream); + when(streamTracerFactory.newClientStreamTracer(any(StreamInfo.class), any(Metadata.class))) + .thenReturn(new ClientStreamTracer() {}); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock in) { @@ -156,7 +160,7 @@ public Void answer(InvocationOnMock in) { @After public void tearDown() { - verifyNoInteractions(streamTracerFactory); + verifyNoMoreInteractions(streamTracerFactory); } @Test @@ -763,6 +767,7 @@ public void deadlineExceededBeforeCallStarted() { channelCallTracer, configSelector) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); + verify(streamTracerFactory).newClientStreamTracer(any(StreamInfo.class), any(Metadata.class)); verify(clientStreamProvider, never()) .newStream( (MethodDescriptor) any(MethodDescriptor.class), diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 9f48b8987d1..4cae565a19e 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -36,6 +36,7 @@ import static org.mockito.Mockito.when; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.IntegerMarshaller; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -57,6 +58,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -89,6 +91,9 @@ public class DelayedClientTransportTest { = CallOptions.Key.createWithDefault("shard-id", -1); private static final Status SHUTDOWN_STATUS = Status.UNAVAILABLE.withDescription("shutdown called"); + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; private final MethodDescriptor method = MethodDescriptor.newBuilder() @@ -122,9 +127,13 @@ public void uncaughtException(Thread t, Throwable e) { .thenReturn(PickResult.withSubchannel(mockSubchannel)); when(mockSubchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); when(mockInternalSubchannel.obtainActiveTransport()).thenReturn(mockRealTransport); - when(mockRealTransport.newStream(same(method), same(headers), same(callOptions))) + when(mockRealTransport.newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any())) .thenReturn(mockRealStream); - when(mockRealTransport2.newStream(same(method2), same(headers2), same(callOptions2))) + when(mockRealTransport2.newStream( + same(method2), same(headers2), same(callOptions2), + ArgumentMatchers.any())) .thenReturn(mockRealStream2); delayedTransport.start(transportListener); } @@ -135,7 +144,8 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void streamStartThenAssignTransport() { assertFalse(delayedTransport.hasPendingStreams()); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(delayedTransport.hasPendingStreams()); @@ -145,7 +155,9 @@ public void uncaughtException(Thread t, Throwable e) { assertEquals(0, delayedTransport.getPendingStreamsCount()); assertFalse(delayedTransport.hasPendingStreams()); assertEquals(1, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); verify(mockRealStream).start(listenerCaptor.capture()); verifyNoMoreInteractions(streamListener); listenerCaptor.getValue().onReady(); @@ -154,7 +166,7 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenAssignTransportThenShutdown() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.reprocess(mockPicker); @@ -163,7 +175,9 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); assertEquals(0, fakeExecutor.runDueTasks()); - verify(mockRealTransport).newStream(same(method), same(headers), same(callOptions)); + verify(mockRealTransport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); stream.start(streamListener); verify(mockRealStream).start(same(streamListener)); } @@ -181,11 +195,13 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); assertEquals(0, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof FailingClientStream); verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test public void assignTransportThenShutdownNowThenNewStream() { @@ -193,15 +209,18 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); assertEquals(0, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof FailingClientStream); verify(mockRealTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test public void startThenCancelStreamWithoutSetTransport() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.cancel(Status.CANCELLED); @@ -213,7 +232,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenShutdownTransportThenAssignTransport() { - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); delayedTransport.shutdown(SHUTDOWN_STATUS); @@ -225,7 +245,8 @@ public void uncaughtException(Thread t, Throwable e) { // ... and will proceed if a real transport is available delayedTransport.reprocess(mockPicker); fakeExecutor.runDueTasks(); - verify(mockRealTransport).newStream(method, headers, callOptions); + verify(mockRealTransport).newStream( + method, headers, callOptions, tracers); verify(mockRealStream).start(any(ClientStreamListener.class)); // Since no more streams are pending, delayed transport is now terminated @@ -233,7 +254,8 @@ public void uncaughtException(Thread t, Throwable e) { verify(transportListener).transportTerminated(); // Further newStream() will return a failing stream - stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); verify(streamListener, never()).closed( any(Status.class), any(RpcProgress.class), any(Metadata.class)); stream.start(streamListener); @@ -247,7 +269,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void newStreamThenShutdownTransportThenCancelStream() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); @@ -264,7 +287,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); @@ -272,7 +296,8 @@ public void uncaughtException(Thread t, Throwable e) { } @Test public void startStreamThenShutdownNow() { - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); @@ -286,7 +311,8 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); - ClientStream stream = delayedTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); @@ -301,55 +327,59 @@ public void uncaughtException(Thread t, Throwable e) { AbstractSubchannel subchannel1 = mock(AbstractSubchannel.class); AbstractSubchannel subchannel2 = mock(AbstractSubchannel.class); AbstractSubchannel subchannel3 = mock(AbstractSubchannel.class); - when(mockRealTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream); - when(mockRealTransport2.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream2); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); + when(mockRealTransport2.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream2); when(subchannel1.getInternalSubchannel()).thenReturn(newTransportProvider(mockRealTransport)); when(subchannel2.getInternalSubchannel()).thenReturn(newTransportProvider(mockRealTransport2)); when(subchannel3.getInternalSubchannel()).thenReturn(newTransportProvider(null)); // Fail-fast streams DelayedStream ff1 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions); + method, headers, failFastCallOptions, tracers); ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); PickSubchannelArgsImpl ff1args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions); + method2, headers2, failFastCallOptions, tracers); PickSubchannelArgsImpl ff2args = new PickSubchannelArgsImpl(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions); + method, headers, failFastCallOptions, tracers); PickSubchannelArgsImpl ff3args = new PickSubchannelArgsImpl(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions); + method2, headers2, failFastCallOptions, tracers); PickSubchannelArgsImpl ff4args = new PickSubchannelArgsImpl(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions); + method, headers, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr1args = new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions); + method2, headers2, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr2args = new PickSubchannelArgsImpl(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( - method, headers, wfr3callOptions); + method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); wfr3.halfClose(); PickSubchannelArgsImpl wfr3args = new PickSubchannelArgsImpl(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions); + method2, headers2, waitForReadyCallOptions, tracers); PickSubchannelArgsImpl wfr4args = new PickSubchannelArgsImpl(method2, headers2, waitForReadyCallOptions); @@ -386,8 +416,10 @@ public void uncaughtException(Thread t, Throwable e) { // streams are now owned by a real transport (which should prevent the Channel from // terminating). // ff1 and wfr1 went through - verify(mockRealTransport).newStream(method, headers, failFastCallOptions); - verify(mockRealTransport2).newStream(method, headers, waitForReadyCallOptions); + verify(mockRealTransport).newStream( + method, headers, failFastCallOptions, tracers); + verify(mockRealTransport2).newStream( + method, headers, waitForReadyCallOptions, tracers); assertSame(mockRealStream, ff1.getRealStream()); assertSame(mockRealStream2, wfr1.getRealStream()); verify(mockRealStream).start(any(ClientStreamListener.class)); @@ -443,7 +475,7 @@ public void uncaughtException(Thread t, Throwable e) { // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions); + method, headers, waitForReadyCallOptions, tracers); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( new PickSubchannelArgsImpl(method, headers, waitForReadyCallOptions)); @@ -474,14 +506,17 @@ public void reprocess_NoPendingStream() { when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel); when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel)); - when(mockRealTransport.newStream(any(MethodDescriptor.class), any(Metadata.class), - any(CallOptions.class))).thenReturn(mockRealStream); + when(mockRealTransport.newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) + .thenReturn(mockRealStream); delayedTransport.reprocess(picker); verifyNoMoreInteractions(picker); verifyNoMoreInteractions(transportListener); // Though picker was not originally used, it will be saved and serve future streams. - ClientStream stream = delayedTransport.newStream(method, headers, CallOptions.DEFAULT); + ClientStream stream = delayedTransport.newStream( + method, headers, CallOptions.DEFAULT, tracers); verify(picker).pickSubchannel(new PickSubchannelArgsImpl(method, headers, CallOptions.DEFAULT)); verify(mockInternalSubchannel).obtainActiveTransport(); assertSame(mockRealStream, stream); @@ -519,7 +554,7 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable { @Override public void run() { // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers, callOptions); + delayedTransport.newStream(method, headers, callOptions, tracers); } }; sideThread.start(); @@ -552,7 +587,7 @@ public void run() { @Override public void run() { // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers2, callOptions); + delayedTransport.newStream(method, headers2, callOptions, tracers); } }; sideThread2.start(); @@ -600,7 +635,8 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { // Because there is no pending stream yet, it will do nothing but save the picker. delayedTransport.reprocess(picker); - ClientStream stream = delayedTransport.newStream(method, headers, callOptions); + ClientStream stream = delayedTransport.newStream( + method, headers, callOptions, tracers); stream.start(streamListener); assertTrue(delayedTransport.hasPendingStreams()); verify(transportListener).transportInUse(true); @@ -609,7 +645,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { @Test public void pendingStream_appendTimeoutInsight_waitForReady() { ClientStream stream = delayedTransport.newStream( - method, headers, callOptions.withWaitForReady()); + method, headers, callOptions.withWaitForReady(), tracers); stream.start(streamListener); InsightBuilder insight = new InsightBuilder(); stream.appendTimeoutInsight(insight); diff --git a/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java b/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java index dad82902395..c07812577d5 100644 --- a/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/FailingClientStreamTest.java @@ -21,6 +21,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -33,13 +34,16 @@ */ @RunWith(JUnit4.class) public class FailingClientStreamTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; @Test public void processedRpcProgressPopulatedToListener() { ClientStreamListener listener = mock(ClientStreamListener.class); Status status = Status.UNAVAILABLE; - ClientStream stream = new FailingClientStream(status); + ClientStream stream = new FailingClientStream(status, RpcProgress.PROCESSED, tracers); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.PROCESSED), any(Metadata.class)); } @@ -49,7 +53,7 @@ public void droppedRpcProgressPopulatedToListener() { ClientStreamListener listener = mock(ClientStreamListener.class); Status status = Status.UNAVAILABLE; - ClientStream stream = new FailingClientStream(status, RpcProgress.DROPPED); + ClientStream stream = new FailingClientStream(status, RpcProgress.DROPPED, tracers); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.DROPPED), any(Metadata.class)); } diff --git a/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java b/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java index ff15ef7ff02..98749d74910 100644 --- a/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/FailingClientTransportTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.verify; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; @@ -41,8 +42,9 @@ public void newStreamStart() { Status error = Status.UNAVAILABLE; RpcProgress rpcProgress = RpcProgress.DROPPED; FailingClientTransport transport = new FailingClientTransport(error, rpcProgress); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + new ClientStreamTracer[] { new ClientStreamTracer() {} }); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); diff --git a/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java b/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java new file mode 100644 index 00000000000..5eb5b49fa19 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/ForwardingClientStreamTracerTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.mockito.Mockito.mock; + +import io.grpc.ClientStreamTracer; +import io.grpc.ForwardingTestUtil; +import java.lang.reflect.Method; +import java.util.Collections; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ForwardingClientStreamTracer}. */ +@RunWith(JUnit4.class) +public class ForwardingClientStreamTracerTest { + private final ClientStreamTracer mockDelegate = mock(ClientStreamTracer.class); + + @Test + public void allMethodsForwarded() throws Exception { + ForwardingTestUtil.testMethodsForwarded( + ClientStreamTracer.class, + mockDelegate, + new ForwardingClientStreamTracerTest.TestClientStreamTracer(), + Collections.emptyList()); + } + + private final class TestClientStreamTracer extends ForwardingClientStreamTracer { + @Override + protected ClientStreamTracer delegate() { + return mockDelegate; + } + } +} diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 7a3808de6e3..6d2c21ddab8 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -27,13 +28,18 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.LoadBalancer.PickResult; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.internal.ClientStreamListener.RpcProgress; import io.grpc.internal.GrpcUtil.Http2Error; import io.grpc.testing.TestMethodDescriptors; +import java.util.ArrayDeque; +import java.util.concurrent.atomic.AtomicReference; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -44,6 +50,10 @@ @RunWith(JUnit4.class) public class GrpcUtilTest { + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; + @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); @@ -244,8 +254,9 @@ public void getTransportFromPickResult_errorPickResult_failFast() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); @@ -260,8 +271,9 @@ public void getTransportFromPickResult_dropPickResult_waitForReady() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); @@ -276,11 +288,45 @@ public void getTransportFromPickResult_dropPickResult_failFast() { assertNotNull(transport); - ClientStream stream = transport - .newStream(TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT); + ClientStream stream = transport.newStream( + TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, + tracers); ClientStreamListener listener = mock(ClientStreamListener.class); stream.start(listener); verify(listener).closed(eq(status), eq(RpcProgress.DROPPED), any(Metadata.class)); } + + @Test + public void clientStreamTracerFactoryBackwardCompatibility() { + final AtomicReference transportAttrsRef = new AtomicReference<>(); + final ClientStreamTracer mockTracer = mock(ClientStreamTracer.class); + final Metadata.Key key = Metadata.Key.of("fake-key", Metadata.ASCII_STRING_MARSHALLER); + final ArrayDeque tracers = new ArrayDeque<>(); + ClientStreamTracer.Factory oldFactoryImpl = new ClientStreamTracer.Factory() { + @SuppressWarnings("deprecation") + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + transportAttrsRef.set(info.getTransportAttrs()); + headers.put(key, "fake-value"); + tracers.offer(mockTracer); + return mockTracer; + } + }; + + StreamInfo info = + StreamInfo.newBuilder().setCallOptions(CallOptions.DEFAULT.withWaitForReady()).build(); + Metadata metadata = new Metadata(); + Attributes transAttrs = + Attributes.newBuilder().set(Attributes.Key.create("foo"), "bar").build(); + ClientStreamTracer tracer = GrpcUtil.newClientStreamTracer(oldFactoryImpl, info, metadata); + tracer.streamCreated(transAttrs, metadata); + assertThat(tracers.poll()).isSameInstanceAs(mockTracer); + assertThat(transportAttrsRef.get()).isEqualTo(transAttrs); + assertThat(metadata.get(key)).isEqualTo("fake-value"); + + tracer.streamClosed(Status.UNAVAILABLE); + // verify that newClientStreamTracer() is called no more than once + assertThat(tracers).isEmpty(); + } } diff --git a/core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java b/core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java new file mode 100644 index 00000000000..e1bbc063ea3 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/InUseStateAggregatorTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertTrue; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * Unit tests for {@link InUseStateAggregator}. + */ +@RunWith(JUnit4.class) +public class InUseStateAggregatorTest { + + private InUseStateAggregator aggregator; + + @Before + public void setUp() { + aggregator = new InUseStateAggregator() { + @Override + protected void handleInUse() { + } + + @Override + protected void handleNotInUse() { + } + }; + } + + @Test + public void anyObjectInUse() { + String objectOne = "1"; + String objectTwo = "2"; + String objectThree = "3"; + + aggregator.updateObjectInUse(objectOne, true); + assertTrue(aggregator.anyObjectInUse(objectOne)); + + aggregator.updateObjectInUse(objectTwo, true); + aggregator.updateObjectInUse(objectThree, true); + assertTrue(aggregator.anyObjectInUse(objectOne, objectTwo, objectThree)); + + aggregator.updateObjectInUse(objectTwo, false); + assertTrue(aggregator.anyObjectInUse(objectOne, objectThree)); + } +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java index 6a76f75c8b7..30e137cba22 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplIdlenessTest.java @@ -26,7 +26,9 @@ import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.atMostOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -284,6 +286,35 @@ public void delayedTransportHoldsOffIdleness() throws Exception { verify(mockLoadBalancer).shutdown(); } + @Test + public void pendingCallExitsIdleAfterEnter() throws Exception { + // Create a pending call without starting it. + channel.newCall(method, CallOptions.DEFAULT); + + channel.enterIdle(); + + // Just the existence of a non-started, pending call means the channel cannot stay + // in idle mode because the expectation is that the pending call will also need to + // be handled. + verify(mockNameResolver, times(2)).start(any(NameResolver.Listener2.class)); + } + + @Test + public void delayedTransportExitsIdleAfterEnter() throws Exception { + // Start a new call that will go to the delayed transport + ClientCall call = channel.newCall(method, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + deliverResolutionResult(); + + channel.enterIdle(); + + // Since we have a call in delayed transport, the call to enterIdle() should have resulted in + // the channel going to idle mode and then immediately exiting. We confirm this by verifying + // that the name resolver was started up twice - once when the call was first created and a + // second time after exiting idle mode. + verify(mockNameResolver, times(2)).start(any(NameResolver.Listener2.class)); + } + @Test public void realTransportsHoldsOffIdleness() throws Exception { final EquivalentAddressGroup addressGroup = servers.get(1); @@ -332,6 +363,50 @@ public void realTransportsHoldsOffIdleness() throws Exception { verify(mockLoadBalancer).shutdown(); } + @Test + public void enterIdleWhileRealTransportInProgress() { + final EquivalentAddressGroup addressGroup = servers.get(1); + + // Start a call, which goes to delayed transport + ClientCall call = channel.newCall(method, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + + // Verify that we have exited the idle mode + ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(null); + verify(mockLoadBalancerProvider).newLoadBalancer(helperCaptor.capture()); + deliverResolutionResult(); + Helper helper = helperCaptor.getValue(); + + // Create a subchannel for the real transport to happen on. + Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY); + requestConnectionSafely(helper, subchannel); + MockClientTransportInfo t0 = newTransports.poll(); + t0.listener.transportReady(); + + SubchannelPicker mockPicker = mock(SubchannelPicker.class); + when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) + .thenReturn(PickResult.withSubchannel(subchannel)); + updateBalancingStateSafely(helper, READY, mockPicker); + + // Delayed transport creates real streams in the app executor + executor.runDueTasks(); + + // Move transport to the in-use state + t0.listener.transportInUse(true); + + // Now we enter Idle mode while real transport is happening + channel.enterIdle(); + + // Verify that the name resolver and the load balance were shut down. + verify(mockNameResolver).shutdown(); + verify(mockLoadBalancer).shutdown(); + + // When there are no pending streams, the call to enterIdle() should stick and + // we remain in idle mode. We verify this by making sure that the name resolver + // was not started up more than once (the initial startup). + verify(mockNameResolver, atMostOnce()).start(isA(NameResolver.Listener2.class)); + } + @Test public void updateSubchannelAddresses_newAddressConnects() { ClientCall call = channel.newCall(method, CallOptions.DEFAULT); diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index e5e017d756a..668411d7ecc 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -71,6 +71,7 @@ import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.CompositeChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -151,6 +152,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; @@ -225,6 +227,8 @@ public boolean shouldAccept(Runnable command) { private ArgumentCaptor statusCaptor; @Captor private ArgumentCaptor callOptionsCaptor; + @Captor + private ArgumentCaptor tracersCaptor; @Mock private LoadBalancer mockLoadBalancer; @Mock @@ -351,6 +355,7 @@ public void close() throws SecurityException { channelBuilder = new ManagedChannelImplBuilder(TARGET, new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + channelBuilder.disableRetry(); configureBuilder(channelBuilder); } @@ -525,7 +530,9 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -534,7 +541,9 @@ channelBuilder, mockTransportFactory, new FakeBackoffPolicyProvider(), executor.runDueTasks(); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); assertThat(callOptionsCaptor.getValue().isWaitForReady()).isTrue(); verify(mockStream).start(streamListenerCaptor.capture()); @@ -600,7 +609,9 @@ public ClientCall interceptCall( MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -609,7 +620,9 @@ public ClientCall interceptCall( executor.runDueTasks(); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(null); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); assertThat(callOptionsCaptor.getValue().getOption(callOptionsKey)).isEqualTo("fooValue"); verify(mockStream).start(streamListenerCaptor.capture()); @@ -800,9 +813,13 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft ConnectionClientTransport mockTransport = transportInfo.transport; verify(mockTransport).start(any(ManagedClientTransport.Listener.class)); ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), same(CallOptions.DEFAULT))) + when(mockTransport.newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any())) .thenReturn(mockStream); - when(mockTransport.newStream(same(method), same(headers2), same(CallOptions.DEFAULT))) + when(mockTransport.newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any())) .thenReturn(mockStream2); transportListener.transportReady(); when(mockPicker.pickSubchannel( @@ -820,14 +837,19 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class)); call.start(mockCallListener, headers); - verify(mockTransport, never()) - .newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport, never()).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // Second RPC, will be assigned to the real transport ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener2, headers2); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); - verify(mockTransport).newStream(same(method), same(headers2), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); + verify(mockTransport).newStream( + same(method), same(headers2), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); verify(mockStream2).start(any(ClientStreamListener.class)); // Shutdown @@ -872,7 +894,9 @@ private void subtestCallsAndShutdown(boolean shutdownNow, boolean shutdownNowAft .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper, READY, picker2); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), same(headers), same(CallOptions.DEFAULT)); + verify(mockTransport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -1021,7 +1045,9 @@ public void callOptionsExecutor() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -1031,7 +1057,8 @@ public void callOptionsExecutor() { // Real streams are started in the call executor if they were previously buffered. assertEquals(1, callExecutor.runDueTasks()); - verify(mockTransport).newStream(same(method), same(headers), same(options)); + verify(mockTransport).newStream( + same(method), same(headers), same(options), ArgumentMatchers.any()); verify(mockStream).start(streamListenerCaptor.capture()); // Call listener callbacks are also run in the call executor @@ -1298,7 +1325,8 @@ public void firstResolvedServerFailedToConnect() throws Exception { same(goodAddress), any(ClientTransportOptions.class), any(ChannelLogger.class)); MockClientTransportInfo goodTransportInfo = transports.poll(); when(goodTransportInfo.transport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mock(ClientStream.class)); goodTransportInfo.listener.transportReady(); @@ -1310,11 +1338,13 @@ public void firstResolvedServerFailedToConnect() throws Exception { // Delayed transport uses the app executor to create real streams. executor.runDueTasks(); - verify(goodTransportInfo.transport).newStream(same(method), same(headers), - same(CallOptions.DEFAULT)); + verify(goodTransportInfo.transport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // The bad transport was never used. - verify(badTransportInfo.transport, times(0)).newStream(any(MethodDescriptor.class), - any(Metadata.class), any(CallOptions.class)); + verify(badTransportInfo.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test @@ -1464,10 +1494,12 @@ public void allServersFailedToConnect() throws Exception { // ... while the wait-for-ready call stays verifyNoMoreInteractions(mockCallListener); // No real stream was ever created - verify(transportInfo1.transport, times(0)) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); - verify(transportInfo2.transport, times(0)) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + verify(transportInfo1.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); + verify(transportInfo2.transport, times(0)).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); } @Test @@ -1763,8 +1795,9 @@ public void oobchannels() { assertEquals(0, balancerRpcExecutor.numPendingTasks()); transportInfo.listener.transportReady(); assertEquals(1, balancerRpcExecutor.runDueTasks()); - verify(transportInfo.transport).newStream(same(method), same(headers), - same(CallOptions.DEFAULT)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(CallOptions.DEFAULT), + ArgumentMatchers.any()); // The transport goes away transportInfo.listener.transportShutdown(Status.UNAVAILABLE); @@ -1849,6 +1882,7 @@ public void oobChannelHasNoChannelCallCredentials() { TARGET, InsecureChannelCredentials.create(), new FakeCallCredentials(metadataKey, channelCredValue), new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + channelBuilder.disableRetry(); configureBuilder(channelBuilder); createChannel(); @@ -1870,7 +1904,9 @@ public void oobChannelHasNoChannelCallCredentials() { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(channelCredValue, callCredValue).inOrder(); @@ -1887,7 +1923,9 @@ public void oobChannelHasNoChannelCallCredentials() { transportInfo.listener.transportReady(); balancerRpcExecutor.runDueTasks(); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); @@ -1897,6 +1935,7 @@ public void oobChannelHasNoChannelCallCredentials() { new FakeNameResolverFactory.Builder(URI.create("oobauthority")).build()) .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) + .disableRetry() // irrelevant to what we test, disable retry to make verification easy .build(); oob.getState(true); ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); @@ -1919,7 +1958,9 @@ public void oobChannelHasNoChannelCallCredentials() { call.start(mockCallListener2, headers); // CallOptions may contain StreamTracerFactory for census that is added by default. - verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + verify(transportInfo.transport).newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)).containsExactly(callCredValue); oob.shutdownNow(); } @@ -1942,6 +1983,7 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { TARGET, InsecureChannelCredentials.create(), new FakeCallCredentials(metadataKey, channelCredValue), new UnsupportedClientTransportFactoryBuilder(), new FixedPortProvider(DEFAULT_PORT)); + channelBuilder.disableRetry(); configureBuilder(channelBuilder); createChannel(); @@ -1962,7 +2004,9 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { ClientCall call = channel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(transportInfo.transport).newStream(same(method), same(headers), same(callOptions)); + verify(transportInfo.transport).newStream( + same(method), same(headers), same(callOptions), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(channelCredValue, callCredValue).inOrder(); @@ -1977,6 +2021,7 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { new FakeNameResolverFactory.Builder(URI.create("fake://oobauthority/")).build()) .defaultLoadBalancingPolicy(MOCK_POLICY_NAME) .idleTimeout(ManagedChannelImplBuilder.IDLE_MODE_MAX_TIMEOUT_DAYS, TimeUnit.DAYS) + .disableRetry() // irrelevant to what we test, disable retry to make verification easy .build(); oob.getState(true); ArgumentCaptor helperCaptor = ArgumentCaptor.forClass(Helper.class); @@ -1998,7 +2043,9 @@ public SwapChannelCredentialsResult answer(InvocationOnMock invocation) { call.start(mockCallListener2, headers); // CallOptions may contain StreamTracerFactory for census that is added by default. - verify(transportInfo.transport).newStream(same(method), same(headers), any(CallOptions.class)); + verify(transportInfo.transport).newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any()); assertThat(headers.getAll(metadataKey)) .containsExactly(oobChannelCredValue, callCredValue).inOrder(); oob.shutdownNow(); @@ -2097,7 +2144,9 @@ public void subchannelChannel_normalUsage() { ClientCall call = sChannel.newCall(method, callOptions); call.start(mockCallListener, headers); - verify(mockTransport).newStream(same(method), same(headers), callOptionsCaptor.capture()); + verify(mockTransport).newStream( + same(method), same(headers), callOptionsCaptor.capture(), + ArgumentMatchers.any()); CallOptions capturedCallOption = callOptionsCaptor.getValue(); assertThat(capturedCallOption.getDeadline()).isSameInstanceAs(callOptions.getDeadline()); @@ -2125,7 +2174,8 @@ public void subchannelChannel_failWhenNotReady() { ClientCall call = sChannel.newCall(method, CallOptions.DEFAULT); call.start(mockCallListener, headers); verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verifyNoInteractions(mockCallListener); assertEquals(1, balancerRpcExecutor.runDueTasks()); @@ -2157,7 +2207,8 @@ public void subchannelChannel_failWaitForReady() { sChannel.newCall(method, CallOptions.DEFAULT.withWaitForReady()); call.start(mockCallListener, headers); verify(mockTransport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verifyNoInteractions(mockCallListener); assertEquals(1, balancerRpcExecutor.runDueTasks()); @@ -2332,7 +2383,8 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { return mock(ClientStream.class); } }).when(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(creds, never()).applyRequestMetadata( any(RequestInfo.class), any(Executor.class), any(CallCredentials.MetadataApplier.class)); @@ -2351,11 +2403,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { assertEquals(AUTHORITY, infoCaptor.getValue().getAuthority()); assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); verify(transport, never()).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); // newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport).newStream( + same(method), any(Metadata.class), same(callOptions), + ArgumentMatchers.any()); assertEquals("testValue", testKey.get(newStreamContexts.poll())); // The context should not live beyond the scope of newStream() and applyRequestMetadata() assertNull(testKey.get()); @@ -2374,11 +2429,14 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { assertEquals(SecurityLevel.NONE, infoCaptor.getValue().getSecurityLevel()); // This is from the first call verify(transport).newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); // Still, newStream() is called after apply() is called applierCaptor.getValue().apply(new Metadata()); - verify(transport, times(2)).newStream(same(method), any(Metadata.class), same(callOptions)); + verify(transport, times(2)).newStream( + same(method), any(Metadata.class), same(callOptions), + ArgumentMatchers.any()); assertEquals("testValue", testKey.get(newStreamContexts.poll())); assertNull(testKey.get()); @@ -2387,8 +2445,20 @@ public ClientStream answer(InvocationOnMock in) throws Throwable { @Test public void pickerReturnsStreamTracer_noDelay() { ClientStream mockStream = mock(ClientStream.class); - ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); - ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); + final ClientStreamTracer tracer1 = new ClientStreamTracer() {}; + final ClientStreamTracer tracer2 = new ClientStreamTracer() {}; + ClientStreamTracer.Factory factory1 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer1; + } + }; + ClientStreamTracer.Factory factory2 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer2; + } + }; createChannel(); Subchannel subchannel = createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener); @@ -2397,7 +2467,8 @@ public void pickerReturnsStreamTracer_noDelay() { transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( @@ -2409,20 +2480,29 @@ public void pickerReturnsStreamTracer_noDelay() { call.start(mockCallListener, new Metadata()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); - verify(mockTransport).newStream(same(method), any(Metadata.class), callOptionsCaptor.capture()); - assertEquals( - Arrays.asList(factory1, factory2), - callOptionsCaptor.getValue().getStreamTracerFactories()); - // The factories are safely not stubbed because we do not expect any usage of them. - verifyNoInteractions(factory1); - verifyNoInteractions(factory2); + verify(mockTransport).newStream( + same(method), any(Metadata.class), callOptionsCaptor.capture(), + tracersCaptor.capture()); + assertThat(tracersCaptor.getValue()).isEqualTo(new ClientStreamTracer[] {tracer1, tracer2}); } @Test public void pickerReturnsStreamTracer_delayed() { ClientStream mockStream = mock(ClientStream.class); - ClientStreamTracer.Factory factory1 = mock(ClientStreamTracer.Factory.class); - ClientStreamTracer.Factory factory2 = mock(ClientStreamTracer.Factory.class); + final ClientStreamTracer tracer1 = new ClientStreamTracer() {}; + final ClientStreamTracer tracer2 = new ClientStreamTracer() {}; + ClientStreamTracer.Factory factory1 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer1; + } + }; + ClientStreamTracer.Factory factory2 = new ClientStreamTracer.InternalLimitedInfoFactory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return tracer2; + } + }; createChannel(); CallOptions callOptions = CallOptions.DEFAULT.withStreamTracerFactory(factory1); @@ -2436,7 +2516,8 @@ public void pickerReturnsStreamTracer_delayed() { transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory2)); @@ -2445,13 +2526,10 @@ public void pickerReturnsStreamTracer_delayed() { assertEquals(1, executor.runDueTasks()); verify(mockPicker).pickSubchannel(any(PickSubchannelArgs.class)); - verify(mockTransport).newStream(same(method), any(Metadata.class), callOptionsCaptor.capture()); - assertEquals( - Arrays.asList(factory1, factory2), - callOptionsCaptor.getValue().getStreamTracerFactories()); - // The factories are safely not stubbed because we do not expect any usage of them. - verifyNoInteractions(factory1); - verifyNoInteractions(factory2); + verify(mockTransport).newStream( + same(method), any(Metadata.class), callOptionsCaptor.capture(), + tracersCaptor.capture()); + assertThat(tracersCaptor.getValue()).isEqualTo(new ClientStreamTracer[] {tracer1, tracer2}); } @Test @@ -2818,7 +2896,9 @@ public void idleMode_resetsDelayedTransportPicker() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2829,7 +2909,9 @@ public void idleMode_resetsDelayedTransportPicker() { executor.runDueTasks(); // Verify the buffered call was drained - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2888,7 +2970,9 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) @@ -2898,7 +2982,9 @@ public void enterIdle_exitsIdleIfDelayedStreamPending() { // Verify the original call was drained executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2920,7 +3006,9 @@ public void updateBalancingStateDoesUpdatePicker() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2929,8 +3017,9 @@ public void updateBalancingStateDoesUpdatePicker() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport, never()) - .newStream(any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport, never()).newStream( + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream, never()).start(any(ClientStreamListener.class)); @@ -2939,7 +3028,9 @@ public void updateBalancingStateDoesUpdatePicker() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -2958,7 +3049,9 @@ public void updateBalancingState_withWrappedSubchannel() { MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); transportListener.transportReady(); @@ -2973,7 +3066,9 @@ protected Subchannel delegate() { updateBalancingStateSafely(helper, READY, mockPicker); executor.runDueTasks(); - verify(mockTransport).newStream(same(method), any(Metadata.class), any(CallOptions.class)); + verify(mockTransport).newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } @@ -3405,7 +3500,8 @@ private void channelsAndSubchannels_instrumented0(boolean success) throws Except transportInfo.listener.transportReady(); ClientTransport mockTransport = transportInfo.transport; when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn( PickResult.withSubchannel(subchannel, factory)); @@ -3478,7 +3574,9 @@ private void channelsAndSubchannels_oob_instrumented0(boolean success) throws Ex MockClientTransportInfo transportInfo = transports.poll(); ConnectionClientTransport mockTransport = transportInfo.transport; ManagedClientTransport.Listener transportListener = transportInfo.listener; - when(mockTransport.newStream(same(method), same(headers), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), same(headers), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream); // subchannel stat bumped when call gets assigned to it @@ -3650,7 +3748,9 @@ public double nextDouble() { ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); updateBalancingStateSafely(helper, READY, mockPicker); @@ -3754,7 +3854,9 @@ public void hedgingScheduledThenChannelShutdown_hedgeShouldStillHappen_newCallSh ConnectionClientTransport mockTransport = transportInfo.transport; ClientStream mockStream = mock(ClientStream.class); ClientStream mockStream2 = mock(ClientStream.class); - when(mockTransport.newStream(same(method), any(Metadata.class), any(CallOptions.class))) + when(mockTransport.newStream( + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mockStream).thenReturn(mockStream2); transportInfo.listener.transportReady(); updateBalancingStateSafely(helper, READY, mockPicker); diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index c1907e51703..5edf64ef85d 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -378,7 +378,7 @@ public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { try { thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Compressed gRPC message exceeds"); + thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); while (stream.read() != -1) { } @@ -424,7 +424,7 @@ public void sizeEnforcingInputStream_readAboveLimit() throws IOException { try { thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Compressed gRPC message exceeds"); + thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); stream.read(buf, 0, buf.length); } finally { @@ -467,7 +467,7 @@ public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { try { thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Compressed gRPC message exceeds"); + thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); stream.skip(4); } finally { diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index a83964b5e91..8b851573b21 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -163,9 +163,11 @@ void postCommit() { } @Override - ClientStream newSubstream(ClientStreamTracer.Factory tracerFactory, Metadata metadata) { + ClientStream newSubstream( + Metadata metadata, ClientStreamTracer.Factory tracerFactory, int previousAttempts, + boolean isTransparentRetry) { bufferSizeTracer = - tracerFactory.newClientStreamTracer(STREAM_INFO, new Metadata()); + tracerFactory.newClientStreamTracer(STREAM_INFO, metadata); int actualPreviousRpcAttemptsInHeader = metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS) == null ? 0 : Integer.valueOf(metadata.get(GRPC_PREVIOUS_RPC_ATTEMPTS)); return retriableStreamRecorder.newSubstream(actualPreviousRpcAttemptsInHeader); @@ -254,6 +256,7 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); retriableStream.sendMessage("msg1"); @@ -306,6 +309,7 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(456); inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // send more messages @@ -354,6 +358,7 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(456); inOrder.verify(mockStream3, times(7)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); InsightBuilder insight = new InsightBuilder(); @@ -635,6 +640,7 @@ public void retry_cancelWhileBackoff() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); // retry ClientStream mockStream2 = mock(ClientStream.class); @@ -654,7 +660,7 @@ public void retry_cancelWhileBackoff() { @Test public void operationsWhileDraining() { - ArgumentCaptor sublistenerCaptor1 = + final ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); final AtomicReference sublistenerCaptor2 = new AtomicReference<>(); @@ -667,10 +673,16 @@ public void operationsWhileDraining() { @Override public void request(int numMessages) { retriableStream.sendMessage("substream1 request " + numMessages); + sublistenerCaptor1.getValue().onReady(); if (numMessages > 1) { retriableStream.request(--numMessages); } } + + @Override + public boolean isReady() { + return true; + } })); final ClientStream mockStream2 = @@ -686,7 +698,7 @@ public void start(ClientStreamListener listener) { @Override public void request(int numMessages) { retriableStream.sendMessage("substream2 request " + numMessages); - + sublistenerCaptor2.get().onReady(); if (numMessages == 3) { sublistenerCaptor2.get().headersRead(new Metadata()); } @@ -697,9 +709,14 @@ public void request(int numMessages) { retriableStream.cancel(cancelStatus); } } + + @Override + public boolean isReady() { + return true; + } })); - InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2, masterListener); doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); retriableStream.start(masterListener); @@ -714,6 +731,7 @@ public void request(int numMessages) { inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); // msg "substream1 request 2" inOrder.verify(mockStream1).request(1); inOrder.verify(mockStream1).writeMessage(any(InputStream.class)); // msg "substream1 request 1" + inOrder.verify(masterListener).onReady(); // retry doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); @@ -741,13 +759,98 @@ public void request(int numMessages) { // msg "substream2 request 2" inOrder.verify(mockStream2, times(2)).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(100); - - verify(mockStream2).cancel(cancelStatus); + inOrder.verify(mockStream2).cancel(cancelStatus); + inOrder.verify(masterListener, never()).onReady(); // "substream2 request 1" will never be sent inOrder.verify(mockStream2, never()).writeMessage(any(InputStream.class)); } + @Test + public void cancelWhileDraining() { + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void request(int numMessages) { + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while requesting")); + } + })); + + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + retriableStream.request(3); + inOrder.verify(mockStream1).request(3); + + // retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(mockStream2).request(3); + inOrder.verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()) + .isEqualTo("Stream thrown away because RetriableStream committed"); + verify(masterListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while requesting"); + } + + @Test + public void cancelWhileRetryStart() { + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void start(ClientStreamListener listener) { + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while retry start")); + } + })); + + InOrder inOrder = inOrder(retriableStreamRecorder, mockStream1, mockStream2); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + retriableStream.start(masterListener); + inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + + // retry + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(1); + sublistenerCaptor1.getValue().closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + inOrder.verify(mockStream2).start(any(ClientStreamListener.class)); + inOrder.verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + inOrder.verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()) + .isEqualTo("Stream thrown away because RetriableStream committed"); + verify(masterListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while retry start"); + } + @Test public void operationsAfterImmediateCommit() { ArgumentCaptor sublistenerCaptor1 = @@ -915,6 +1018,47 @@ public void start(ClientStreamListener listener) { verify(mockStream3).request(1); } + @Test + public void commitAndCancelWhileDraining() { + ClientStream mockStream1 = mock(ClientStream.class); + ClientStream mockStream2 = + mock( + ClientStream.class, + delegatesTo( + new NoopClientStream() { + @Override + public void start(ClientStreamListener listener) { + // commit while draining + listener.headersRead(new Metadata()); + // cancel while draining + retriableStream.cancel( + Status.CANCELLED.withDescription("cancelled while drained")); + } + })); + + when(retriableStreamRecorder.newSubstream(anyInt())) + .thenReturn(mockStream1, mockStream2); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + ClientStreamListener listener1 = sublistenerCaptor1.getValue(); + + // retry + listener1.closed( + Status.fromCode(RETRIABLE_STATUS_CODE_1), PROCESSED, new Metadata()); + fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); + + verify(mockStream2).start(any(ClientStreamListener.class)); + verify(retriableStreamRecorder).postCommit(); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockStream2).cancel(statusCaptor.capture()); + assertThat(statusCaptor.getValue().getCode()).isEqualTo(Code.CANCELLED); + assertThat(statusCaptor.getValue().getDescription()).isEqualTo("cancelled while drained"); + } @Test public void perRpcBufferLimitExceeded() { @@ -945,6 +1089,7 @@ public void perRpcBufferLimitExceededDuringBackoff() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); bufferSizeTracer.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -961,6 +1106,7 @@ public void perRpcBufferLimitExceededDuringBackoff() { fakeClock.forwardTime((long) (INITIAL_BACKOFF_IN_SECONDS * FAKE_RANDOM), TimeUnit.SECONDS); verify(mockStream2).start(any(ClientStreamListener.class)); + verify(mockStream2).isReady(); // bufferLimitExceeded bufferSizeTracer.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -1024,6 +1170,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); @@ -1039,6 +1186,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // retry2 @@ -1055,6 +1203,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // retry3 @@ -1072,6 +1221,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // retry4 @@ -1086,6 +1236,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor5 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream5).start(sublistenerCaptor5.capture()); + inOrder.verify(mockStream5).isReady(); inOrder.verifyNoMoreInteractions(); // retry5 @@ -1100,6 +1251,7 @@ public void expBackoff_maxBackoff_maxRetryAttempts() { ArgumentCaptor sublistenerCaptor6 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream6).start(sublistenerCaptor6.capture()); + inOrder.verify(mockStream6).isReady(); inOrder.verifyNoMoreInteractions(); // can not retry any more @@ -1130,6 +1282,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); @@ -1148,6 +1301,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // retry2 @@ -1165,6 +1319,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // retry3 @@ -1179,6 +1334,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // retry4 @@ -1195,6 +1351,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor5 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream5).start(sublistenerCaptor5.capture()); + inOrder.verify(mockStream5).isReady(); inOrder.verifyNoMoreInteractions(); // retry5 @@ -1212,6 +1369,7 @@ public void pushback() { ArgumentCaptor sublistenerCaptor6 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream6).start(sublistenerCaptor6.capture()); + inOrder.verify(mockStream6).isReady(); inOrder.verifyNoMoreInteractions(); // can not retry any more even pushback is positive @@ -1469,6 +1627,7 @@ public void transparentRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // transparent retry @@ -1480,6 +1639,7 @@ public void transparentRetry() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1495,6 +1655,7 @@ public void transparentRetry() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1517,6 +1678,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // normal retry @@ -1530,6 +1692,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1546,6 +1709,7 @@ public void normalRetry_thenNoTransparentRetry_butNormalRetry() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); } @@ -1567,6 +1731,7 @@ public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // normal retry @@ -1580,6 +1745,7 @@ public void normalRetry_thenNoTransparentRetry_andNoMoreRetry() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); verify(retriableStreamRecorder, never()).postCommit(); assertEquals(0, fakeClock.numPendingTasks()); @@ -1610,6 +1776,7 @@ method, new Metadata(), channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_ ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); // transparent retry @@ -1622,6 +1789,7 @@ method, new Metadata(), channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_ ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(retriableStreamRecorder).postCommit(); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); assertEquals(0, fakeClock.numPendingTasks()); } @@ -1640,6 +1808,7 @@ public void droppedShouldNeverRetry() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); // drop and verify no retry Status status = Status.fromCode(RETRIABLE_STATUS_CODE_1); @@ -1711,6 +1880,7 @@ public Void answer(InvocationOnMock in) { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); hedgingStream.sendMessage("msg1"); @@ -1752,6 +1922,8 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream2, times(2)).flush(); inOrder.verify(mockStream2).writeMessage(any(InputStream.class)); inOrder.verify(mockStream2).request(456); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); // send more messages @@ -1789,6 +1961,9 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream3).writeMessage(any(InputStream.class)); inOrder.verify(mockStream3).request(456); inOrder.verify(mockStream3, times(2)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // send one more message @@ -1831,6 +2006,9 @@ public Void answer(InvocationOnMock in) { inOrder.verify(mockStream4).writeMessage(any(InputStream.class)); inOrder.verify(mockStream4).request(456); inOrder.verify(mockStream4, times(4)).writeMessage(any(InputStream.class)); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); InsightBuilder insight = new InsightBuilder(); @@ -1881,6 +2059,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1888,6 +2067,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1895,6 +2075,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1902,6 +2083,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // a random one of the hedges fails @@ -1913,6 +2095,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor5 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream5).start(sublistenerCaptor5.capture()); + inOrder.verify(mockStream5).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1920,6 +2103,7 @@ public void hedging_maxAttempts() { ArgumentCaptor sublistenerCaptor6 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream6).start(sublistenerCaptor6.capture()); + inOrder.verify(mockStream6).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1964,6 +2148,7 @@ public void hedging_receiveHeaders() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1971,6 +2156,7 @@ public void hedging_receiveHeaders() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -1978,6 +2164,7 @@ public void hedging_receiveHeaders() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // a random one of the hedges receives headers @@ -2015,6 +2202,7 @@ public void hedging_pushback_negative() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2022,6 +2210,7 @@ public void hedging_pushback_negative() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2029,6 +2218,7 @@ public void hedging_pushback_negative() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // a random one of the hedges receives a negative pushback @@ -2060,6 +2250,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2067,6 +2258,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); @@ -2084,6 +2276,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor3 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream3).start(sublistenerCaptor3.capture()); + inOrder.verify(mockStream3).isReady(); inOrder.verifyNoMoreInteractions(); // hedge2 receives a pushback for HEDGING_DELAY_IN_SECONDS - 1 second @@ -2097,6 +2290,7 @@ public void hedging_pushback_positive() { ArgumentCaptor sublistenerCaptor4 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream4).start(sublistenerCaptor4.capture()); + inOrder.verify(mockStream4).isReady(); inOrder.verifyNoMoreInteractions(); // commit @@ -2126,6 +2320,7 @@ public void hedging_cancelled() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream1).start(sublistenerCaptor1.capture()); + inOrder.verify(mockStream1).isReady(); inOrder.verifyNoMoreInteractions(); fakeClock.forwardTime(HEDGING_DELAY_IN_SECONDS, TimeUnit.SECONDS); @@ -2133,6 +2328,8 @@ public void hedging_cancelled() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); inOrder.verify(mockStream2).start(sublistenerCaptor2.capture()); + inOrder.verify(mockStream1).isReady(); + inOrder.verify(mockStream2).isReady(); inOrder.verifyNoMoreInteractions(); Status status = Status.CANCELLED.withDescription("cancelled"); @@ -2147,6 +2344,8 @@ public void hedging_cancelled() { assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); inOrder.verify(retriableStreamRecorder).postCommit(); + inOrder.verify(masterListener).closed( + any(Status.class), any(RpcProgress.class), any(Metadata.class)); inOrder.verifyNoMoreInteractions(); } @@ -2161,6 +2360,7 @@ public void hedging_perRpcBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); ClientStreamTracer bufferSizeTracer1 = bufferSizeTracer; bufferSizeTracer1.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -2169,6 +2369,8 @@ public void hedging_perRpcBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream2).start(sublistenerCaptor2.capture()); + verify(mockStream1, times(2)).isReady(); + verify(mockStream2).isReady(); ClientStreamTracer bufferSizeTracer2 = bufferSizeTracer; bufferSizeTracer2.outboundWireSize(PER_RPC_BUFFER_LIMIT - 1); @@ -2185,6 +2387,7 @@ public void hedging_perRpcBufferLimitExceeded() { verify(retriableStreamRecorder).postCommit(); verifyNoMoreInteractions(mockStream1); + verify(mockStream2).isReady(); verifyNoMoreInteractions(mockStream2); } @@ -2199,6 +2402,7 @@ public void hedging_channelBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor1 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream1).start(sublistenerCaptor1.capture()); + verify(mockStream1).isReady(); ClientStreamTracer bufferSizeTracer1 = bufferSizeTracer; bufferSizeTracer1.outboundWireSize(100); @@ -2207,6 +2411,8 @@ public void hedging_channelBufferLimitExceeded() { ArgumentCaptor sublistenerCaptor2 = ArgumentCaptor.forClass(ClientStreamListener.class); verify(mockStream2).start(sublistenerCaptor2.capture()); + verify(mockStream1, times(2)).isReady(); + verify(mockStream2).isReady(); ClientStreamTracer bufferSizeTracer2 = bufferSizeTracer; bufferSizeTracer2.outboundWireSize(100); diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 5130bc05aa7..ea49b94e8aa 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -18,6 +18,7 @@ import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -378,6 +379,9 @@ public void streamListener_closedOk() { verify(callListener).onComplete(); assertTrue(context.isCancelled()); assertNull(context.cancellationCause()); + // The call considers cancellation to be an exceptional situation so it should + // not be cancelled with an OK status. + assertFalse(call.isCancelled()); } @Test diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 2a9dbd5a1fe..0f5c510f97c 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -1196,7 +1196,9 @@ public void testStreamClose_clientOkTriggersDelayedCancellation() throws Excepti context, contextCancelled, null); // For close status OK: - // isCancelled is expected to be true after all pending work is done + // The context isCancelled is expected to be true after all pending work is done, + // but for the call it should be false as it gets set cancelled only if the call + // fails with a non-OK status. assertFalse(callReference.get().isCancelled()); assertFalse(context.get().isCancelled()); streamListener.closed(Status.OK); @@ -1204,7 +1206,7 @@ public void testStreamClose_clientOkTriggersDelayedCancellation() throws Excepti assertFalse(context.get().isCancelled()); assertEquals(1, executor.runDueTasks()); - assertTrue(callReference.get().isCancelled()); + assertFalse(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); assertTrue(contextCancelled.get()); } diff --git a/core/src/test/java/io/grpc/internal/TestUtils.java b/core/src/test/java/io/grpc/internal/TestUtils.java index d5b4ce4949e..974f36e595c 100644 --- a/core/src/test/java/io/grpc/internal/TestUtils.java +++ b/core/src/test/java/io/grpc/internal/TestUtils.java @@ -23,6 +23,7 @@ import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.InternalLogId; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; @@ -35,6 +36,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import javax.annotation.Nullable; +import org.mockito.ArgumentMatchers; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -118,7 +120,8 @@ public ConnectionClientTransport answer(InvocationOnMock invocation) throws Thro when(mockTransport.getLogId()) .thenReturn(InternalLogId.allocate("mocktransport", /*details=*/ null)); when(mockTransport.newStream( - any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) + any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any())) .thenReturn(mock(ClientStream.class)); // Save the listener doAnswer(new Answer() { diff --git a/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java b/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java index fcb19b69eb8..dbd7e99b29a 100644 --- a/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java +++ b/core/src/test/java/io/grpc/util/ForwardingClientStreamTracerTest.java @@ -40,6 +40,7 @@ public void allMethodsForwarded() throws Exception { Collections.emptyList()); } + @SuppressWarnings("deprecation") private final class TestClientStreamTracer extends ForwardingClientStreamTracer { @Override protected ClientStreamTracer delegate() { diff --git a/cronet/README.md b/cronet/README.md index bd5329e5192..8b220bd606d 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.39.0' +implementation 'io.grpc:grpc-cronet:1.41.0' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/cronet/build.gradle b/cronet/build.gradle index 2d73cc4194f..7a6d58a7e4b 100644 --- a/cronet/build.gradle +++ b/cronet/build.gradle @@ -28,6 +28,10 @@ android { consumerProguardFiles 'proguard-rules.pro' } } + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } testOptions { unitTests { includeAndroidResources = true } } lintOptions { disable 'InvalidPackage' } } diff --git a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java index dc4fc45ae4e..d41ec372d4c 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetClientTransport.java @@ -21,6 +21,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -118,7 +119,7 @@ public ListenableFuture getStats() { @Override public CronetClientStream newStream(final MethodDescriptor method, final Metadata headers, - final CallOptions callOptions) { + final CallOptions callOptions, ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); @@ -126,7 +127,7 @@ public CronetClientStream newStream(final MethodDescriptor method, final M final String url = "https://" + authority + defaultPath; final StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, attrs, headers); + StatsTraceContext.newClientContext(tracers, attrs, headers); class StartCallback implements Runnable { final CronetClientStream clientStream = new CronetClientStream( url, userAgent, executor, headers, CronetClientTransport.this, this, lock, maxMessageSize, diff --git a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java index 39fe03991e4..c27963c6d56 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetChannelBuilderTest.java @@ -25,6 +25,7 @@ import android.os.Build; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.cronet.CronetChannelBuilder.CronetTransportFactory; @@ -50,6 +51,8 @@ public final class CronetChannelBuilderTest { @Mock private ExperimentalCronetEngine mockEngine; @Mock private ChannelLogger channelLogger; + private final ClientStreamTracer[] tracers = + new ClientStreamTracer[]{ new ClientStreamTracer() {} }; private MethodDescriptor method = TestMethodDescriptors.voidMethod(); @Before @@ -69,7 +72,8 @@ public void alwaysUsePutTrue_cronetStreamIsIdempotent() throws Exception { new InetSocketAddress("localhost", 443), new ClientTransportOptions(), channelLogger); - CronetClientStream stream = transport.newStream(method, new Metadata(), CallOptions.DEFAULT); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); assertTrue(stream.idempotent); } @@ -85,7 +89,8 @@ public void alwaysUsePut_defaultsToFalse() throws Exception { new InetSocketAddress("localhost", 443), new ClientTransportOptions(), channelLogger); - CronetClientStream stream = transport.newStream(method, new Metadata(), CallOptions.DEFAULT); + CronetClientStream stream = transport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); assertFalse(stream.idempotent); } diff --git a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java index 50017cb43f8..9503481e747 100644 --- a/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java +++ b/cronet/src/test/java/io/grpc/cronet/CronetClientTransportTest.java @@ -27,6 +27,7 @@ import android.os.Build; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.SecurityLevel; @@ -60,6 +61,8 @@ public final class CronetClientTransportTest { private static final Attributes EAG_ATTRS = Attributes.newBuilder().set(EAG_ATTR_KEY, "value").build(); + private final ClientStreamTracer[] tracers = + new ClientStreamTracer[]{ new ClientStreamTracer() {} }; private CronetClientTransport transport; @Mock private StreamBuilderFactory streamFactory; @Mock private Executor executor; @@ -101,9 +104,9 @@ public void transportAttributes() { @Test public void shutdownTransport() throws Exception { CronetClientStream stream1 = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); CronetClientStream stream2 = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); // Create a transport and start two streams on it. ArgumentCaptor callbackCaptor = @@ -137,7 +140,7 @@ public void shutdownTransport() throws Exception { @Test public void startStreamAfterShutdown() throws Exception { CronetClientStream stream = - transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT); + transport.newStream(descriptor, new Metadata(), CallOptions.DEFAULT, tracers); transport.shutdown(); BaseClientStreamListener listener = new BaseClientStreamListener(); stream.start(listener); diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index 93447639197..60e3bb35a85 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.39.0' -implementation 'io.grpc:grpc-okhttp:1.39.0' +implementation 'io.grpc:grpc-android:1.41.0' +implementation 'io.grpc:grpc-okhttp:1.41.0' ``` You also need permission to access the device's network state in your diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 2d6cb8e0097..9b5ef9d448e 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -34,7 +34,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.12' testImplementation 'com.google.truth:truth:1.0.1' - testImplementation 'io.grpc:grpc-testing:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.41.0' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 57490bc84ff..8a3ed89300d 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index ed7ae228853..4b2d3989e2c 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index b117eccc12a..e18273cbadc 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.41.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.41.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.41.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/build.gradle b/examples/build.gradle index 9b2226e38c6..b21a15f8443 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index e3f38649708..e08d8fd1ce5 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protocVersion = '3.17.2' dependencies { diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 708365e8973..0e7d7ece1f0 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index cf901838401..91849437181 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0 example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0 3.17.2 1.7 diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index 9ff9210ab7d..23779a52a2b 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,7 +21,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' dependencies { diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 17e954bf67e..1df16a9acbe 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0 example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0 3.17.2 1.7 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index c3593a2dd08..762d8e08f3e 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index 6e76f72b1e2..bbe496f0c90 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,13 +7,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0 example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0 3.17.2 3.17.2 diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index a272fc30e6a..f60b54c146a 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -23,7 +23,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def protocVersion = '3.17.2' dependencies { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index ec49708c2b0..67d6ff1a507 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0 example-tls https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0 3.17.2 2.0.34.Final diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index dfa860ec0d0..47151427c8a 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -22,7 +22,7 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.40.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.41.0' // CURRENT_GRPC_VERSION def nettyTcNativeVersion = '2.0.31.Final' def protocVersion = '3.17.2' diff --git a/examples/pom.xml b/examples/pom.xml index 3ba96b6c197..a73da5d3bdf 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,13 +6,13 @@ jar - 1.40.0-SNAPSHOT + 1.41.0 examples https://github.com/grpc/grpc-java UTF-8 - 1.40.0-SNAPSHOT + 1.41.0 3.17.2 3.17.2 diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java index d27c485dc13..75f2481254d 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbClientLoadRecorder.java @@ -37,7 +37,7 @@ * span of an LB stream with the remote load-balancer. */ @ThreadSafe -final class GrpclbClientLoadRecorder extends ClientStreamTracer.Factory { +final class GrpclbClientLoadRecorder extends ClientStreamTracer.InternalLimitedInfoFactory { private static final AtomicLongFieldUpdater callsStartedUpdater = AtomicLongFieldUpdater.newUpdater(GrpclbClientLoadRecorder.class, "callsStarted"); diff --git a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java index 03b9bdf7f1b..03e1447bb2c 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java +++ b/grpclb/src/main/java/io/grpc/grpclb/TokenAttachingTracerFactory.java @@ -22,6 +22,7 @@ import io.grpc.Attributes; import io.grpc.ClientStreamTracer; import io.grpc.Metadata; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.internal.GrpcAttributes; import javax.annotation.Nullable; @@ -29,7 +30,7 @@ * Wraps a {@link ClientStreamTracer.Factory}, retrieves tokens from transport attributes and * attaches them to headers. This is only used in the PICK_FIRST mode. */ -final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { +final class TokenAttachingTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { private static final ClientStreamTracer NOOP_TRACER = new ClientStreamTracer() {}; @Nullable @@ -42,19 +43,30 @@ final class TokenAttachingTracerFactory extends ClientStreamTracer.Factory { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { - Attributes transportAttrs = checkNotNull(info.getTransportAttrs(), "transportAttrs"); - Attributes eagAttrs = - checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); - String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); - headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); - if (token != null) { - headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); - } - if (delegate != null) { - return delegate.newClientStreamTracer(info, headers); - } else { + if (delegate == null) { return NOOP_TRACER; } + final ClientStreamTracer clientStreamTracer = delegate.newClientStreamTracer(info, headers); + class TokenPropagationTracer extends ForwardingClientStreamTracer { + @Override + protected ClientStreamTracer delegate() { + return clientStreamTracer; + } + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + Attributes eagAttrs = + checkNotNull(transportAttrs.get(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS), "eagAttrs"); + String token = eagAttrs.get(GrpclbConstants.TOKEN_ATTRIBUTE_KEY); + headers.discardAll(GrpclbConstants.TOKEN_METADATA_KEY); + if (token != null) { + headers.put(GrpclbConstants.TOKEN_METADATA_KEY, token); + } + delegate().streamCreated(transportAttrs, headers); + } + } + + return new TokenPropagationTracer(); } @Override diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 39f736dbcf4..a68962ad7d9 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -481,6 +481,7 @@ public void loadReporting() { ClientStreamTracer tracer1 = pick1.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer1.streamCreated(Attributes.EMPTY, new Metadata()); PickResult pick2 = picker.pickSubchannel(args); assertNull(pick2.getSubchannel()); @@ -504,6 +505,7 @@ public void loadReporting() { assertSame(getLoadRecorder(), pick3.getStreamTracerFactory()); ClientStreamTracer tracer3 = pick3.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer3.streamCreated(Attributes.EMPTY, new Metadata()); // pick3 has sent out headers tracer3.outboundHeaders(); @@ -541,6 +543,7 @@ public void loadReporting() { assertSame(getLoadRecorder(), pick5.getStreamTracerFactory()); ClientStreamTracer tracer5 = pick5.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, new Metadata()); + tracer5.streamCreated(Attributes.EMPTY, new Metadata()); // pick3 ended without receiving response headers tracer3.streamClosed(Status.DEADLINE_EXCEEDED); diff --git a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java index 34b0a8ea1aa..29ded18d913 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/TokenAttachingTracerFactoryTest.java @@ -33,12 +33,23 @@ /** Unit tests for {@link TokenAttachingTracerFactory}. */ @RunWith(JUnit4.class) public class TokenAttachingTracerFactoryTest { - private static final ClientStreamTracer fakeTracer = new ClientStreamTracer() {}; + private static final class FakeClientStreamTracer extends ClientStreamTracer { + Attributes transportAttrs; + Metadata headers; + + @Override + public void streamCreated(Attributes transportAttrs, Metadata headers) { + this.transportAttrs = transportAttrs; + this.headers = headers; + } + } + + private static final FakeClientStreamTracer fakeTracer = new FakeClientStreamTracer(); private final ClientStreamTracer.Factory delegate = mock( ClientStreamTracer.Factory.class, delegatesTo( - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -51,28 +62,25 @@ public void hasToken() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); Attributes eagAttrs = Attributes.newBuilder() .set(GrpclbConstants.TOKEN_ATTRIBUTE_KEY, "token0001").build(); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); // Preexisting token should be replaced headers.put(GrpclbConstants.TOKEN_METADATA_KEY, "preexisting-token"); ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); verify(delegate).newClientStreamTracer(same(info), same(headers)); - assertThat(tracer).isSameInstanceAs(fakeTracer); + Attributes transportAttrs = + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); + tracer.streamCreated(transportAttrs, headers); + assertThat(fakeTracer.transportAttrs).isSameInstanceAs(transportAttrs); + assertThat(fakeTracer.headers).isSameInstanceAs(headers); assertThat(headers.getAll(GrpclbConstants.TOKEN_METADATA_KEY)).containsExactly("token0001"); } @Test public void noToken() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(delegate); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder() - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); // Preexisting token should be removed @@ -80,22 +88,25 @@ public void noToken() { ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); verify(delegate).newClientStreamTracer(same(info), same(headers)); - assertThat(tracer).isSameInstanceAs(fakeTracer); + Attributes transportAttrs = + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(); + tracer.streamCreated(transportAttrs, headers); + assertThat(fakeTracer.transportAttrs).isSameInstanceAs(transportAttrs); + assertThat(fakeTracer.headers).isSameInstanceAs(headers); assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); } @Test public void nullDelegate() { TokenAttachingTracerFactory factory = new TokenAttachingTracerFactory(null); - ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder() - .setTransportAttrs( - Attributes.newBuilder() - .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build()) - .build(); + ClientStreamTracer.StreamInfo info = ClientStreamTracer.StreamInfo.newBuilder().build(); Metadata headers = new Metadata(); ClientStreamTracer tracer = factory.newClientStreamTracer(info, headers); + tracer.streamCreated( + Attributes.newBuilder().set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, Attributes.EMPTY).build(), + headers); assertThat(tracer).isNotNull(); assertThat(headers.get(GrpclbConstants.TOKEN_METADATA_KEY)).isNull(); } diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 79aa5356ecd..944c0daab81 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -27,6 +27,7 @@ dependencies { project(':grpc-stub'), project(':grpc-testing'), project(path: ':grpc-xds', configuration: 'shadow'), + libraries.hdrhistogram, libraries.junit, libraries.truth, libraries.opencensus_contrib_grpc_metrics, @@ -43,6 +44,8 @@ dependencies { libraries.netty_tcnative, project(':grpc-grpclb') testImplementation project(':grpc-context').sourceSets.test.output, + project(':grpc-api').sourceSets.test.output, + project(':grpc-core').sourceSets.test.output, libraries.mockito alpnagent libraries.jetty_alpn_agent } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 758f99d5353..33d263e95de 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -92,6 +92,9 @@ import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; +import io.opencensus.stats.Measure; +import io.opencensus.stats.Measure.MeasureDouble; +import io.opencensus.stats.Measure.MeasureLong; import io.opencensus.tags.TagKey; import io.opencensus.tags.TagValue; import io.opencensus.trace.Span; @@ -124,6 +127,7 @@ import javax.annotation.Nullable; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSession; +import org.HdrHistogram.Histogram; import org.junit.After; import org.junit.Assert; import org.junit.Assume; @@ -152,6 +156,15 @@ public abstract class AbstractInteropTest { * SETTINGS/WINDOW_UPDATE exchange. */ public static final int TEST_FLOW_CONTROL_WINDOW = 65 * 1024; + private static final MeasureLong RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/retries_per_call", "Number of retries per call", "1"); + private static final MeasureLong TRANSPARENT_RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/transparent_retries_per_call", "Transparent retries per call", "1"); + private static final MeasureDouble RETRY_DELAY_PER_CALL = + Measure.MeasureDouble.create( + "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); private static final FakeTagger tagger = new FakeTagger(); private static final FakeTagContextBinarySerializer tagContextBinarySerializer = @@ -289,7 +302,7 @@ final SocketAddress getListenAddress() { new LinkedBlockingQueue<>(); private final ClientStreamTracer.Factory clientStreamTracerFactory = - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer( ClientStreamTracer.StreamInfo info, Metadata headers) { @@ -375,7 +388,8 @@ protected final ClientInterceptor createCensusStatsClientInterceptor() { .getClientInterceptor( tagger, tagContextBinarySerializer, clientStatsRecorder, GrpcUtil.STOPWATCH_SUPPLIER, - true, true, true, false /* real-time metrics */); + true, true, true, + /* recordRealTimeMetrics= */ false); } protected final ServerStreamTracer.Factory createCustomCensusTracerFactory() { @@ -1042,19 +1056,18 @@ public void veryLargeResponse() throws Exception { @Test public void exchangeMetadataUnaryCall() throws Exception { - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub; - // Capture the metadata exchange Metadata fixedHeaders = new Metadata(); // Send a context proto (as it's in the default extension registry) Messages.SimpleContext contextValue = Messages.SimpleContext.newBuilder().setValue("dog").build(); fixedHeaders.put(Util.METADATA_KEY, contextValue); - stub = MetadataUtils.attachHeaders(stub, fixedHeaders); // .. and expect it to be echoed back in trailers AtomicReference trailersCapture = new AtomicReference<>(); AtomicReference headersCapture = new AtomicReference<>(); - stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(fixedHeaders), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); assertNotNull(stub.emptyCall(EMPTY)); @@ -1065,19 +1078,18 @@ public void exchangeMetadataUnaryCall() throws Exception { @Test public void exchangeMetadataStreamingCall() throws Exception { - TestServiceGrpc.TestServiceStub stub = asyncStub; - // Capture the metadata exchange Metadata fixedHeaders = new Metadata(); // Send a context proto (as it's in the default extension registry) Messages.SimpleContext contextValue = Messages.SimpleContext.newBuilder().setValue("dog").build(); fixedHeaders.put(Util.METADATA_KEY, contextValue); - stub = MetadataUtils.attachHeaders(stub, fixedHeaders); // .. and expect it to be echoed back in trailers AtomicReference trailersCapture = new AtomicReference<>(); AtomicReference headersCapture = new AtomicReference<>(); - stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceStub stub = asyncStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(fixedHeaders), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); List responseSizes = Arrays.asList(50, 100, 150, 200); Messages.StreamingOutputCallRequest.Builder streamingOutputBuilder = @@ -1179,6 +1191,7 @@ public void deadlineExceeded() throws Exception { public void deadlineExceededServerStreaming() throws Exception { // warm up the channel and JVM blockingStub.emptyCall(Empty.getDefaultInstance()); + assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); ResponseParameters.Builder responseParameters = ResponseParameters.newBuilder() .setSize(1) .setIntervalUs(10000); @@ -1195,7 +1208,6 @@ public void deadlineExceededServerStreaming() throws Exception { recorder.awaitCompletion(); assertEquals(Status.DEADLINE_EXCEEDED.getCode(), Status.fromThrowable(recorder.getError()).getCode()); - assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); if (metricsExpected()) { // Stream may not have been created when deadline is exceeded, thus we don't check tracer // stats. @@ -1235,10 +1247,18 @@ public void deadlineInPast() throws Exception { checkEndTags( clientEndRecord, "grpc.testing.TestService/EmptyCall", Status.DEADLINE_EXCEEDED.getCode(), true); + assertZeroRetryRecorded(); } // warm up the channel blockingStub.emptyCall(Empty.getDefaultInstance()); + if (metricsExpected()) { + // clientStartRecord + clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + // clientEndRecord + clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); + assertZeroRetryRecorded(); + } try { blockingStub .withDeadlineAfter(-10, TimeUnit.SECONDS) @@ -1249,7 +1269,6 @@ public void deadlineInPast() throws Exception { assertThat(ex.getStatus().getDescription()) .startsWith("ClientCall started after deadline exceeded"); } - assertStatsTrace("grpc.testing.TestService/EmptyCall", Status.Code.OK); if (metricsExpected()) { MetricsRecord clientStartRecord = clientStatsRecorder.pollRecord(5, TimeUnit.SECONDS); checkStartTags(clientStartRecord, "grpc.testing.TestService/EmptyCall", true); @@ -1257,6 +1276,7 @@ public void deadlineInPast() throws Exception { checkEndTags( clientEndRecord, "grpc.testing.TestService/EmptyCall", Status.DEADLINE_EXCEEDED.getCode(), true); + assertZeroRetryRecorded(); } } @@ -1484,11 +1504,11 @@ public void customMetadata() throws Exception { Metadata metadata = new Metadata(); metadata.put(Util.ECHO_INITIAL_METADATA_KEY, "test_initial_metadata_value"); metadata.put(Util.ECHO_TRAILING_METADATA_KEY, trailingBytes); - TestServiceGrpc.TestServiceBlockingStub blockingStub = this.blockingStub; - blockingStub = MetadataUtils.attachHeaders(blockingStub, metadata); AtomicReference headersCapture = new AtomicReference<>(); AtomicReference trailersCapture = new AtomicReference<>(); - blockingStub = MetadataUtils.captureMetadata(blockingStub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceBlockingStub blockingStub = this.blockingStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(metadata), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); SimpleResponse response = blockingStub.unaryCall(request); assertResponse(goldenResponse, response); @@ -1503,11 +1523,11 @@ public void customMetadata() throws Exception { metadata = new Metadata(); metadata.put(Util.ECHO_INITIAL_METADATA_KEY, "test_initial_metadata_value"); metadata.put(Util.ECHO_TRAILING_METADATA_KEY, trailingBytes); - TestServiceGrpc.TestServiceStub stub = asyncStub; - stub = MetadataUtils.attachHeaders(stub, metadata); headersCapture = new AtomicReference<>(); trailersCapture = new AtomicReference<>(); - stub = MetadataUtils.captureMetadata(stub, headersCapture, trailersCapture); + TestServiceGrpc.TestServiceStub stub = asyncStub.withInterceptors( + MetadataUtils.newAttachHeadersInterceptor(metadata), + MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture)); StreamRecorder recorder = StreamRecorder.create(); StreamObserver requestStream = @@ -1867,6 +1887,128 @@ public void googleDefaultCredentials( assertResponse(goldenResponse, response); } + private static class SoakIterationResult { + public SoakIterationResult(long latencyMs, Status status) { + this.latencyMs = latencyMs; + this.status = status; + } + + public long getLatencyMs() { + return latencyMs; + } + + public Status getStatus() { + return status; + } + + private long latencyMs = -1; + private Status status = Status.OK; + } + + private SoakIterationResult performOneSoakIteration(boolean resetChannel) throws Exception { + long startNs = System.nanoTime(); + Status status = Status.OK; + ManagedChannel soakChannel = channel; + TestServiceGrpc.TestServiceBlockingStub soakStub = blockingStub; + if (resetChannel) { + soakChannel = createChannel(); + soakStub = TestServiceGrpc.newBlockingStub(soakChannel); + } + try { + final SimpleRequest request = + SimpleRequest.newBuilder() + .setResponseSize(314159) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse goldenResponse = + SimpleResponse.newBuilder() + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, soakStub.unaryCall(request)); + } catch (StatusRuntimeException e) { + status = e.getStatus(); + } + long elapsedNs = System.nanoTime() - startNs; + if (resetChannel) { + soakChannel.shutdownNow(); + soakChannel.awaitTermination(10, TimeUnit.SECONDS); + } + return new SoakIterationResult(TimeUnit.NANOSECONDS.toMillis(elapsedNs), status); + } + + /** + * Runs large unary RPCs in a loop with configurable failure thresholds + * and channel creation behavior. + */ + public void performSoakTest( + boolean resetChannelPerIteration, + int soakIterations, + int maxFailures, + int maxAcceptablePerIterationLatencyMs, + int overallTimeoutSeconds) + throws Exception { + int iterationsDone = 0; + int totalFailures = 0; + Histogram latencies = new Histogram(4 /* number of significant value digits */); + long startNs = System.nanoTime(); + for (int i = 0; i < soakIterations; i++) { + if (System.nanoTime() - startNs >= TimeUnit.SECONDS.toNanos(overallTimeoutSeconds)) { + break; + } + SoakIterationResult result = performOneSoakIteration(resetChannelPerIteration); + System.err.print( + String.format( + "soak iteration: %d elapsed: %d ms", i, result.getLatencyMs())); + if (!result.getStatus().equals(Status.OK)) { + totalFailures++; + System.err.println(String.format(" failed: %s", result.getStatus())); + } else if (result.getLatencyMs() > maxAcceptablePerIterationLatencyMs) { + totalFailures++; + System.err.println( + String.format( + " exceeds max acceptable latency: %d", maxAcceptablePerIterationLatencyMs)); + } else { + System.err.println(" succeeded"); + } + iterationsDone++; + latencies.recordValue(result.getLatencyMs()); + } + System.err.println( + String.format( + "soak test ran: %d / %d iterations\n" + + "total failures: %d\n" + + "max failures threshold: %d\n" + + "max acceptable per iteration latency ms: %d\n" + + " p50 soak iteration latency: %d ms\n" + + " p90 soak iteration latency: %d ms\n" + + "p100 soak iteration latency: %d ms\n" + + "See breakdown above for which iterations succeeded, failed, and " + + "why for more info.", + iterationsDone, + soakIterations, + totalFailures, + maxFailures, + maxAcceptablePerIterationLatencyMs, + latencies.getValueAtPercentile(50), + latencies.getValueAtPercentile(90), + latencies.getValueAtPercentile(100))); + // check if we timed out + String timeoutErrorMessage = + String.format( + "soak test consumed all %d seconds of time and quit early, only " + + "having ran %d out of desired %d iterations.", + overallTimeoutSeconds, + iterationsDone, + soakIterations); + assertEquals(timeoutErrorMessage, iterationsDone, soakIterations); + // check if we had too many failures + String tooManyFailuresErrorMessage = + String.format( + "soak test total failures: %d exceeds max failures threshold: %d.", + totalFailures, maxFailures); + assertTrue(tooManyFailuresErrorMessage, totalFailures <= maxFailures); + } + protected static void assertSuccess(StreamRecorder recorder) { if (recorder.getError() != null) { throw new AssertionError(recorder.getError()); @@ -1974,6 +2116,13 @@ private void assertStatsTrace(String method, Status.Code status) { assertStatsTrace(method, status, null, null); } + private void assertZeroRetryRecorded() { + MetricsRecord retryRecord = clientStatsRecorder.pollRecord(); + assertThat(retryRecord.getMetric(RETRIES_PER_CALL)).isEqualTo(0); + assertThat(retryRecord.getMetric(TRANSPARENT_RETRIES_PER_CALL)).isEqualTo(0); + assertThat(retryRecord.getMetric(RETRY_DELAY_PER_CALL)).isEqualTo(0D); + } + private void assertClientStatsTrace(String method, Status.Code code, Collection requests, Collection responses) { // Tracer-based stats @@ -2003,6 +2152,7 @@ private void assertClientStatsTrace(String method, Status.Code code, if (requests != null && responses != null) { checkCensus(clientEndRecord, false, requests, responses); } + assertZeroRetryRecorded(); } } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java index 2d1648e157a..39afaa99d6e 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java @@ -54,7 +54,9 @@ public enum TestCases { CANCEL_AFTER_FIRST_RESPONSE("cancel on first response"), TIMEOUT_ON_SLEEPING_SERVER("timeout before receiving a response"), VERY_LARGE_REQUEST("very large request"), - PICK_FIRST_UNARY("all requests are sent to one server despite multiple servers are resolved"); + PICK_FIRST_UNARY("all requests are sent to one server despite multiple servers are resolved"), + RPC_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on the same channel"), + CHANNEL_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on a new channel"); private final String description; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index f8546880eae..914db12e5a8 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -86,6 +86,11 @@ public static void main(String[] args) throws Exception { private boolean fullStreamDecompression; private int localHandshakerPort = -1; private Map serviceConfig = null; + private int soakIterations = 10; + private int soakMaxFailures = 0; + private int soakPerIterationMaxAcceptableLatencyMs = 1000; + private int soakOverallTimeoutSeconds = + soakIterations * soakPerIterationMaxAcceptableLatencyMs / 1000; private Tester tester = new Tester(); @@ -150,6 +155,14 @@ void parseArgs(String[] args) throws Exception { @SuppressWarnings("unchecked") Map map = (Map) JsonParser.parse(value); serviceConfig = map; + } else if ("soak_iterations".equals(key)) { + soakIterations = Integer.parseInt(value); + } else if ("soak_max_failures".equals(key)) { + soakMaxFailures = Integer.parseInt(value); + } else if ("soak_per_iteration_max_acceptable_latency_ms".equals(key)) { + soakPerIterationMaxAcceptableLatencyMs = Integer.parseInt(value); + } else if ("soak_overall_timeout_seconds".equals(key)) { + soakOverallTimeoutSeconds = Integer.parseInt(value); } else { System.err.println("Unknown argument: " + key); usage = true; @@ -196,6 +209,23 @@ void parseArgs(String[] args) throws Exception { + "\n --service_config_json=SERVICE_CONFIG_JSON" + "\n Disables service config lookups and sets the provided " + "\n string as the default service config." + + "\n --soak_iterations The number of iterations to use for the two soak " + + "\n tests: rpc_soak and channel_soak. Default " + + c.soakIterations + + "\n --soak_max_failures The number of iterations in soak tests that are " + + "\n allowed to fail (either due to non-OK status code or " + + "\n exceeding the per-iteration max acceptable latency). " + + "\n Default " + c.soakMaxFailures + + "\n --soak_per_iteration_max_acceptable_latency_ms " + + "\n The number of milliseconds a single iteration in the " + + "\n two soak tests (rpc_soak and channel_soak) should " + + "\n take. Default " + + c.soakPerIterationMaxAcceptableLatencyMs + + "\n --soak_overall_timeout_seconds " + + "\n The overall number of seconds after which a soak test " + + "\n should stop and fail, if the desired number of " + + "\n iterations have not yet completed. Default " + + c.soakOverallTimeoutSeconds ); System.exit(1); } @@ -412,6 +442,26 @@ private void runTest(TestCases testCase) throws Exception { break; } + case RPC_SOAK: { + tester.performSoakTest( + false /* resetChannelPerIteration */, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakOverallTimeoutSeconds); + break; + } + + case CHANNEL_SOAK: { + tester.performSoakTest( + true /* resetChannelPerIteration */, + soakIterations, + soakMaxFailures, + soakPerIterationMaxAcceptableLatencyMs, + soakOverallTimeoutSeconds); + break; + } + default: throw new IllegalArgumentException("Unknown test case: " + testCase); } diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java index d48be9f5031..087152dca64 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java @@ -369,7 +369,7 @@ public void onCompleted() { @Override public void onError(Throwable t) { if (printResponse) { - logger.log(Level.WARNING, "Rpc failed: {0}", t); + logger.log(Level.WARNING, "Rpc failed", t); } handleRpcError(requestId, config.rpcType, Status.fromThrowable(t), savedWatchers); diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java new file mode 100644 index 00000000000..eb815501d5c --- /dev/null +++ b/interop-testing/src/test/java/io/grpc/testing/integration/RetryTest.java @@ -0,0 +1,514 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.testing.integration; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.timeout; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableMap; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientStreamTracer; +import io.grpc.ClientStreamTracer.StreamInfo; +import io.grpc.Deadline; +import io.grpc.Deadline.Ticker; +import io.grpc.IntegerMarshaller; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StringMarshaller; +import io.grpc.census.InternalCensusStatsAccessor; +import io.grpc.census.internal.DeprecatedCensusConstants; +import io.grpc.internal.FakeClock; +import io.grpc.internal.testing.StatsTestUtils.FakeStatsRecorder; +import io.grpc.internal.testing.StatsTestUtils.FakeTagContextBinarySerializer; +import io.grpc.internal.testing.StatsTestUtils.FakeTagger; +import io.grpc.internal.testing.StatsTestUtils.MetricsRecord; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.testing.GrpcCleanupRule; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.ScheduledFuture; +import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; +import io.opencensus.stats.Measure; +import io.opencensus.stats.Measure.MeasureDouble; +import io.opencensus.stats.Measure.MeasureLong; +import io.opencensus.tags.TagValue; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class RetryTest { + private static final FakeTagger tagger = new FakeTagger(); + private static final FakeTagContextBinarySerializer tagContextBinarySerializer = + new FakeTagContextBinarySerializer(); + private static final MeasureLong RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/retries_per_call", "Number of retries per call", "1"); + private static final MeasureLong TRANSPARENT_RETRIES_PER_CALL = + Measure.MeasureLong.create( + "grpc.io/client/transparent_retries_per_call", "Transparent retries per call", "1"); + private static final MeasureDouble RETRY_DELAY_PER_CALL = + Measure.MeasureDouble.create( + "grpc.io/client/retry_delay_per_call", "Retry delay per call", "ms"); + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + @Rule + public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); + private final FakeClock fakeClock = new FakeClock(); + @Mock + private ClientCall.Listener mockCallListener; + private CountDownLatch backoffLatch = new CountDownLatch(1); + private final EventLoopGroup group = new DefaultEventLoopGroup() { + @SuppressWarnings("FutureReturnValueIgnored") + @Override + public ScheduledFuture schedule( + final Runnable command, final long delay, final TimeUnit unit) { + if (!command.getClass().getName().contains("RetryBackoffRunnable")) { + return super.schedule(command, delay, unit); + } + fakeClock.getScheduledExecutorService().schedule( + new Runnable() { + @Override + public void run() { + group.execute(command); + } + }, + delay, + unit); + backoffLatch.countDown(); + return super.schedule( + new Runnable() { + @Override + public void run() {} // no-op + }, + 0, + TimeUnit.NANOSECONDS); + } + }; + private final FakeStatsRecorder clientStatsRecorder = new FakeStatsRecorder(); + private final ClientInterceptor statsInterceptor = + InternalCensusStatsAccessor.getClientInterceptor( + tagger, tagContextBinarySerializer, clientStatsRecorder, + fakeClock.getStopwatchSupplier(), true, true, true, + /* recordRealTimeMetrics= */ true); + private final MethodDescriptor clientStreamingMethod = + MethodDescriptor.newBuilder() + .setType(MethodType.CLIENT_STREAMING) + .setFullMethodName("service/method") + .setRequestMarshaller(new StringMarshaller()) + .setResponseMarshaller(new IntegerMarshaller()) + .build(); + private final LinkedBlockingQueue> serverCalls = + new LinkedBlockingQueue<>(); + private final ServerMethodDefinition methodDefinition = + ServerMethodDefinition.create( + clientStreamingMethod, + new ServerCallHandler() { + @Override + public Listener startCall(ServerCall call, Metadata headers) { + serverCalls.offer(call); + return new Listener() {}; + } + } + ); + private final ServerServiceDefinition serviceDefinition = + ServerServiceDefinition.builder(clientStreamingMethod.getServiceName()) + .addMethod(methodDefinition) + .build(); + private final LocalAddress localAddress = new LocalAddress(this.getClass().getName()); + private Server localServer; + private ManagedChannel channel; + private Map retryPolicy = null; + private long bufferLimit = 1L << 20; // 1M + + private void startNewServer() throws Exception { + localServer = cleanupRule.register(NettyServerBuilder.forAddress(localAddress) + .channelType(LocalServerChannel.class) + .bossEventLoopGroup(group) + .workerEventLoopGroup(group) + .addService(serviceDefinition) + .build()); + localServer.start(); + } + + private void createNewChannel() { + Map methodConfig = new HashMap<>(); + Map name = new HashMap<>(); + name.put("service", "service"); + methodConfig.put("name", Arrays.asList(name)); + if (retryPolicy != null) { + methodConfig.put("retryPolicy", retryPolicy); + } + Map rawServiceConfig = new HashMap<>(); + rawServiceConfig.put("methodConfig", Arrays.asList(methodConfig)); + channel = cleanupRule.register( + NettyChannelBuilder.forAddress(localAddress) + .channelType(LocalChannel.class) + .eventLoopGroup(group) + .usePlaintext() + .enableRetry() + .perRpcBufferLimit(bufferLimit) + .defaultServiceConfig(rawServiceConfig) + .intercept(statsInterceptor) + .build()); + } + + private void elapseBackoff(long time, TimeUnit unit) throws Exception { + assertThat(backoffLatch.await(5, SECONDS)).isTrue(); + backoffLatch = new CountDownLatch(1); + fakeClock.forwardTime(time, unit); + } + + private void assertRpcStartedRecorded() throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_STARTED_COUNT)) + .isEqualTo(1); + } + + private void assertOutboundMessageRecorded() throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat( + record.getMetricAsLongOrFail( + RpcMeasureConstants.GRPC_CLIENT_SENT_MESSAGES_PER_METHOD)) + .isEqualTo(1); + } + + private void assertInboundMessageRecorded() throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat( + record.getMetricAsLongOrFail( + RpcMeasureConstants.GRPC_CLIENT_RECEIVED_MESSAGES_PER_METHOD)) + .isEqualTo(1); + } + + private void assertOutboundWireSizeRecorded(long length) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat(record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_SENT_BYTES_PER_METHOD)) + .isEqualTo(length); + } + + private void assertInboundWireSizeRecorded(long length) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat( + record.getMetricAsLongOrFail(RpcMeasureConstants.GRPC_CLIENT_RECEIVED_BYTES_PER_METHOD)) + .isEqualTo(length); + } + + private void assertRpcStatusRecorded( + Status.Code code, long roundtripLatencyMs, long outboundMessages) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + TagValue statusTag = record.tags.get(RpcMeasureConstants.GRPC_CLIENT_STATUS); + assertThat(statusTag.asString()).isEqualTo(code.toString()); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_FINISHED_COUNT)) + .isEqualTo(1); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_ROUNDTRIP_LATENCY)) + .isEqualTo(roundtripLatencyMs); + assertThat(record.getMetricAsLongOrFail(DeprecatedCensusConstants.RPC_CLIENT_REQUEST_COUNT)) + .isEqualTo(outboundMessages); + } + + private void assertRetryStatsRecorded( + int numRetries, int numTransparentRetries, long retryDelayMs) throws Exception { + MetricsRecord record = clientStatsRecorder.pollRecord(5, SECONDS); + assertThat(record.getMetricAsLongOrFail(RETRIES_PER_CALL)).isEqualTo(numRetries); + assertThat(record.getMetricAsLongOrFail(TRANSPARENT_RETRIES_PER_CALL)) + .isEqualTo(numTransparentRetries); + assertThat(record.getMetricAsLongOrFail(RETRY_DELAY_PER_CALL)).isEqualTo(retryDelayMs); + } + + @Test + public void retryUntilBufferLimitExceeded() throws Exception { + String message = "String of length 20."; + + startNewServer(); + bufferLimit = message.length() * 2L - 1; // Can buffer no more than 1 message. + retryPolicy = ImmutableMap.builder() + .put("maxAttempts", 4D) + .put("initialBackoff", "10s") + .put("maxBackoff", "10s") + .put("backoffMultiplier", 1D) + .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) + .build(); + createNewChannel(); + ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + call.sendMessage(message); + + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + // trigger retry + serverCall.close( + Status.UNAVAILABLE.withDescription("original attempt failed"), + new Metadata()); + elapseBackoff(10, SECONDS); + // 2nd attempt received + serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + verify(mockCallListener, never()).onClose(any(Status.class), any(Metadata.class)); + // send one more message, should exceed buffer limit + call.sendMessage(message); + // let attempt fail + serverCall.close( + Status.UNAVAILABLE.withDescription("2nd attempt failed"), + new Metadata()); + // no more retry + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); + verify(mockCallListener, timeout(5000)).onClose(statusCaptor.capture(), any(Metadata.class)); + assertThat(statusCaptor.getValue().getDescription()).contains("2nd attempt failed"); + } + + @Test + public void statsRecorded() throws Exception { + startNewServer(); + retryPolicy = ImmutableMap.builder() + .put("maxAttempts", 4D) + .put("initialBackoff", "10s") + .put("maxBackoff", "10s") + .put("backoffMultiplier", 1D) + .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) + .build(); + createNewChannel(); + + ClientCall call = channel.newCall(clientStreamingMethod, CallOptions.DEFAULT); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + String message = "String of length 20."; + call.sendMessage(message); + assertOutboundMessageRecorded(); + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + // original attempt latency + fakeClock.forwardTime(1, SECONDS); + // trigger retry + serverCall.close( + Status.UNAVAILABLE.withDescription("original attempt failed"), + new Metadata()); + assertRpcStatusRecorded(Status.Code.UNAVAILABLE, 1000, 1); + elapseBackoff(10, SECONDS); + assertRpcStartedRecorded(); + assertOutboundMessageRecorded(); + serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + message = "new message"; + call.sendMessage(message); + assertOutboundMessageRecorded(); + assertOutboundWireSizeRecorded(message.length()); + // retry attempt latency + fakeClock.forwardTime(2, SECONDS); + serverCall.sendHeaders(new Metadata()); + serverCall.sendMessage(3); + call.request(1); + assertInboundMessageRecorded(); + assertInboundWireSizeRecorded(1); + serverCall.close(Status.OK, new Metadata()); + assertRpcStatusRecorded(Status.Code.OK, 2000, 2); + assertRetryStatsRecorded(1, 0, 10_000); + } + + @Test + public void statsRecorde_callCancelledBeforeCommit() throws Exception { + startNewServer(); + retryPolicy = ImmutableMap.builder() + .put("maxAttempts", 4D) + .put("initialBackoff", "10s") + .put("maxBackoff", "10s") + .put("backoffMultiplier", 1D) + .put("retryableStatusCodes", Arrays.asList("UNAVAILABLE")) + .build(); + createNewChannel(); + + // We will have streamClosed return at a particular moment that we want. + final CountDownLatch streamClosedLatch = new CountDownLatch(1); + ClientStreamTracer.Factory streamTracerFactory = new ClientStreamTracer.Factory() { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new ClientStreamTracer() { + @Override + public void streamClosed(Status status) { + if (status.getCode().equals(Code.CANCELLED)) { + try { + streamClosedLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new AssertionError("streamClosedLatch interrupted", e); + } + } + } + }; + } + }; + ClientCall call = channel.newCall( + clientStreamingMethod, CallOptions.DEFAULT.withStreamTracerFactory(streamTracerFactory)); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + fakeClock.forwardTime(5, SECONDS); + String message = "String of length 20."; + call.sendMessage(message); + assertOutboundMessageRecorded(); + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + // trigger retry + serverCall.close( + Status.UNAVAILABLE.withDescription("original attempt failed"), + new Metadata()); + assertRpcStatusRecorded(Code.UNAVAILABLE, 5000, 1); + elapseBackoff(10, SECONDS); + assertRpcStartedRecorded(); + assertOutboundMessageRecorded(); + serverCall = serverCalls.poll(5, SECONDS); + serverCall.request(2); + assertOutboundWireSizeRecorded(message.length()); + fakeClock.forwardTime(7, SECONDS); + call.cancel("Cancelled before commit", null); // A noop substream will commit. + // The call listener is closed, but the netty substream listener is not yet closed. + verify(mockCallListener, timeout(5000)).onClose(any(Status.class), any(Metadata.class)); + // Let the netty substream listener be closed. + streamClosedLatch.countDown(); + assertRetryStatsRecorded(1, 0, 10_000); + assertRpcStatusRecorded(Code.CANCELLED, 7_000, 1); + } + + @Test + public void serverCancelledAndClientDeadlineExceeded() throws Exception { + startNewServer(); + createNewChannel(); + + class CloseDelayedTracer extends ClientStreamTracer { + @Override + public void streamClosed(Status status) { + fakeClock.forwardTime(10, SECONDS); + } + } + + class CloseDelayedTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new CloseDelayedTracer(); + } + } + + CallOptions callOptions = CallOptions.DEFAULT + .withDeadline(Deadline.after( + 10, + SECONDS, + new Ticker() { + @Override + public long nanoTime() { + return fakeClock.getTicker().read(); + } + })) + .withStreamTracerFactory(new CloseDelayedTracerFactory()); + ClientCall call = channel.newCall(clientStreamingMethod, callOptions); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + ServerCall serverCall = serverCalls.poll(5, SECONDS); + serverCall.close(Status.CANCELLED, new Metadata()); + assertRpcStatusRecorded(Code.DEADLINE_EXCEEDED, 10_000, 0); + assertRetryStatsRecorded(0, 0, 0); + } + + @Ignore("flaky because old transportReportStatus() is not completely migrated yet") + @Test + public void transparentRetryStatsRecorded() throws Exception { + startNewServer(); + createNewChannel(); + + final AtomicBoolean transparentRetryTriggered = new AtomicBoolean(); + class TransparentRetryTriggeringTracer extends ClientStreamTracer { + + @Override + public void streamCreated(Attributes transportAttrs, Metadata metadata) { + if (transparentRetryTriggered.get()) { + return; + } + localServer.shutdownNow(); + } + + @Override + public void streamClosed(Status status) { + if (transparentRetryTriggered.get()) { + return; + } + transparentRetryTriggered.set(true); + try { + startNewServer(); + channel.resetConnectBackoff(); + channel.getState(true); + } catch (Exception e) { + throw new AssertionError("local server can not be restarted", e); + } + } + } + + class TransparentRetryTracerFactory extends ClientStreamTracer.InternalLimitedInfoFactory { + @Override + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + return new TransparentRetryTriggeringTracer(); + } + } + + CallOptions callOptions = CallOptions.DEFAULT + .withWaitForReady() + .withStreamTracerFactory(new TransparentRetryTracerFactory()); + ClientCall call = channel.newCall(clientStreamingMethod, callOptions); + call.start(mockCallListener, new Metadata()); + assertRpcStartedRecorded(); + assertRpcStatusRecorded(Code.UNAVAILABLE, 0, 0); + assertRpcStartedRecorded(); + call.cancel("cancel", null); + assertRpcStatusRecorded(Code.CANCELLED, 0, 0); + assertRetryStatsRecorded(0, 1, 0); + } +} diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java index 4e511c1cfe5..14a98514918 100644 --- a/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java +++ b/interop-testing/src/test/java/io/grpc/testing/integration/TestCasesTest.java @@ -73,7 +73,9 @@ public void testCaseNamesShouldMapToEnums() { "client_compressed_unary_noprobe", "client_compressed_streaming_noprobe", "very_large_request", - "pick_first_unary" + "pick_first_unary", + "channel_soak", + "rpc_soak" }; assertEquals(testCases.length + additionalTestCases.length, TestCases.values().length); diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index 521256ea13d..6b1dad644d1 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -4,16 +4,6 @@ import org.gradle.api.file.FileTreeElement import shadow.org.apache.tools.zip.ZipOutputStream import shadow.org.apache.tools.zip.ZipEntry - -buildscript { - repositories { - jcenter() - } - dependencies { - classpath "com.github.jengelman.gradle.plugins:shadow:6.1.0" - } -} - plugins { id "java" id "maven-publish" @@ -120,8 +110,9 @@ class NettyResourceTransformer implements Transformer { @Override void transform(TransformerContext context) { + String updatedPath = context.path.replace("io.netty", "io.grpc.netty.shaded.io.netty") String updatedContent = context.is.getText().replace("io.netty", "io.grpc.netty.shaded.io.netty") - resources.put(context.path, updatedContent) + resources.put(updatedPath, updatedContent) } @Override diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index d3bdc4394ca..5c2ff317ccd 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -78,7 +78,8 @@ public void noNormalNetty() throws Exception { @Test public void nettyResourcesUpdated() throws IOException { InputStream inputStream = NettyChannelBuilder.class.getClassLoader() - .getResourceAsStream("META-INF/native-image/io.netty/transport/reflection-config.json"); + .getResourceAsStream( + "META-INF/native-image/io.grpc.netty.shaded.io.netty/transport/reflection-config.json"); assertThat(inputStream).isNotNull(); Scanner s = new Scanner(inputStream, StandardCharsets.UTF_8.name()).useDelimiter("\\A"); diff --git a/netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java b/netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java new file mode 100644 index 00000000000..97904687548 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/GracefulServerCloseCommand.java @@ -0,0 +1,53 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.base.Preconditions; +import java.util.concurrent.TimeUnit; + +/** + * A command to trigger close and allow streams naturally close. + */ +class GracefulServerCloseCommand extends WriteQueue.AbstractQueuedCommand { + private final String goAwayDebugString; + private final long graceTime; + private final TimeUnit graceTimeUnit; + + public GracefulServerCloseCommand(String goAwayDebugString) { + this(goAwayDebugString, -1, null); + } + + public GracefulServerCloseCommand( + String goAwayDebugString, long graceTime, TimeUnit graceTimeUnit) { + this.goAwayDebugString = Preconditions.checkNotNull(goAwayDebugString, "goAwayDebugString"); + this.graceTime = graceTime; + this.graceTimeUnit = graceTimeUnit; + } + + public String getGoAwayDebugString() { + return goAwayDebugString; + } + + /** Has no meaning if {@code getGraceTimeUnit() == null}. */ + public long getGraceTime() { + return graceTime; + } + + public TimeUnit getGraceTimeUnit() { + return graceTimeUnit; + } +} diff --git a/netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java b/netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java new file mode 100644 index 00000000000..deb72373ac7 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/InternalGracefulServerCloseCommand.java @@ -0,0 +1,36 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import io.grpc.Internal; +import java.util.concurrent.TimeUnit; + +/** + * Internal accessor for {@link GracefulServerCloseCommand}. + */ +@Internal +public final class InternalGracefulServerCloseCommand { + private InternalGracefulServerCloseCommand() {} + + public static Object create(String goAwayDebugString) { + return new GracefulServerCloseCommand(goAwayDebugString); + } + + public static Object create(String goAwayDebugString, long graceTime, TimeUnit graceTimeUnit) { + return new GracefulServerCloseCommand(goAwayDebugString, graceTime, graceTimeUnit); + } +} diff --git a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java index d263356204e..6dde8c825ef 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientHandler.java @@ -62,7 +62,6 @@ import io.netty.handler.codec.http2.Http2ConnectionEncoder; import io.netty.handler.codec.http2.Http2Error; import io.netty.handler.codec.http2.Http2Exception; -import io.netty.handler.codec.http2.Http2FlowController; import io.netty.handler.codec.http2.Http2FrameAdapter; import io.netty.handler.codec.http2.Http2FrameLogger; import io.netty.handler.codec.http2.Http2FrameReader; @@ -217,17 +216,7 @@ static NettyClientHandler newHandler( Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader); - transportTracer.setFlowControlWindowReader(new TransportTracer.FlowControlReader() { - final Http2FlowController local = connection.local().flowController(); - final Http2FlowController remote = connection.remote().flowController(); - - @Override - public TransportTracer.FlowControlWindows read() { - return new TransportTracer.FlowControlWindows( - local.windowSize(connection.connectionStream()), - remote.windowSize(connection.connectionStream())); - } - }); + transportTracer.setFlowControlWindowReader(new Utils.FlowControlReader(connection)); Http2Settings settings = new Http2Settings(); settings.pushEnabled(false); @@ -822,6 +811,7 @@ private void goingAway(long errorCode, byte[] debugData) { // UNAVAILABLE. https://github.com/netty/netty/issues/10670 final Status abruptGoAwayStatusConservative = statusFromH2Error( null, "Abrupt GOAWAY closed sent stream", errorCode, debugData); + final boolean mayBeHittingNettyBug = errorCode != Http2Error.NO_ERROR.code(); // Try to allocate as many in-flight streams as possible, to reduce race window of // https://github.com/grpc/grpc-java/issues/2562 . To be of any help, the server has to // gracefully shut down the connection with two GOAWAYs. gRPC servers generally send a PING @@ -848,11 +838,12 @@ public boolean visit(Http2Stream stream) throws Http2Exception { if (clientStream != null) { // RpcProgress _should_ be REFUSED, but are being conservative. See comment for // abruptGoAwayStatusConservative. This does reduce our ability to perform transparent - // retries, but our main goal of transporent retries is to resolve the local race. We - // still hope/expect servers to use the graceful double-GOAWAY when closing - // connections. + // retries, but only if something else caused a connection failure. + RpcProgress progress = mayBeHittingNettyBug + ? RpcProgress.PROCESSED + : RpcProgress.REFUSED; clientStream.transportReportStatus( - abruptGoAwayStatusConservative, RpcProgress.PROCESSED, false, new Metadata()); + abruptGoAwayStatusConservative, progress, false, new Metadata()); } stream.close(); } diff --git a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java index c3807986c9f..a7a1044059c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyClientTransport.java +++ b/netty/src/main/java/io/grpc/netty/NettyClientTransport.java @@ -28,6 +28,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalLogId; import io.grpc.Metadata; @@ -167,14 +168,15 @@ public void operationComplete(ChannelFuture future) throws Exception { @Override public ClientStream newStream( - MethodDescriptor method, Metadata headers, CallOptions callOptions) { + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); if (channel == null) { - return new FailingClientStream(statusExplainingWhyTheChannelIsNull); + return new FailingClientStream(statusExplainingWhyTheChannelIsNull, tracers); } StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, getAttributes(), headers); + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); return new NettyClientStream( new NettyClientStream.TransportState( handler, diff --git a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java index 01ed7c0c373..0a34644267f 100644 --- a/netty/src/main/java/io/grpc/netty/NettyServerHandler.java +++ b/netty/src/main/java/io/grpc/netty/NettyServerHandler.java @@ -69,7 +69,6 @@ import io.netty.handler.codec.http2.Http2Error; import io.netty.handler.codec.http2.Http2Exception; import io.netty.handler.codec.http2.Http2Exception.StreamException; -import io.netty.handler.codec.http2.Http2FlowController; import io.netty.handler.codec.http2.Http2FrameAdapter; import io.netty.handler.codec.http2.Http2FrameLogger; import io.netty.handler.codec.http2.Http2FrameReader; @@ -367,23 +366,8 @@ public void run() { keepAliveManager.onTransportStarted(); } - - if (transportTracer != null) { - assert encoder().connection().equals(decoder().connection()); - final Http2Connection connection = encoder().connection(); - transportTracer.setFlowControlWindowReader(new TransportTracer.FlowControlReader() { - private final Http2FlowController local = connection.local().flowController(); - private final Http2FlowController remote = connection.remote().flowController(); - - @Override - public TransportTracer.FlowControlWindows read() { - assert ctx.executor().inEventLoop(); - return new TransportTracer.FlowControlWindows( - local.windowSize(connection.connectionStream()), - remote.windowSize(connection.connectionStream())); - } - }); - } + assert encoder().connection().equals(decoder().connection()); + transportTracer.setFlowControlWindowReader(new Utils.FlowControlReader(encoder().connection())); super.handlerAdded(ctx); } @@ -634,6 +618,8 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) sendResponseHeaders(ctx, (SendResponseHeadersCommand) msg, promise); } else if (msg instanceof CancelServerStreamCommand) { cancelStream(ctx, (CancelServerStreamCommand) msg, promise); + } else if (msg instanceof GracefulServerCloseCommand) { + gracefulClose(ctx, (GracefulServerCloseCommand) msg, promise); } else if (msg instanceof ForcefulCloseCommand) { forcefulClose(ctx, (ForcefulCloseCommand) msg, promise); } else { @@ -647,11 +633,8 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - if (gracefulShutdown == null) { - gracefulShutdown = new GracefulShutdown("app_requested", null); - gracefulShutdown.start(ctx); - ctx.flush(); - } + gracefulClose(ctx, new GracefulServerCloseCommand("app_requested"), promise); + ctx.flush(); } /** @@ -732,6 +715,21 @@ private void cancelStream(ChannelHandlerContext ctx, CancelServerStreamCommand c } } + private void gracefulClose(final ChannelHandlerContext ctx, final GracefulServerCloseCommand msg, + ChannelPromise promise) throws Exception { + // Ideally we'd adjust a pre-existing graceful shutdown's grace period to at least what is + // requested here. But that's an edge case and seems bug-prone. + if (gracefulShutdown == null) { + Long graceTimeInNanos = null; + if (msg.getGraceTimeUnit() != null) { + graceTimeInNanos = msg.getGraceTimeUnit().toNanos(msg.getGraceTime()); + } + gracefulShutdown = new GracefulShutdown(msg.getGoAwayDebugString(), graceTimeInNanos); + gracefulShutdown.start(ctx); + } + promise.setSuccess(); + } + private void forcefulClose(final ChannelHandlerContext ctx, final ForcefulCloseCommand msg, ChannelPromise promise) throws Exception { super.close(ctx, promise); @@ -895,16 +893,14 @@ public void ping() { ChannelFuture pingFuture = encoder().writePing( ctx, false /* isAck */, KEEPALIVE_PING, ctx.newPromise()); ctx.flush(); - if (transportTracer != null) { - pingFuture.addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - transportTracer.reportKeepAliveSent(); - } + pingFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + transportTracer.reportKeepAliveSent(); } - }); - } + } + }); } @Override diff --git a/netty/src/main/java/io/grpc/netty/Utils.java b/netty/src/main/java/io/grpc/netty/Utils.java index 082ce63dd54..c2f2fa4a7bf 100644 --- a/netty/src/main/java/io/grpc/netty/Utils.java +++ b/netty/src/main/java/io/grpc/netty/Utils.java @@ -32,6 +32,7 @@ import io.grpc.Status; import io.grpc.internal.GrpcUtil; import io.grpc.internal.SharedResourceHolder.Resource; +import io.grpc.internal.TransportTracer; import io.grpc.netty.GrpcHttp2HeadersUtils.GrpcHttp2InboundHeaders; import io.grpc.netty.NettySocketSupport.NativeSocketOptions; import io.netty.buffer.ByteBufAllocator; @@ -47,8 +48,11 @@ import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.http2.Http2Connection; import io.netty.handler.codec.http2.Http2Exception; +import io.netty.handler.codec.http2.Http2FlowController; import io.netty.handler.codec.http2.Http2Headers; +import io.netty.handler.codec.http2.Http2Stream; import io.netty.util.AsciiString; import io.netty.util.NettyRuntime; import io.netty.util.concurrent.DefaultThreadFactory; @@ -441,6 +445,25 @@ public String toString() { } } + static final class FlowControlReader implements TransportTracer.FlowControlReader { + private final Http2Stream connectionStream; + private final Http2FlowController local; + private final Http2FlowController remote; + + FlowControlReader(Http2Connection connection) { + local = connection.local().flowController(); + remote = connection.remote().flowController(); + connectionStream = connection.connectionStream(); + } + + @Override + public TransportTracer.FlowControlWindows read() { + return new TransportTracer.FlowControlWindows( + local.windowSize(connectionStream), + remote.windowSize(connectionStream)); + } + } + static InternalChannelz.SocketOptions getSocketOptions(Channel channel) { ChannelConfig config = channel.config(); InternalChannelz.SocketOptions.Builder b = new InternalChannelz.SocketOptions.Builder(); diff --git a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java index 9521fc93889..100367625fa 100644 --- a/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java +++ b/netty/src/main/java/io/grpc/netty/WriteBufferingAndExceptionHandler.java @@ -124,6 +124,8 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) promise.setFailure(failCause); ReferenceCountUtil.release(msg); } else { + // Do not special case GracefulServerCloseCommand, as we don't want to cause handshake + // failures. if (msg instanceof GracefulCloseCommand || msg instanceof ForcefulCloseCommand) { // No point in continuing negotiation ctx.close(); diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java new file mode 100644 index 00000000000..7dd5ec75e54 --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -0,0 +1,480 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ChannelCredentials; +import io.grpc.Grpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.ServerCredentials; +import io.grpc.StatusRuntimeException; +import io.grpc.TlsChannelCredentials; +import io.grpc.TlsServerCredentials; +import io.grpc.TlsServerCredentials.ClientAuth; +import io.grpc.internal.testing.TestUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.util.AdvancedTlsX509KeyManager; +import io.grpc.util.AdvancedTlsX509TrustManager; +import io.grpc.util.AdvancedTlsX509TrustManager.SslSocketAndEnginePeerVerifier; +import io.grpc.util.AdvancedTlsX509TrustManager.Verification; +import io.grpc.util.CertificateUtils; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.net.Socket; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLEngine; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AdvancedTlsTest { + public static final String SERVER_0_KEY_FILE = "server0.key"; + public static final String SERVER_0_PEM_FILE = "server0.pem"; + public static final String CLIENT_0_KEY_FILE = "client.key"; + public static final String CLIENT_0_PEM_FILE = "client.pem"; + public static final String CA_PEM_FILE = "ca.pem"; + public static final String SERVER_BAD_KEY_FILE = "badserver.key"; + public static final String SERVER_BAD_PEM_FILE = "badserver.pem"; + + private ScheduledExecutorService executor; + private Server server; + private ManagedChannel channel; + + private File caCertFile; + private File serverKey0File; + private File serverCert0File; + private File clientKey0File; + private File clientCert0File; + private X509Certificate[] caCert; + private PrivateKey serverKey0; + private X509Certificate[] serverCert0; + private PrivateKey clientKey0; + private X509Certificate[] clientCert0; + private PrivateKey serverKeyBad; + private X509Certificate[] serverCertBad; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() + throws NoSuchAlgorithmException, IOException, CertificateException, InvalidKeySpecException { + executor = Executors.newSingleThreadScheduledExecutor(); + caCertFile = TestUtils.loadCert(CA_PEM_FILE); + serverKey0File = TestUtils.loadCert(SERVER_0_KEY_FILE); + serverCert0File = TestUtils.loadCert(SERVER_0_PEM_FILE); + clientKey0File = TestUtils.loadCert(CLIENT_0_KEY_FILE); + clientCert0File = TestUtils.loadCert(CLIENT_0_PEM_FILE); + caCert = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + CA_PEM_FILE)); + serverKey0 = CertificateUtils.getPrivateKey( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_0_KEY_FILE)); + serverCert0 = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_0_PEM_FILE)); + clientKey0 = CertificateUtils.getPrivateKey( + TestUtils.class.getResourceAsStream("/certs/" + CLIENT_0_KEY_FILE)); + clientCert0 = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + CLIENT_0_PEM_FILE)); + serverKeyBad = CertificateUtils.getPrivateKey( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_BAD_KEY_FILE)); + serverCertBad = CertificateUtils.getX509Certificates( + TestUtils.class.getResourceAsStream("/certs/" + SERVER_BAD_PEM_FILE)); + } + + @After + public void tearDown() { + if (server != null) { + server.shutdown(); + } + if (channel != null) { + channel.shutdown(); + } + MoreExecutors.shutdownAndAwaitTermination(executor, 5, TimeUnit.SECONDS); + } + + @Test + public void basicMutualTlsTest() throws Exception { + // Create & start a server. + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverCert0File, serverKey0File).trustManager(caCertFile) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + // Create a client to connect. + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientCert0File, clientKey0File).trustManager(caCertFile).build(); + channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au").build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + // Send an actual request, via the full GRPC & network stack, and check that a proper + // response comes back. + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + e.printStackTrace(); + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void advancedTlsKeyManagerTrustManagerMutualTlsTest() throws Exception { + // Create a server with the key manager and trust manager. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .build(); + serverTrustManager.updateTrustCredentials(caCert); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + // Create a client with the key manager and trust manager. + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) + .build(); + clientTrustManager.updateTrustCredentials(caCert); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au").build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void trustManagerCustomVerifierMutualTlsTest() throws Exception { + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverKey0, serverCert0); + // Set server's custom verification based on the information of clientCert0. + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("testclient")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("testclient")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + }) + .build(); + serverTrustManager.updateTrustCredentials(caCert); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + // Set client's custom verification based on the information of serverCert0. + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("*.test.google.com.au")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { + if (peerCertChain == null || peerCertChain.length == 0) { + throw new CertificateException("peerCertChain is empty"); + } + X509Certificate leafCert = peerCertChain[0]; + if (!leafCert.getSubjectDN().getName().contains("*.test.google.com.au")) { + throw new CertificateException("SslSocketAndEnginePeerVerifier failed"); + } + } + }) + .build(); + clientTrustManager.updateTrustCredentials(caCert); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), channelCredentials).build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void trustManagerInsecurelySkipAllTest() throws Exception { + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + // Even if we provide bad credentials for the server, the test should still pass, because we + // will configure the client to skip all checks later. + serverKeyManager.updateIdentityCredentials(serverKeyBad, serverCertBad); + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { } + }) + .build(); + serverTrustManager.updateTrustCredentials(caCert); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + clientKeyManager.updateIdentityCredentials(clientKey0, clientCert0); + // Set the client to skip all checks, including traditional certificate verification. + // Note this is very dangerous in production environment - only do so if you are confident on + // what you are doing! + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { } + }) + .build(); + clientTrustManager.updateTrustCredentials(caCert); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress( + "localhost", server.getPort(), channelCredentials).build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + fail("Failed to make a connection"); + e.printStackTrace(); + } + } + + @Test + public void onFileReloadingKeyManagerTrustManagerTest() throws Exception { + // Create & start a server. + AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + Closeable serverKeyShutdown = serverKeyManager.updateIdentityCredentialsFromFile(serverKey0File, + serverCert0File, 100, TimeUnit.MILLISECONDS, executor); + AdvancedTlsX509TrustManager serverTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .build(); + Closeable serverTrustShutdown = serverTrustManager.updateTrustCredentialsFromFile(caCertFile, + 100, TimeUnit.MILLISECONDS, executor); + ServerCredentials serverCredentials = TlsServerCredentials.newBuilder() + .keyManager(serverKeyManager).trustManager(serverTrustManager) + .clientAuth(ClientAuth.REQUIRE).build(); + server = Grpc.newServerBuilderForPort(0, serverCredentials).addService( + new SimpleServiceImpl()).build().start(); + TimeUnit.SECONDS.sleep(5); + // Create a client to connect. + AdvancedTlsX509KeyManager clientKeyManager = new AdvancedTlsX509KeyManager(); + Closeable clientKeyShutdown = clientKeyManager.updateIdentityCredentialsFromFile(clientKey0File, + clientCert0File,100, TimeUnit.MILLISECONDS, executor); + AdvancedTlsX509TrustManager clientTrustManager = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_AND_HOST_NAME_VERIFICATION) + .build(); + Closeable clientTrustShutdown = clientTrustManager.updateTrustCredentialsFromFile(caCertFile, + 100, TimeUnit.MILLISECONDS, executor); + ChannelCredentials channelCredentials = TlsChannelCredentials.newBuilder() + .keyManager(clientKeyManager).trustManager(clientTrustManager).build(); + channel = Grpc.newChannelBuilderForAddress("localhost", server.getPort(), channelCredentials) + .overrideAuthority("foo.test.google.com.au").build(); + // Start the connection. + try { + SimpleServiceGrpc.SimpleServiceBlockingStub client = + SimpleServiceGrpc.newBlockingStub(channel); + // Send an actual request, via the full GRPC & network stack, and check that a proper + // response comes back. + client.unaryRpc(SimpleRequest.getDefaultInstance()); + } catch (StatusRuntimeException e) { + e.printStackTrace(); + fail("Find error: " + e.getMessage()); + } + // Clean up. + serverKeyShutdown.close(); + serverTrustShutdown.close(); + clientKeyShutdown.close(); + clientTrustShutdown.close(); + } + + @Test + public void keyManagerAliasesTest() throws Exception { + AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); + assertArrayEquals( + new String[] {"default"}, km.getClientAliases("", null)); + assertEquals( + "default", km.chooseClientAlias(new String[] {"default"}, null, null)); + assertArrayEquals( + new String[] {"default"}, km.getServerAliases("", null)); + assertEquals( + "default", km.chooseServerAlias("default", null, null)); + } + + @Test + public void trustManagerCheckTrustedWithSocketTest() throws Exception { + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); + tm.updateTrustCredentials(caCert); + tm.checkClientTrusted(serverCert0, "RSA", new Socket()); + tm.useSystemDefaultTrustCerts(); + tm.checkServerTrusted(clientCert0, "RSA", new Socket()); + } + + @Test + public void trustManagerCheckClientTrustedWithoutParameterTest() throws Exception { + exceptionRule.expect(CertificateException.class); + exceptionRule.expectMessage( + "Not enough information to validate peer. SSLEngine or Socket required."); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); + tm.checkClientTrusted(serverCert0, "RSA"); + } + + @Test + public void trustManagerCheckServerTrustedWithoutParameterTest() throws Exception { + exceptionRule.expect(CertificateException.class); + exceptionRule.expectMessage( + "Not enough information to validate peer. SSLEngine or Socket required."); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.INSECURELY_SKIP_ALL_VERIFICATION).build(); + tm.checkServerTrusted(serverCert0, "RSA"); + } + + @Test + public void trustManagerEmptyChainTest() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage( + "Want certificate verification but got null or empty certificates"); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .build(); + tm.updateTrustCredentials(caCert); + tm.checkClientTrusted(null, "RSA", (SSLEngine) null); + } + + @Test + public void trustManagerBadCustomVerificationTest() throws Exception { + exceptionRule.expect(CertificateException.class); + exceptionRule.expectMessage("Bad Custom Verification"); + AdvancedTlsX509TrustManager tm = AdvancedTlsX509TrustManager.newBuilder() + .setVerification(Verification.CERTIFICATE_ONLY_VERIFICATION) + .setSslSocketAndEnginePeerVerifier( + new SslSocketAndEnginePeerVerifier() { + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + Socket socket) throws CertificateException { + throw new CertificateException("Bad Custom Verification"); + } + + @Override + public void verifyPeerCertificate(X509Certificate[] peerCertChain, String authType, + SSLEngine engine) throws CertificateException { + throw new CertificateException("Bad Custom Verification"); + } + }).build(); + tm.updateTrustCredentials(caCert); + tm.checkClientTrusted(serverCert0, "RSA", new Socket()); + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest req, StreamObserver respOb) { + respOb.onNext(SimpleResponse.getDefaultInstance()); + respOb.onCompleted(); + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index 5f1d27c37e2..d0d48fe9b48 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -331,9 +331,18 @@ public void inboundShouldForwardToStream() throws Exception { } @Test - public void receivedGoAwayShouldRefuseLaterStreamId() throws Exception { + public void receivedGoAwayNoErrorShouldRefuseLaterStreamId() throws Exception { ChannelFuture future = enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); channelRead(goAwayFrame(streamId - 1)); + verify(streamListener).closed(any(Status.class), eq(REFUSED), any(Metadata.class)); + assertTrue(future.isDone()); + } + + @Test + public void receivedGoAwayErrorShouldRefuseLaterStreamId() throws Exception { + ChannelFuture future = enqueue(newCreateStreamCommand(grpcHeaders, streamTransportState)); + channelRead( + goAwayFrame(streamId - 1, (int) Http2Error.PROTOCOL_ERROR.code(), Unpooled.EMPTY_BUFFER)); // This _should_ be REFUSED, but we purposefully use PROCESSED. See comment for // abruptGoAwayStatusConservative in NettyClientHandler verify(streamListener).closed(any(Status.class), eq(PROCESSED), any(Metadata.class)); diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index e4165e89243..018ca9b6594 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -41,6 +41,7 @@ import io.grpc.Attributes; import io.grpc.CallOptions; import io.grpc.ChannelLogger; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.InternalChannelz; import io.grpc.Metadata; @@ -828,7 +829,9 @@ private static class Rpc { } Rpc(NettyClientTransport transport, Metadata headers) { - stream = transport.newStream(METHOD, headers, CallOptions.DEFAULT); + stream = transport.newStream( + METHOD, headers, CallOptions.DEFAULT, + new ClientStreamTracer[]{ new ClientStreamTracer() {} }); stream.start(listener); stream.request(1); stream.writeMessage(new ByteArrayInputStream(MESSAGE.getBytes(UTF_8))); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index 961f983d9cd..8c44088afa7 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -350,6 +350,56 @@ public void closeShouldGracefullyCloseChannel() throws Exception { assertFalse(channel().isOpen()); } + @Test + public void gracefulCloseShouldGracefullyCloseChannel() throws Exception { + manualSetUp(); + handler() + .write(ctx(), new GracefulServerCloseCommand("test", 1, TimeUnit.MINUTES), newPromise()); + + verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + verifyWrite().writePing( + eq(ctx()), + eq(false), + eq(NettyServerHandler.GRACEFUL_SHUTDOWN_PING), + isA(ChannelPromise.class)); + channelRead(pingFrame(/*ack=*/ true , NettyServerHandler.GRACEFUL_SHUTDOWN_PING)); + + verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + + // Verify that the channel was closed. + assertFalse(channel().isOpen()); + } + + @Test + public void secondGracefulCloseIsSafe() throws Exception { + manualSetUp(); + handler().write(ctx(), new GracefulServerCloseCommand("test"), newPromise()); + + verifyWrite().writeGoAway(eq(ctx()), eq(Integer.MAX_VALUE), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + verifyWrite().writePing( + eq(ctx()), + eq(false), + eq(NettyServerHandler.GRACEFUL_SHUTDOWN_PING), + isA(ChannelPromise.class)); + + handler().write(ctx(), new GracefulServerCloseCommand("test2"), newPromise()); + + channel().runPendingTasks(); + // No additional GOAWAYs. + verifyWrite().writeGoAway(any(ChannelHandlerContext.class), any(Integer.class), any(Long.class), + any(ByteBuf.class), any(ChannelPromise.class)); + channel().checkException(); + assertTrue(channel().isOpen()); + + channelRead(pingFrame(/*ack=*/ true , NettyServerHandler.GRACEFUL_SHUTDOWN_PING)); + verifyWrite().writeGoAway(eq(ctx()), eq(0), eq(Http2Error.NO_ERROR.code()), + isA(ByteBuf.class), any(ChannelPromise.class)); + assertFalse(channel().isOpen()); + } + @Test public void exceptionCaughtShouldCloseConnection() throws Exception { manualSetUp(); diff --git a/netty/src/test/java/io/grpc/netty/NettyServerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerTest.java index 3f277ed4356..12dc5b9fa51 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerTest.java @@ -486,9 +486,8 @@ public void run() {} assertEquals(ns.getListenSocketAddress(), socketStats.local); assertNull(socketStats.remote); - // TODO(zpencer): uncomment when sock options are exposed // by default, there are some socket options set on the listen socket - // assertThat(socketStats.socketOptions.additional).isNotEmpty(); + assertThat(socketStats.socketOptions.others).isNotEmpty(); // Cleanup ns.shutdown(); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index a001ddb73e7..121093716db 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -34,6 +34,7 @@ import com.squareup.okhttp.internal.http.StatusLine; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.Grpc; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz; @@ -387,12 +388,13 @@ public void ping(final PingCallback callback, Executor executor) { } @Override - public OkHttpClientStream newStream(final MethodDescriptor method, - final Metadata headers, CallOptions callOptions) { + public OkHttpClientStream newStream( + MethodDescriptor method, Metadata headers, CallOptions callOptions, + ClientStreamTracer[] tracers) { Preconditions.checkNotNull(method, "method"); Preconditions.checkNotNull(headers, "headers"); - StatsTraceContext statsTraceCtx = - StatsTraceContext.newClientContext(callOptions, attributes, headers); + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(tracers, getAttributes(), headers); // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -406,7 +408,7 @@ public OkHttpClientStream newStream(final MethodDescriptor method, initialWindowSize, defaultAuthority, userAgent, - statsTraceCtx, + statsTraceContext, transportTracer, callOptions, useGetForSafeMethods); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index b03a2dedc00..b70b832a797 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -56,6 +56,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; +import io.grpc.ClientStreamTracer; import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.InternalChannelz.SocketStats; import io.grpc.InternalChannelz.TransportStats; @@ -146,6 +147,9 @@ public class OkHttpClientTransportTest { private static final int DEFAULT_MAX_INBOUND_METADATA_SIZE = Integer.MAX_VALUE; private static final Attributes EAG_ATTRS = Attributes.EMPTY; private static final Logger logger = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final ClientStreamTracer[] tracers = new ClientStreamTracer[] { + new ClientStreamTracer() {} + }; @Rule public final Timeout globalTimeout = Timeout.seconds(10); @@ -299,7 +303,7 @@ public void close() throws SecurityException { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -387,7 +391,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -443,11 +447,11 @@ public void nextFrameThrowIoException() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); assertEquals(2, activeStreamCount()); @@ -477,7 +481,7 @@ public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertEquals(1, activeStreamCount()); @@ -498,7 +502,7 @@ public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); frameReader.nextFrameAtEndOfStream(); @@ -516,7 +520,7 @@ public void readMessages() throws Exception { final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); assertContainStream(3); @@ -566,7 +570,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -590,7 +594,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); assertContainStream(3); @@ -610,7 +614,7 @@ public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().headers(true, true, 3, 0, grpcResponseTrailers(), HeadersMode.HTTP_20_HEADERS); @@ -624,7 +628,7 @@ public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().rstStream(3, ErrorCode.PROTOCOL_ERROR); @@ -641,7 +645,7 @@ public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); frameHandler().headers(false, false, 3, 0, grpcResponseHeaders(), HeadersMode.HTTP_20_HEADERS); @@ -661,7 +665,7 @@ public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -676,7 +680,7 @@ public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), GrpcUtil.getGrpcUserAgent("okhttp", null)); @@ -695,7 +699,7 @@ public void overrideDefaultUserAgent() throws Exception { startTransport(3, null, true, DEFAULT_MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "fakeUserAgent"); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List
expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, new Header(Header.TARGET_AUTHORITY, "notarealauthority:80"), @@ -714,7 +718,7 @@ public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); verify(frameWriter, timeout(TIME_OUT_MS)).rstStream(eq(3), eq(ErrorCode.CANCEL)); @@ -728,7 +732,7 @@ public void writeMessage() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); assertEquals(12, input.available()); @@ -772,12 +776,12 @@ public void windowUpdate() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); assertEquals(2, activeStreamCount()); @@ -838,7 +842,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; byte[] fakeMessage = new byte[messageLength]; @@ -874,7 +878,7 @@ public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); // Outbound window always starts at 65535 until changed by Settings.INITIAL_WINDOW_SIZE @@ -920,7 +924,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 75; @@ -963,7 +967,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 100000; @@ -999,7 +1003,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; setInitialWindowSize(HEADER_LENGTH + 10); @@ -1045,7 +1049,7 @@ public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() thr initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; setInitialWindowSize(HEADER_LENGTH + 10); @@ -1080,10 +1084,10 @@ public void stopNormally() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); clientTransport.shutdown(SHUTDOWN_REASON); @@ -1110,11 +1114,11 @@ public void receiveGoAway() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); assertEquals(2, activeStreamCount()); @@ -1168,7 +1172,7 @@ public void streamIdExhausted() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1204,11 +1208,11 @@ public void pendingStreamSucceed() throws Exception { final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); @@ -1241,7 +1245,7 @@ public void pendingStreamCancelled() throws Exception { setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); stream.cancel(Status.CANCELLED); @@ -1260,11 +1264,11 @@ public void pendingStreamFailedByGoAway() throws Exception { final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); waitForStreamPending(1); @@ -1290,7 +1294,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1314,15 +1318,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(2); @@ -1346,7 +1350,7 @@ public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1398,7 +1402,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( eq(false), eq(false), eq(3), eq(0), ArgumentMatchers.
anyList()); @@ -1415,7 +1419,7 @@ public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); @@ -1437,7 +1441,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1]); @@ -1459,7 +1463,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); Buffer buffer = createMessageFrame(new byte[1000]); @@ -1480,7 +1484,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1507,7 +1511,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); // This should be ignored. @@ -1527,7 +1531,7 @@ public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); assertTrue(listener.isOnReadyCalled()); @@ -1545,7 +1549,7 @@ public void notifyOnReady() throws Exception { setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); // Be notified at the beginning. @@ -1695,7 +1699,7 @@ public void writeBeforeConnected() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); stream.writeMessage(input); @@ -1720,7 +1724,7 @@ public void cancelBeforeConnected() throws Exception { final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); stream.writeMessage(input); @@ -1738,7 +1742,7 @@ public void shutdownDuringConnecting() throws Exception { initTransportAndDelayConnected(); MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); clientTransport.shutdown(SHUTDOWN_REASON); allowTransportConnected(); @@ -1810,7 +1814,8 @@ public void unreachableServer() throws Exception { assertTrue(status.getCause().toString(), status.getCause() instanceof IOException); MockStreamListener streamListener = new MockStreamListener(); - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT).start(streamListener); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers) + .start(streamListener); streamListener.waitUntilStreamClosed(); assertEquals(Status.UNAVAILABLE.getCode(), streamListener.status.getCode()); } @@ -2054,13 +2059,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2094,13 +2099,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); OkHttpClientStream stream1 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); OkHttpClientStream stream2 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); OkHttpClientStream stream3 = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); assertEquals(3, activeStreamCount()); @@ -2158,7 +2163,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT); + clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); assertFalse(listener.status.isOk()); diff --git a/repositories.bzl b/repositories.bzl index ad50272d286..0d6e9ab2f74 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -16,7 +16,9 @@ IO_GRPC_GRPC_JAVA_ARTIFACTS = [ "com.google.auth:google-auth-library-oauth2-http:0.22.0", "com.google.code.findbugs:jsr305:3.0.2", "com.google.code.gson:gson:jar:2.8.6", - "com.google.errorprone:error_prone_annotations:2.4.0", + "com.google.auto.value:auto-value:1.7.4", + "com.google.auto.value:auto-value-annotations:1.7.4", + "com.google.errorprone:error_prone_annotations:2.9.0", "com.google.guava:failureaccess:1.0.1", "com.google.guava:guava:30.1-android", "com.google.j2objc:j2objc-annotations:1.3", diff --git a/rls/BUILD.bazel b/rls/BUILD.bazel index 0da03fc924b..4daa7029560 100644 --- a/rls/BUILD.bazel +++ b/rls/BUILD.bazel @@ -7,12 +7,14 @@ java_library( ]), visibility = ["//visibility:public"], deps = [ + ":autovalue", ":rls_java_grpc", "//api", "//core", "//core:internal", "//core:util", "//stub", + "@com_google_auto_value_auto_value_annotations//jar", "@com_google_code_findbugs_jsr305//jar", "@com_google_guava_guava//jar", "@io_grpc_grpc_proto//:rls_java_proto", @@ -20,6 +22,25 @@ java_library( ], ) +java_plugin( + name = "autovalue_plugin", + processor_class = "com.google.auto.value.processor.AutoValueProcessor", + deps = [ + "@com_google_auto_value_auto_value//jar", + ], +) + +java_library( + name = "autovalue", + exported_plugins = [ + ":autovalue_plugin", + ], + neverlink = 1, + exports = [ + "@com_google_auto_value_auto_value//jar", + ], +) + java_grpc_library( name = "rls_java_grpc", srcs = ["@io_grpc_grpc_proto//:rls_proto"], diff --git a/rls/build.gradle b/rls/build.gradle index a2ebf2a62ef..45f17fb71c3 100644 --- a/rls/build.gradle +++ b/rls/build.gradle @@ -14,7 +14,9 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), + libraries.autovalue_annotation, libraries.guava + annotationProcessor libraries.autovalue compileOnly libraries.javax_annotation testImplementation libraries.truth, project(':grpc-grpclb'), @@ -24,6 +26,17 @@ dependencies { signature "org.codehaus.mojo.signature:java17:1.0@signature" } +[compileJava].each() { + it.options.compilerArgs += [ + // only has AutoValue annotation processor + "-Xlint:-processing", + ] + appendToProperty( + it.options.errorprone.excludedPaths, + ".*/build/generated/sources/annotationProcessor/java/.*", + "|") +} + javadoc { // Do not publish javadoc since currently there is no public API. failOnError false // no public or protected classes found to document diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java index 12903044a21..289098e2554 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java @@ -82,7 +82,9 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { @Override public void requestConnection() { - routeLookupClient.requestConnection(); + if (routeLookupClient != null) { + routeLookupClient.requestConnection(); + } } @Override diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java index 32df13c4262..ce89def4467 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java @@ -20,9 +20,11 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.Converter; +import com.google.common.collect.ImmutableMap; import io.grpc.internal.JsonUtil; import io.grpc.lookup.v1.RouteLookupRequest; import io.grpc.lookup.v1.RouteLookupResponse; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -46,25 +48,16 @@ final class RlsProtoConverters { static final class RouteLookupRequestConverter extends Converter { - @SuppressWarnings("deprecation") @Override protected RlsProtoData.RouteLookupRequest doForward(RouteLookupRequest routeLookupRequest) { - return - new RlsProtoData.RouteLookupRequest( - /* server= */ routeLookupRequest.getServer(), - /* path= */ routeLookupRequest.getPath(), - /* targetType= */ routeLookupRequest.getTargetType(), - routeLookupRequest.getKeyMapMap()); + return new RlsProtoData.RouteLookupRequest(routeLookupRequest.getKeyMapMap()); } - @SuppressWarnings("deprecation") @Override protected RouteLookupRequest doBackward(RlsProtoData.RouteLookupRequest routeLookupRequest) { return RouteLookupRequest.newBuilder() - .setServer(routeLookupRequest.getServer()) - .setPath(routeLookupRequest.getPath()) - .setTargetType(routeLookupRequest.getTargetType()) + .setTargetType("grpc") .putAllKeyMap(routeLookupRequest.getKeyMap()) .build(); } @@ -183,7 +176,19 @@ static GrpcKeyBuilder convert(Map keyBuilder) { matcher.isOptional(), "NameMatcher for GrpcKeyBuilders shouldn't be required"); nameMatchers.add(matcher); } - return new GrpcKeyBuilder(names, nameMatchers); + ExtraKeys extraKeys = ExtraKeys.DEFAULT; + Map rawExtraKeys = + (Map) JsonUtil.getObject(keyBuilder, "extraKeys"); + if (rawExtraKeys != null) { + extraKeys = ExtraKeys.create( + rawExtraKeys.get("host"), rawExtraKeys.get("service"), rawExtraKeys.get("method")); + } + Map constantKeys = + (Map) JsonUtil.getObject(keyBuilder, "constantKeys"); + if (constantKeys == null) { + constantKeys = ImmutableMap.of(); + } + return new GrpcKeyBuilder(names, nameMatchers, extraKeys, constantKeys); } } diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoData.java b/rls/src/main/java/io/grpc/rls/RlsProtoData.java index fbcb6feb21c..3556ca609be 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoData.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoData.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.auto.value.AutoValue; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; import com.google.common.collect.ImmutableList; @@ -40,46 +41,12 @@ final class RlsProtoData { @Immutable static final class RouteLookupRequest { - private final String server; - - private final String path; - - private final String targetType; - private final ImmutableMap keyMap; - RouteLookupRequest( - String server, String path, String targetType, Map keyMap) { - this.server = checkNotNull(server, "server"); - this.path = checkNotNull(path, "path"); - this.targetType = checkNotNull(targetType, "targetName"); + RouteLookupRequest(Map keyMap) { this.keyMap = ImmutableMap.copyOf(checkNotNull(keyMap, "keyMap")); } - /** - * Returns a full host name of the target server, {@literal e.g.} firestore.googleapis.com. Only - * set for gRPC requests; HTTP requests must use key_map explicitly. - */ - String getServer() { - return server; - } - - /** - * Returns a full path of the request, {@literal i.e.} "/service/method". Only set for gRPC - * requests; HTTP requests must use key_map explicitly. - */ - String getPath() { - return path; - } - - /** - * Returns the target type allows the client to specify what kind of target format it would like - * from RLS to allow it to find the regional server, {@literal e.g.} "grpc". - */ - String getTargetType() { - return targetType; - } - /** Returns a map of key values extracted via key builders for the gRPC or HTTP request. */ ImmutableMap getKeyMap() { return keyMap; @@ -94,23 +61,17 @@ public boolean equals(Object o) { return false; } RouteLookupRequest that = (RouteLookupRequest) o; - return Objects.equal(server, that.server) - && Objects.equal(path, that.path) - && Objects.equal(targetType, that.targetType) - && Objects.equal(keyMap, that.keyMap); + return Objects.equal(keyMap, that.keyMap); } @Override public int hashCode() { - return Objects.hashCode(server, path, targetType, keyMap); + return Objects.hashCode(keyMap); } @Override public String toString() { return MoreObjects.toStringHelper(this) - .add("server", server) - .add("path", path) - .add("targetName", targetType) .add("keyMap", keyMap) .toString(); } @@ -300,6 +261,7 @@ ImmutableList getValidTargets() { * error. Note that requests can be routed only to a subdomain of the original target, * {@literal e.g.} "us_east_1.cloudbigtable.googleapis.com". */ + @Nullable String getDefaultTarget() { return defaultTarget; } @@ -431,12 +393,18 @@ static final class GrpcKeyBuilder { private final ImmutableList names; private final ImmutableList headers; + private final ExtraKeys extraKeys; + private final ImmutableMap constantKeys; - public GrpcKeyBuilder(List names, List headers) { + public GrpcKeyBuilder( + List names, List headers, ExtraKeys extraKeys, + Map constantKeys) { checkState(names != null && !names.isEmpty(), "names cannot be empty"); this.names = ImmutableList.copyOf(names); checkUniqueKey(checkNotNull(headers, "headers")); this.headers = ImmutableList.copyOf(headers); + this.extraKeys = checkNotNull(extraKeys, "extraKeys"); + this.constantKeys = ImmutableMap.copyOf(checkNotNull(constantKeys, "constantKeys")); } private static void checkUniqueKey(List headers) { @@ -464,6 +432,14 @@ ImmutableList getHeaders() { return headers; } + ExtraKeys getExtraKeys() { + return extraKeys; + } + + ImmutableMap getConstantKeys() { + return constantKeys; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -473,12 +449,14 @@ public boolean equals(Object o) { return false; } GrpcKeyBuilder that = (GrpcKeyBuilder) o; - return Objects.equal(names, that.names) && Objects.equal(headers, that.headers); + return Objects.equal(names, that.names) && Objects.equal(headers, that.headers) + && Objects.equal(extraKeys, that.extraKeys) + && Objects.equal(constantKeys, that.constantKeys); } @Override public int hashCode() { - return Objects.hashCode(names, headers); + return Objects.hashCode(names, headers, extraKeys, constantKeys); } @Override @@ -486,6 +464,8 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("names", names) .add("headers", headers) + .add("extraKeys", extraKeys) + .add("constantKeys", constantKeys) .toString(); } @@ -548,4 +528,20 @@ public String toString() { } } } + + @AutoValue + abstract static class ExtraKeys { + static final ExtraKeys DEFAULT = create(null, null, null); + + @Nullable abstract String host(); + + @Nullable abstract String service(); + + @Nullable abstract String method(); + + static ExtraKeys create( + @Nullable String host, @Nullable String service, @Nullable String method) { + return new AutoValue_RlsProtoData_ExtraKeys(host, service, method); + } + } } diff --git a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java index b9bcb037cf5..e181d64833d 100644 --- a/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java +++ b/rls/src/main/java/io/grpc/rls/RlsRequestFactory.java @@ -19,17 +19,18 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.MoreObjects; -import com.google.common.collect.HashBasedTable; -import com.google.common.collect.Table; +import com.google.common.collect.ImmutableMap; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.StatusRuntimeException; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; import io.grpc.rls.RlsProtoData.RouteLookupConfig; import io.grpc.rls.RlsProtoData.RouteLookupRequest; import java.util.HashMap; +import java.util.List; import java.util.Map; import javax.annotation.CheckReturnValue; @@ -40,10 +41,7 @@ final class RlsRequestFactory { private final String target; - /** - * schema: Path(/serviceName/methodName or /serviceName/*), rls request headerName, header fields. - */ - private final Table keyBuilderTable; + private final Map keyBuilderTable; RlsRequestFactory(RouteLookupConfig rlsConfig, String target) { checkNotNull(rlsConfig, "rlsConfig"); @@ -51,18 +49,15 @@ final class RlsRequestFactory { this.keyBuilderTable = createKeyBuilderTable(rlsConfig); } - private static Table createKeyBuilderTable( + private static Map createKeyBuilderTable( RouteLookupConfig config) { - Table table = HashBasedTable.create(); + Map table = new HashMap<>(); for (GrpcKeyBuilder grpcKeyBuilder : config.getGrpcKeyBuilders()) { - for (NameMatcher nameMatcher : grpcKeyBuilder.getHeaders()) { - for (Name name : grpcKeyBuilder.getNames()) { - String method = - name.getMethod() == null || name.getMethod().isEmpty() - ? "*" : name.getMethod(); - String path = "/" + name.getService() + "/" + method; - table.put(path, nameMatcher.getKey(), nameMatcher); - } + for (Name name : grpcKeyBuilder.getNames()) { + boolean hasMethod = name.getMethod() == null || name.getMethod().isEmpty(); + String method = hasMethod ? "*" : name.getMethod(); + String path = "/" + name.getService() + "/" + method; + table.put(path, grpcKeyBuilder); } } return table; @@ -74,20 +69,35 @@ RouteLookupRequest create(String service, String method, Metadata metadata) { checkNotNull(service, "service"); checkNotNull(method, "method"); String path = "/" + service + "/" + method; - Map keyBuilder = keyBuilderTable.row(path); - // if no matching keyBuilder found, fall back to wildcard match (ServiceName/*) - if (keyBuilder.isEmpty()) { - keyBuilder = keyBuilderTable.row("/" + service + "/*"); + GrpcKeyBuilder grpcKeyBuilder = keyBuilderTable.get(path); + if (grpcKeyBuilder == null) { + // if no matching keyBuilder found, fall back to wildcard match (ServiceName/*) + grpcKeyBuilder = keyBuilderTable.get("/" + service + "/*"); + } + if (grpcKeyBuilder == null) { + return new RouteLookupRequest(ImmutableMap.of()); + } + Map rlsRequestHeaders = + createRequestHeaders(metadata, grpcKeyBuilder.getHeaders()); + ExtraKeys extraKeys = grpcKeyBuilder.getExtraKeys(); + Map constantKeys = grpcKeyBuilder.getConstantKeys(); + if (extraKeys.host() != null) { + rlsRequestHeaders.put(extraKeys.host(), target); + } + if (extraKeys.service() != null) { + rlsRequestHeaders.put(extraKeys.service(), service); + } + if (extraKeys.method() != null) { + rlsRequestHeaders.put(extraKeys.method(), method); } - Map rlsRequestHeaders = createRequestHeaders(metadata, keyBuilder); - return new RouteLookupRequest(target, path, "grpc", rlsRequestHeaders); + rlsRequestHeaders.putAll(constantKeys); + return new RouteLookupRequest(rlsRequestHeaders); } private Map createRequestHeaders( - Metadata metadata, Map keyBuilder) { + Metadata metadata, List keyBuilder) { Map rlsRequestHeaders = new HashMap<>(); - for (Map.Entry entry : keyBuilder.entrySet()) { - NameMatcher nameMatcher = entry.getValue(); + for (NameMatcher nameMatcher : keyBuilder) { String value = null; for (String requestHeaderName : nameMatcher.names()) { value = metadata.get(Metadata.Key.of(requestHeaderName, Metadata.ASCII_STRING_MARSHALLER)); @@ -96,11 +106,11 @@ private Map createRequestHeaders( } } if (value != null) { - rlsRequestHeaders.put(entry.getKey(), value); + rlsRequestHeaders.put(nameMatcher.getKey(), value); } else if (!nameMatcher.isOptional()) { throw new StatusRuntimeException( Status.INVALID_ARGUMENT.withDescription( - String.format("Missing mandatory metadata(%s) not found", entry.getKey()))); + String.format("Missing mandatory metadata(%s) not found", nameMatcher.getKey()))); } } return rlsRequestHeaders; diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index c2c221800ab..aa64ec890b6 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -65,6 +65,7 @@ import io.grpc.rls.LruCache.EvictionListener; import io.grpc.rls.LruCache.EvictionType; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -191,9 +192,8 @@ public void run() { public void get_noError_lifeCycle() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(evictionListener); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -242,9 +242,8 @@ public void get_noError_lifeCycle() throws Exception { public void rls_overDirectPath() throws Exception { CachingRlsLbClient.enableOobChannelDirectPath = true; setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -276,8 +275,8 @@ public void rls_overDirectPath() throws Exception { @Test public void get_throttledAndRecover() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest("server", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -319,9 +318,8 @@ public void get_throttledAndRecover() throws Exception { public void get_updatesLbState() throws Exception { setUpRlsLbClient(); InOrder inOrder = inOrder(helper); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/foo/bar", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "service1", "method-key", "create")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, @@ -349,7 +347,8 @@ public void get_updatesLbState() throws Exception { Metadata headers = new Metadata(); PickResult pickResult = pickerCaptor.getValue().pickSubchannel( new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod().toBuilder().setFullMethodName("foo/bar").build(), + TestMethodDescriptors.voidMethod().toBuilder().setFullMethodName("service1/create") + .build(), headers, CallOptions.DEFAULT)); assertThat(pickResult.getStatus().isOk()).isTrue(); @@ -360,8 +359,7 @@ public void get_updatesLbState() throws Exception { fakeBackoffProvider.nextPolicy = createBackoffPolicy(100, TimeUnit.MILLISECONDS); // try to get invalid RouteLookupRequest invalidRouteLookupRequest = - new RouteLookupRequest( - "bigtable.googleapis.com", "/doesn/exists", "grpc", ImmutableMap.of()); + new RouteLookupRequest(ImmutableMap.of()); CachedRouteLookupResponse errorResp = getInSyncContext(invalidRouteLookupRequest); assertThat(errorResp.isPending()).isTrue(); fakeTimeProvider.forwardTime(SERVER_LATENCY_MILLIS, TimeUnit.MILLISECONDS); @@ -369,7 +367,7 @@ public void get_updatesLbState() throws Exception { errorResp = getInSyncContext(invalidRouteLookupRequest); assertThat(errorResp.hasError()).isTrue(); - // Channel is still READY because the subchannel for method /foo/bar is still READY. + // Channel is still READY because the subchannel for method /service1/create is still READY. // Method /doesn/exists will use fallback child balancer and fail immediately. inOrder.verify(helper) .updateBalancingState(eq(ConnectivityState.READY), pickerCaptor.capture()); @@ -387,10 +385,10 @@ public void get_updatesLbState() throws Exception { @Test public void get_childPolicyWrapper_reusedForSameTarget() throws Exception { setUpRlsLbClient(); - RouteLookupRequest routeLookupRequest = - new RouteLookupRequest("server", "/foo/bar", "grpc", ImmutableMap.of()); - RouteLookupRequest routeLookupRequest2 = - new RouteLookupRequest("server", "/foo/baz", "grpc", ImmutableMap.of()); + RouteLookupRequest routeLookupRequest = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "bar")); + RouteLookupRequest routeLookupRequest2 = new RouteLookupRequest(ImmutableMap.of( + "server", "bigtable.googleapis.com", "service-key", "foo", "method-key", "baz")); rlsServerImpl.setLookupTable( ImmutableMap.of( routeLookupRequest, new RouteLookupResponse(ImmutableList.of("target"), "header"), @@ -426,7 +424,9 @@ private static RouteLookupConfig getRouteLookupConfig() { ImmutableList.of(new Name("service1", "create")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)))), + new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)), + ExtraKeys.create("server", "service-key", "method-key"), + ImmutableMap.of())), /* lookupService= */ "service1", /* lookupServiceTimeoutInMillis= */ TimeUnit.SECONDS.toMillis(2), /* maxAgeInMillis= */ TimeUnit.SECONDS.toMillis(300), diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index 3b7e84bd543..fba295f98f0 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -144,13 +144,15 @@ public void setUp() throws Exception { .build(); fakeRlsServerImpl.setLookupTable( ImmutableMap.of( - new RouteLookupRequest( - "fake-bigtable.googleapis.com", "/com.google/Search", "grpc", - ImmutableMap.of()), + new RouteLookupRequest(ImmutableMap.of( + "server", "fake-bigtable.googleapis.com", + "service-key", "com.google", + "method-key", "Search")), new RouteLookupResponse(ImmutableList.of("wilderness"), "where are you?"), - new RouteLookupRequest( - "fake-bigtable.googleapis.com", "/com.google/Rescue", "grpc", - ImmutableMap.of()), + new RouteLookupRequest(ImmutableMap.of( + "server", "fake-bigtable.googleapis.com", + "service-key", "com.google", + "method-key", "Rescue")), new RouteLookupResponse(ImmutableList.of("civilization"), "you are safe"))); rlsLb = (RlsLoadBalancer) provider.newLoadBalancer(helper); @@ -409,7 +411,12 @@ private String getRlsConfigJsonStr() { + " \"names\": [\"PermitId\"],\n" + " \"optional\": true\n" + " }\n" - + " ]\n" + + " ],\n" + + " \"extraKeys\": {\n" + + " \"host\": \"server\",\n" + + " \"service\": \"service-key\",\n" + + " \"method\": \"method-key\"\n" + + " }\n" + " }\n" + " ],\n" + " \"lookupService\": \"localhost:8972\",\n" diff --git a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java index a50cdeb9f68..bfb331e6cc8 100644 --- a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java @@ -27,6 +27,7 @@ import io.grpc.rls.RlsProtoConverters.RouteLookupConfigConverter; import io.grpc.rls.RlsProtoConverters.RouteLookupRequestConverter; import io.grpc.rls.RlsProtoConverters.RouteLookupResponseConverter; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -41,40 +42,29 @@ @RunWith(JUnit4.class) public class RlsProtoConvertersTest { - @SuppressWarnings("deprecation") @Test public void convert_toRequestProto() { Converter converter = new RouteLookupRequestConverter(); RouteLookupRequest proto = RouteLookupRequest.newBuilder() - .setServer("server") - .setPath("path") - .setTargetType("target") .putKeyMap("key1", "val1") .build(); RlsProtoData.RouteLookupRequest object = converter.convert(proto); - assertThat(object.getServer()).isEqualTo("server"); - assertThat(object.getPath()).isEqualTo("path"); - assertThat(object.getTargetType()).isEqualTo("target"); assertThat(object.getKeyMap()).containsExactly("key1", "val1"); } - @SuppressWarnings("deprecation") @Test public void convert_toRequestObject() { Converter converter = new RouteLookupRequestConverter().reverse(); RlsProtoData.RouteLookupRequest requestObject = - new RlsProtoData.RouteLookupRequest( - "server", "path", "target", ImmutableMap.of("key1", "val1")); + new RlsProtoData.RouteLookupRequest(ImmutableMap.of("key1", "val1")); RouteLookupRequest proto = converter.convert(requestObject); - assertThat(proto.getServer()).isEqualTo("server"); - assertThat(proto.getPath()).isEqualTo("path"); - assertThat(proto.getTargetType()).isEqualTo("target"); + assertThat(proto.getTargetType()).isEqualTo("grpc"); assertThat(proto.getKeyMapMap()).containsExactly("key1", "val1"); } @@ -164,7 +154,15 @@ public void convert_jsonRlsConfig() throws IOException { + " \"names\": [\"User\", \"Parent\"],\n" + " \"optional\": true\n" + " }\n" - + " ]\n" + + " ],\n" + + " \"extraKeys\": {\n" + + " \"host\": \"host-key\",\n" + + " \"service\": \"service-key\",\n" + + " \"method\": \"method-key\"\n" + + " }, \n" + + " \"constantKeys\": {\n" + + " \"constKey1\": \"value1\"\n" + + " }\n" + " }\n" + " ],\n" + " \"lookupService\": \"service1\",\n" @@ -183,16 +181,22 @@ public void convert_jsonRlsConfig() throws IOException { ImmutableList.of(new Name("service1", "create")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("id", ImmutableList.of("X-Google-Id"), true))), + new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)), + ExtraKeys.DEFAULT, + ImmutableMap.of()), new GrpcKeyBuilder( ImmutableList.of(new Name("service1")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("password", ImmutableList.of("Password"), true))), + new NameMatcher("password", ImmutableList.of("Password"), true)), + ExtraKeys.DEFAULT, + ImmutableMap.of()), new GrpcKeyBuilder( ImmutableList.of(new Name("service3")), ImmutableList.of( - new NameMatcher("user", ImmutableList.of("User", "Parent"), true)))), + new NameMatcher("user", ImmutableList.of("User", "Parent"), true)), + ExtraKeys.create("host-key", "service-key", "method-key"), + ImmutableMap.of("constKey1", "value1"))), /* lookupService= */ "service1", /* lookupServiceTimeoutInMillis= */ TimeUnit.SECONDS.toMillis(2), /* maxAgeInMillis= */ TimeUnit.SECONDS.toMillis(300), diff --git a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java index 8661e023346..b0d197ff525 100644 --- a/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsRequestFactoryTest.java @@ -20,9 +20,11 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.grpc.Metadata; import io.grpc.Status.Code; import io.grpc.StatusRuntimeException; +import io.grpc.rls.RlsProtoData.ExtraKeys; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder; import io.grpc.rls.RlsProtoData.GrpcKeyBuilder.Name; import io.grpc.rls.RlsProtoData.NameMatcher; @@ -43,21 +45,29 @@ public class RlsRequestFactoryTest { ImmutableList.of(new Name("com.google.service1", "Create")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("id", ImmutableList.of("X-Google-Id"), true))), + new NameMatcher("id", ImmutableList.of("X-Google-Id"), true)), + ExtraKeys.create("server-1", null, null), + ImmutableMap.of("const-key-1", "const-value-1")), new GrpcKeyBuilder( ImmutableList.of(new Name("com.google.service1")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), true), - new NameMatcher("password", ImmutableList.of("Password"), true))), + new NameMatcher("password", ImmutableList.of("Password"), true)), + ExtraKeys.create(null, "service-2", null), + ImmutableMap.of("const-key-2", "const-value-2")), new GrpcKeyBuilder( ImmutableList.of(new Name("com.google.service2")), ImmutableList.of( new NameMatcher("user", ImmutableList.of("User", "Parent"), false), - new NameMatcher("password", ImmutableList.of("Password"), true))), + new NameMatcher("password", ImmutableList.of("Password"), true)), + ExtraKeys.create(null, "service-3", "method-3"), + ImmutableMap.of()), new GrpcKeyBuilder( ImmutableList.of(new Name("com.google.service3")), ImmutableList.of( - new NameMatcher("user", ImmutableList.of("User", "Parent"), true)))), + new NameMatcher("user", ImmutableList.of("User", "Parent"), true)), + ExtraKeys.create(null, null, null), + ImmutableMap.of("const-key-4", "const-value-4"))), /* lookupService= */ "bigtable-rls.googleapis.com", /* lookupServiceTimeoutInMillis= */ TimeUnit.SECONDS.toMillis(2), /* maxAgeInMillis= */ TimeUnit.SECONDS.toMillis(300), @@ -77,10 +87,11 @@ public void create_pathMatches() { metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); RouteLookupRequest request = factory.create("com.google.service1", "Create", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service1/Create"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test", "id", "123"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", + "id", "123", + "server-1", "bigtable.googleapis.com", + "const-key-1", "const-value-1"); } @Test @@ -106,10 +117,11 @@ public void create_pathFallbackMatches() { RouteLookupRequest request = factory.create("com.google.service1" , "Update", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service1/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test", "password", "hunter2"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", + "password", "hunter2", + "service-2", "com.google.service1", + "const-key-2", "const-value-2"); } @Test @@ -121,10 +133,10 @@ public void create_pathFallbackMatches_optionalHeaderMissing() { RouteLookupRequest request = factory.create("com.google.service1", "Update", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service1/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", + "service-2", "com.google.service1", + "const-key-2", "const-value-2"); } @Test @@ -135,10 +147,6 @@ public void create_unknownPath() { metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); RouteLookupRequest request = factory.create("abc.def.service999", "Update", metadata); - - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/abc.def.service999/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); assertThat(request.getKeyMap()).isEmpty(); } @@ -151,9 +159,7 @@ public void create_noMethodInRlsConfig() { RouteLookupRequest request = factory.create("com.google.service3", "Update", metadata); - assertThat(request.getTargetType()).isEqualTo("grpc"); - assertThat(request.getPath()).isEqualTo("/com.google.service3/Update"); - assertThat(request.getServer()).isEqualTo("bigtable.googleapis.com"); - assertThat(request.getKeyMap()).containsExactly("user", "test"); + assertThat(request.getKeyMap()).containsExactly( + "user", "test", "const-key-4", "const-value-4"); } } diff --git a/stub/BUILD.bazel b/stub/BUILD.bazel index 181ffe0485d..c65b01a23dc 100644 --- a/stub/BUILD.bazel +++ b/stub/BUILD.bazel @@ -8,6 +8,7 @@ java_library( "//api", "//context", "@com_google_code_findbugs_jsr305//jar", + "@com_google_errorprone_error_prone_annotations//jar", "@com_google_guava_guava//jar", "@com_google_j2objc_j2objc_annotations//jar", ], diff --git a/stub/build.gradle b/stub/build.gradle index 4076460377c..2b5a6a4edb6 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -10,6 +10,7 @@ description = "gRPC: Stub" dependencies { api project(':grpc-api'), libraries.guava + implementation libraries.errorprone testImplementation libraries.truth, project(':grpc-testing') signature "org.codehaus.mojo.signature:java17:1.0@signature" diff --git a/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java b/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java index ea09bb99d55..5fb70c76de3 100644 --- a/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/ClientCallStreamObserver.java @@ -20,7 +20,8 @@ /** * A refinement of {@link CallStreamObserver} that allows for lower-level interaction with - * client calls. + * client calls. An instance of this class is obtained via {@link ClientResponseObserver}, or by + * manually casting the {@code StreamObserver} returned by a stub. * *

Like {@code StreamObserver}, implementations are not required to be thread-safe; if multiple * threads will be writing to an instance concurrently, the application must synchronize its calls. diff --git a/stub/src/main/java/io/grpc/stub/MetadataUtils.java b/stub/src/main/java/io/grpc/stub/MetadataUtils.java index 0fedf3711f7..5395ba9b5e3 100644 --- a/stub/src/main/java/io/grpc/stub/MetadataUtils.java +++ b/stub/src/main/java/io/grpc/stub/MetadataUtils.java @@ -18,6 +18,7 @@ import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.InlineMe; import io.grpc.CallOptions; import io.grpc.Channel; import io.grpc.ClientCall; @@ -43,8 +44,14 @@ private MetadataUtils() {} * @param stub to bind the headers to. * @param extraHeaders the headers to be passed by each call on the returned stub. * @return an implementation of the stub with {@code extraHeaders} bound to each call. + * @deprecated Use {@code stub.withInterceptors(newAttachHeadersInterceptor(...))} instead. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1789") + @Deprecated + @InlineMe( + replacement = + "stub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(extraHeaders))", + imports = "io.grpc.stub.MetadataUtils") public static > T attachHeaders(T stub, Metadata extraHeaders) { return stub.withInterceptors(newAttachHeadersInterceptor(extraHeaders)); } @@ -98,8 +105,15 @@ public void start(Listener responseListener, Metadata headers) { * @param trailersCapture to record the last received trailers * @return an implementation of the stub that allows to access the last received call's * headers and trailers via {@code headersCapture} and {@code trailersCapture}. + * @deprecated Use {@code stub.withInterceptors(newCaptureMetadataInterceptor())} instead. */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/1789") + @Deprecated + @InlineMe( + replacement = + "stub.withInterceptors(MetadataUtils.newCaptureMetadataInterceptor(headersCapture," + + " trailersCapture))", + imports = "io.grpc.stub.MetadataUtils") public static > T captureMetadata( T stub, AtomicReference headersCapture, diff --git a/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java b/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java index 3ba1bf563ef..a4d4564a46d 100644 --- a/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/ServerCallStreamObserver.java @@ -18,7 +18,8 @@ /** * A refinement of {@link CallStreamObserver} to allows for interaction with call - * cancellation events on the server side. + * cancellation events on the server side. An instance of this class is obtained by casting the + * {@code StreamObserver} passed as an argument to service implementations. * *

Like {@code StreamObserver}, implementations are not required to be thread-safe; if multiple * threads will be writing to an instance concurrently, the application must synchronize its calls. diff --git a/stub/src/main/java/io/grpc/stub/StreamObserver.java b/stub/src/main/java/io/grpc/stub/StreamObserver.java index 92040d9bc58..cf7cc258961 100644 --- a/stub/src/main/java/io/grpc/stub/StreamObserver.java +++ b/stub/src/main/java/io/grpc/stub/StreamObserver.java @@ -31,6 +31,16 @@ * not need to be synchronized together; incoming and outgoing directions are independent. * Since individual {@code StreamObserver}s are not thread-safe, if multiple threads will be * writing to a {@code StreamObserver} concurrently, the application must synchronize calls. + * + *

This API is asynchronous, so methods may return before the operation completes. The API + * provides no guarantees for how quickly an operation will complete, so utilizing flow control via + * {@link ClientCallStreamObserver} and {@link ServerCallStreamObserver} to avoid excessive + * buffering is recommended for streaming RPCs. gRPC's implementation of {@code onError()} on + * client-side causes the RPC to be cancelled and discards all messages, so completes quickly. + * + *

gRPC guarantees it does not block on I/O in its implementation, but applications are allowed + * to perform blocking operations in their implementations. However, doing so will delay other + * callbacks because the methods cannot be called concurrently. */ public interface StreamObserver { /** diff --git a/xds/build.gradle b/xds/build.gradle index ae8d8d208a9..fdee8fab203 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -97,9 +97,9 @@ javadoc { exclude 'io/grpc/xds/*LoadBalancer*' exclude 'io/grpc/xds/Bootstrapper.java' exclude 'io/grpc/xds/Envoy*' + exclude 'io/grpc/xds/FilterChainMatchingProtocolNegotiators.java' exclude 'io/grpc/xds/TlsContextManager.java' exclude 'io/grpc/xds/XdsAttributes.java' - exclude 'io/grpc/xds/XdsClientWrapperForServerSds.java' exclude 'io/grpc/xds/XdsInitializationException.java' exclude 'io/grpc/xds/XdsNameResolverProvider.java' exclude 'io/grpc/xds/internal/**' diff --git a/xds/src/main/java/io/grpc/xds/Bootstrapper.java b/xds/src/main/java/io/grpc/xds/Bootstrapper.java index 08d11d174d4..e1ba4aa176e 100644 --- a/xds/src/main/java/io/grpc/xds/Bootstrapper.java +++ b/xds/src/main/java/io/grpc/xds/Bootstrapper.java @@ -84,7 +84,8 @@ public static class CertificateProviderInfo { private final String pluginName; private final Map config; - CertificateProviderInfo(String pluginName, Map config) { + @VisibleForTesting + public CertificateProviderInfo(String pluginName, Map config) { this.pluginName = checkNotNull(pluginName, "pluginName"); this.config = checkNotNull(config, "config"); } @@ -135,8 +136,9 @@ public Node getNode() { } /** Returns the cert-providers config map. */ + @Nullable public Map getCertProviders() { - return Collections.unmodifiableMap(certProviders); + return certProviders == null ? null : Collections.unmodifiableMap(certProviders); } @Nullable diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index e91e76090ab..036f77f7cd1 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -188,8 +188,10 @@ private void handleClusterDiscovered() { if (root.result.lbPolicy() == LbPolicy.RING_HASH) { lbProvider = lbRegistry.getProvider("ring_hash"); lbConfig = new RingHashConfig(root.result.minRingSize(), root.result.maxRingSize()); - } else { + } + if (lbProvider == null) { lbProvider = lbRegistry.getProvider("round_robin"); + lbConfig = null; } ClusterResolverConfig config = new ClusterResolverConfig( Collections.unmodifiableList(instances), new PolicySelection(lbProvider, lbConfig)); diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 4ae6651784f..f39992c24ac 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -134,8 +134,12 @@ final class ClientXdsClient extends AbstractXdsClient { || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION")); @VisibleForTesting static boolean enableRetry = - !Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")) - && Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")); + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RETRY")); + @VisibleForTesting + static boolean enableRbac = + !Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")) + && Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_RBAC")); private static final String TYPE_URL_HTTP_CONNECTION_MANAGER_V2 = "type.googleapis.com/envoy.config.filter.network.http_connection_manager.v2" @@ -190,6 +194,7 @@ final class ClientXdsClient extends AbstractXdsClient { protected void handleLdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); Set retainedRdsResources = new HashSet<>(); @@ -217,11 +222,12 @@ protected void handleLdsResponse(String versionInfo, List resources, String listener, retainedRdsResources, enableFaultInjection && isResourceV3); } else { ldsUpdate = processServerSideListener( - listener, retainedRdsResources, enableFaultInjection && isResourceV3); + listener, retainedRdsResources, enableRbac && isResourceV3); } } catch (ResourceInvalidException e) { errors.add( "LDS response Listener '" + listenerName + "' validation error: " + e.getMessage()); + invalidResources.add(listenerName); continue; } @@ -231,19 +237,9 @@ protected void handleLdsResponse(String versionInfo, List resources, String getLogger().log(XdsLogLevel.INFO, "Received LDS Response version {0} nonce {1}. Parsed resources: {2}", versionInfo, nonce, unpackedResources); - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.LDS, unpackedResources, versionInfo, nonce, errors); - return; - } - - handleResourcesAccepted(ResourceType.LDS, parsedResources, versionInfo, nonce); - for (String resource : rdsResourceSubscribers.keySet()) { - if (!retainedRdsResources.contains(resource)) { - ResourceSubscriber subscriber = rdsResourceSubscribers.get(resource); - subscriber.onAbsent(); - } - } + handleResourceUpdate( + ResourceType.LDS, parsedResources, invalidResources, retainedRdsResources, versionInfo, + nonce, errors); } private LdsUpdate processClientSideListener( @@ -266,14 +262,20 @@ private LdsUpdate processClientSideListener( private LdsUpdate processServerSideListener( Listener proto, Set rdsResources, boolean parseHttpFilter) throws ResourceInvalidException { + Set certProviderInstances = null; + if (getBootstrapInfo() != null && getBootstrapInfo().getCertProviders() != null) { + certProviderInstances = getBootstrapInfo().getCertProviders().keySet(); + } return LdsUpdate.forTcpListener(parseServerSideListener( - proto, rdsResources, tlsContextManager, filterRegistry, parseHttpFilter)); + proto, rdsResources, tlsContextManager, filterRegistry, certProviderInstances, + parseHttpFilter)); } @VisibleForTesting static EnvoyServerProtoData.Listener parseServerSideListener( Listener proto, Set rdsResources, TlsContextManager tlsContextManager, - FilterRegistry filterRegistry, boolean parseHttpFilter) throws ResourceInvalidException { + FilterRegistry filterRegistry, Set certProviderInstances, boolean parseHttpFilter) + throws ResourceInvalidException { if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) { throw new ResourceInvalidException( "Listener " + proto.getName() + " with invalid traffic direction: " @@ -309,13 +311,13 @@ static EnvoyServerProtoData.Listener parseServerSideListener( for (io.envoyproxy.envoy.config.listener.v3.FilterChain fc : proto.getFilterChainsList()) { filterChains.add( parseFilterChain(fc, rdsResources, tlsContextManager, filterRegistry, uniqueSet, - parseHttpFilter)); + certProviderInstances, parseHttpFilter)); } FilterChain defaultFilterChain = null; if (proto.hasDefaultFilterChain()) { defaultFilterChain = parseFilterChain( proto.getDefaultFilterChain(), rdsResources, tlsContextManager, filterRegistry, - null, parseHttpFilter); + null, certProviderInstances, parseHttpFilter); } return new EnvoyServerProtoData.Listener( @@ -326,43 +328,34 @@ static EnvoyServerProtoData.Listener parseServerSideListener( static FilterChain parseFilterChain( io.envoyproxy.envoy.config.listener.v3.FilterChain proto, Set rdsResources, TlsContextManager tlsContextManager, FilterRegistry filterRegistry, - Set uniqueSet, boolean parseHttpFilters) + Set uniqueSet, Set certProviderInstances, boolean parseHttpFilters) throws ResourceInvalidException { - io.grpc.xds.HttpConnectionManager httpConnectionManager = null; - HashSet uniqueNames = new HashSet<>(); - for (io.envoyproxy.envoy.config.listener.v3.Filter filter : proto.getFiltersList()) { - if (!uniqueNames.add(filter.getName())) { - throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " with duplicated filter: " + filter.getName()); - } - if (!filter.hasTypedConfig()) { - throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() - + " without typed_config"); - } - Any any = filter.getTypedConfig(); - // HttpConnectionManager is the only supported network filter at the moment. - if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { - throw new ResourceInvalidException( - "FilterChain " + proto.getName() + " contains filter " + filter.getName() - + " with unsupported typed_config type " + any.getTypeUrl()); - } - if (httpConnectionManager == null) { - HttpConnectionManager hcmProto; - try { - hcmProto = any.unpack(HttpConnectionManager.class); - } catch (InvalidProtocolBufferException e) { - throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " - + filter.getName() + " failed to unpack message", e); - } - httpConnectionManager = parseHttpConnectionManager( - hcmProto, rdsResources, filterRegistry, parseHttpFilters, false /* isForClient */); - } - } - if (httpConnectionManager == null) { + if (proto.getFiltersCount() != 1) { throw new ResourceInvalidException("FilterChain " + proto.getName() - + " missing required HttpConnectionManager filter"); + + " should contain exact one HttpConnectionManager filter"); } + io.envoyproxy.envoy.config.listener.v3.Filter filter = proto.getFiltersList().get(0); + if (!filter.hasTypedConfig()) { + throw new ResourceInvalidException( + "FilterChain " + proto.getName() + " contains filter " + filter.getName() + + " without typed_config"); + } + Any any = filter.getTypedConfig(); + // HttpConnectionManager is the only supported network filter at the moment. + if (!any.getTypeUrl().equals(TYPE_URL_HTTP_CONNECTION_MANAGER)) { + throw new ResourceInvalidException( + "FilterChain " + proto.getName() + " contains filter " + filter.getName() + + " with unsupported typed_config type " + any.getTypeUrl()); + } + HttpConnectionManager hcmProto; + try { + hcmProto = any.unpack(HttpConnectionManager.class); + } catch (InvalidProtocolBufferException e) { + throw new ResourceInvalidException("FilterChain " + proto.getName() + " with filter " + + filter.getName() + " failed to unpack message", e); + } + io.grpc.xds.HttpConnectionManager httpConnectionManager = parseHttpConnectionManager( + hcmProto, rdsResources, filterRegistry, parseHttpFilters, false /* isForClient */); EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = null; if (proto.hasTransportSocket()) { @@ -380,7 +373,7 @@ static FilterChain parseFilterChain( } downstreamTlsContext = EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( - validateDownstreamTlsContext(downstreamTlsContextProto)); + validateDownstreamTlsContext(downstreamTlsContextProto, certProviderInstances)); } String name = proto.getName(); @@ -399,13 +392,12 @@ static FilterChain parseFilterChain( } @VisibleForTesting - static io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - validateDownstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - downstreamTlsContext) + static DownstreamTlsContext validateDownstreamTlsContext( + DownstreamTlsContext downstreamTlsContext, Set certProviderInstances) throws ResourceInvalidException { if (downstreamTlsContext.hasCommonTlsContext()) { - validateCommonTlsContext(downstreamTlsContext.getCommonTlsContext(), true); + validateCommonTlsContext(downstreamTlsContext.getCommonTlsContext(), certProviderInstances, + true); } else { throw new ResourceInvalidException( "common-tls-context is required in downstream-tls-context"); @@ -414,22 +406,6 @@ static FilterChain parseFilterChain( throw new ResourceInvalidException( "downstream-tls-context with require-sni is not supported"); } - if (downstreamTlsContext.hasSessionTicketKeys()) { - throw new ResourceInvalidException( - "downstream-tls-context with session_ticket_keys is not supported"); - } - if (downstreamTlsContext.hasSessionTicketKeysSdsSecretConfig()) { - throw new ResourceInvalidException( - "downstream-tls-context with session_ticket_keys_sds_secret_config is not supported"); - } - if (downstreamTlsContext.hasDisableStatelessSessionResumption()) { - throw new ResourceInvalidException( - "downstream-tls-context with disable_stateless_session_resumption is not supported"); - } - if (downstreamTlsContext.hasSessionTimeout()) { - throw new ResourceInvalidException( - "downstream-tls-context with session_timeout is not supported"); - } DownstreamTlsContext.OcspStaplePolicy ocspStaplePolicy = downstreamTlsContext .getOcspStaplePolicy(); if (ocspStaplePolicy != DownstreamTlsContext.OcspStaplePolicy.UNRECOGNIZED @@ -444,30 +420,22 @@ static FilterChain parseFilterChain( @VisibleForTesting static io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext validateUpstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext, + Set certProviderInstances) throws ResourceInvalidException { if (upstreamTlsContext.hasCommonTlsContext()) { - validateCommonTlsContext(upstreamTlsContext.getCommonTlsContext(), false); + validateCommonTlsContext(upstreamTlsContext.getCommonTlsContext(), certProviderInstances, + false); } else { throw new ResourceInvalidException("common-tls-context is required in upstream-tls-context"); } - if (!Strings.isNullOrEmpty(upstreamTlsContext.getSni())) { - throw new ResourceInvalidException("upstream-tls-context with sni is not supported"); - } - if (upstreamTlsContext.getAllowRenegotiation()) { - throw new ResourceInvalidException( - "upstream-tls-context with allow_renegotiation is not supported"); - } - if (upstreamTlsContext.hasMaxSessionKeys()) { - throw new ResourceInvalidException( - "upstream-tls-context with max_session_keys is not supported"); - } return upstreamTlsContext; } @VisibleForTesting static void validateCommonTlsContext( - CommonTlsContext commonTlsContext, boolean server) throws ResourceInvalidException { + CommonTlsContext commonTlsContext, Set certProviderInstances, boolean server) + throws ResourceInvalidException { if (commonTlsContext.hasCustomHandshaker()) { throw new ResourceInvalidException( "common-tls-context with custom_handshaker is not supported"); @@ -475,10 +443,6 @@ static void validateCommonTlsContext( if (commonTlsContext.hasTlsParams()) { throw new ResourceInvalidException("common-tls-context with tls_params is not supported"); } - if (commonTlsContext.hasValidationContext()) { - throw new ResourceInvalidException( - "common-tls-context with validation_context is not supported"); - } if (commonTlsContext.hasValidationContextSdsSecretConfig()) { throw new ResourceInvalidException( "common-tls-context with validation_context_sds_secret_config is not supported"); @@ -492,52 +456,54 @@ static void validateCommonTlsContext( "common-tls-context with validation_context_certificate_provider_instance is not" + " supported"); } - if (!commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + String certInstanceName = getIdentityCertInstanceName(commonTlsContext); + if (certInstanceName == null) { if (server) { throw new ResourceInvalidException( - "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); + "tls_certificate_provider_instance is required in downstream-tls-context"); } if (commonTlsContext.getTlsCertificatesCount() > 0) { throw new ResourceInvalidException( - "common-tls-context with tls_certificates is not supported"); + "tls_certificate_provider_instance is unset"); } if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) { throw new ResourceInvalidException( - "common-tls-context with tls_certificate_sds_secret_configs is not supported"); + "tls_certificate_provider_instance is unset"); } if (commonTlsContext.hasTlsCertificateCertificateProvider()) { throw new ResourceInvalidException( - "common-tls-context with tls_certificate_certificate_provider is not supported"); + "tls_certificate_provider_instance is unset"); } + } else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) { + throw new ResourceInvalidException( + "CertificateProvider instance name '" + certInstanceName + + "' not defined in the bootstrap file."); } - if (!commonTlsContext.hasCombinedValidationContext()) { + String rootCaInstanceName = getRootCertInstanceName(commonTlsContext); + if (rootCaInstanceName == null) { if (!server) { throw new ResourceInvalidException( - "combined_validation_context is required in upstream-tls-context"); + "ca_certificate_provider_instance is required in upstream-tls-context"); } } else { - CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext - = commonTlsContext.getCombinedValidationContext(); - if (!combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance()) { + if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) { throw new ResourceInvalidException( - "validation_context_certificate_provider_instance is required in" - + " combined_validation_context"); - } - if (combinedCertificateValidationContext.hasDefaultValidationContext()) { - CertificateValidationContext certificateValidationContext - = combinedCertificateValidationContext.getDefaultValidationContext(); + "ca_certificate_provider_instance name '" + rootCaInstanceName + + "' not defined in the bootstrap file."); + } + CertificateValidationContext certificateValidationContext = null; + if (commonTlsContext.hasValidationContext()) { + certificateValidationContext = commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext() && commonTlsContext + .getCombinedValidationContext().hasDefaultValidationContext()) { + certificateValidationContext = commonTlsContext.getCombinedValidationContext() + .getDefaultValidationContext(); + } + if (certificateValidationContext != null) { if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) { throw new ResourceInvalidException( "match_subject_alt_names only allowed in upstream_tls_context"); } - if (certificateValidationContext.hasTrustedCa()) { - throw new ResourceInvalidException( - "trusted_ca in default_validation_context is not supported"); - } - if (certificateValidationContext.hasWatchedDirectory()) { - throw new ResourceInvalidException( - "watched_directory in default_validation_context is not supported"); - } if (certificateValidationContext.getVerifyCertificateSpkiCount() > 0) { throw new ResourceInvalidException( "verify_certificate_spki in default_validation_context is not supported"); @@ -554,17 +520,6 @@ static void validateCommonTlsContext( if (certificateValidationContext.hasCrl()) { throw new ResourceInvalidException("crl in default_validation_context is not supported"); } - if (certificateValidationContext.getAllowExpiredCertificate()) { - throw new ResourceInvalidException( - "allow_expired_certificate in default_validation_context is not supported"); - } - CertificateValidationContext.TrustChainVerification trustChainVerification - = certificateValidationContext.getTrustChainVerification(); - if (trustChainVerification - != CertificateValidationContext.TrustChainVerification.VERIFY_TRUST_CHAIN) { - throw new ResourceInvalidException( - "Only VERIFY_TRUST_CHAIN for trust_chain_verification supported"); - } if (certificateValidationContext.hasCustomValidatorConfig()) { throw new ResourceInvalidException( "custom_validator_config in default_validation_context is not supported"); @@ -573,6 +528,38 @@ static void validateCommonTlsContext( } } + private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName(); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName(); + } + return null; + } + + private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + if (commonTlsContext.getValidationContext().hasCaCertificateProviderInstance()) { + return commonTlsContext.getValidationContext().getCaCertificateProviderInstance() + .getInstanceName(); + } + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext + = commonTlsContext.getCombinedValidationContext(); + if (combinedCertificateValidationContext.hasDefaultValidationContext() + && combinedCertificateValidationContext.getDefaultValidationContext() + .hasCaCertificateProviderInstance()) { + return combinedCertificateValidationContext.getDefaultValidationContext() + .getCaCertificateProviderInstance().getInstanceName(); + } else if (combinedCertificateValidationContext + .hasValidationContextCertificateProviderInstance()) { + return combinedCertificateValidationContext + .getValidationContextCertificateProviderInstance().getInstanceName(); + } + } + return null; + } + private static void checkForUniqueness(Set uniqueSet, FilterChainMatch filterChainMatch) throws ResourceInvalidException { if (uniqueSet != null) { @@ -746,10 +733,14 @@ private static FilterChainMatch parseFilterChainMatch( static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( HttpConnectionManager proto, Set rdsResources, FilterRegistry filterRegistry, boolean parseHttpFilter, boolean isForClient) throws ResourceInvalidException { - if (proto.getXffNumTrustedHops() != 0) { + if (enableRbac && proto.getXffNumTrustedHops() != 0) { throw new ResourceInvalidException( "HttpConnectionManager with xff_num_trusted_hops unsupported"); } + if (enableRbac && !proto.getOriginalIpDetectionExtensionsList().isEmpty()) { + throw new ResourceInvalidException("HttpConnectionManager with " + + "original_ip_detection_extensions unsupported"); + } // Obtain max_stream_duration from Http Protocol Options. long maxStreamDuration = 0; if (proto.hasCommonHttpProtocolOptions()) { @@ -762,10 +753,14 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( // Parse http filters. List filterConfigs = null; if (parseHttpFilter) { + if (proto.getHttpFiltersList().isEmpty()) { + throw new ResourceInvalidException("Missing HttpFilter in HttpConnectionManager."); + } filterConfigs = new ArrayList<>(); Set names = new HashSet<>(); - for (io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter - httpFilter : proto.getHttpFiltersList()) { + for (int i = 0; i < proto.getHttpFiltersCount(); i++) { + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter + httpFilter = proto.getHttpFiltersList().get(i); String filterName = httpFilter.getName(); if (!names.add(filterName)) { throw new ResourceInvalidException( @@ -773,6 +768,11 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( } StructOrError filterConfig = parseHttpFilter(httpFilter, filterRegistry, isForClient); + if ((i == proto.getHttpFiltersCount() - 1) + && (filterConfig == null || !isTerminalFilter(filterConfig.struct))) { + throw new ResourceInvalidException("The last HttpFilter must be a terminal filter: " + + filterName); + } if (filterConfig == null) { continue; } @@ -781,6 +781,10 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( "HttpConnectionManager contains invalid HttpFilter: " + filterConfig.getErrorDetail()); } + if ((i < proto.getHttpFiltersCount() - 1) && isTerminalFilter(filterConfig.getStruct())) { + throw new ResourceInvalidException("A terminal HttpFilter must be the last filter: " + + filterName); + } filterConfigs.add(new NamedFilterConfig(filterName, filterConfig.struct)); } } @@ -821,6 +825,11 @@ static io.grpc.xds.HttpConnectionManager parseHttpConnectionManager( "HttpConnectionManager neither has inlined route_config nor RDS"); } + // hard-coded: currently router config is the only terminal filter. + private static boolean isTerminalFilter(FilterConfig filterConfig) { + return RouterFilter.ROUTER_CONFIG.equals(filterConfig); + } + @VisibleForTesting @Nullable // Returns null if the filter is optional but not supported. static StructOrError parseHttpFilter( @@ -1254,32 +1263,28 @@ private static StructOrError parseRetryPolicy( maxBackoff = Durations.fromNanos(Durations.toNanos(initialBackoff) * 10); } } - Iterable retryOns = Splitter.on(',').split(retryPolicyProto.getRetryOn()); + Iterable retryOns = + Splitter.on(',').omitEmptyStrings().trimResults().split(retryPolicyProto.getRetryOn()); ImmutableList.Builder retryableStatusCodesBuilder = ImmutableList.builder(); for (String retryOn : retryOns) { Code code; try { code = Code.valueOf(retryOn.toUpperCase(Locale.US).replace('-', '_')); } catch (IllegalArgumentException e) { - // TODO(zdapeng): TBD // unsupported value, such as "5xx" - return null; + continue; } if (!SUPPORTED_RETRYABLE_CODES.contains(code)) { - // TODO(zdapeng): TBD // unsupported value - return null; + continue; } retryableStatusCodesBuilder.add(code); } List retryableStatusCodes = retryableStatusCodesBuilder.build(); - if (!retryableStatusCodes.isEmpty()) { - return StructOrError.fromStruct( - RetryPolicy.create( - maxAttempts, retryableStatusCodes, initialBackoff, maxBackoff, - /* perAttemptRecvTimeout= */ null)); - } - return null; + return StructOrError.fromStruct( + RetryPolicy.create( + maxAttempts, retryableStatusCodes, initialBackoff, maxBackoff, + /* perAttemptRecvTimeout= */ null)); } @VisibleForTesting @@ -1305,6 +1310,7 @@ static StructOrError parseClusterWeight( protected void handleRdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); for (int i = 0; i < resources.size(); i++) { @@ -1332,6 +1338,7 @@ protected void handleRdsResponse(String versionInfo, List resources, String errors.add( "RDS response RouteConfiguration '" + routeConfigName + "' validation error: " + e .getMessage()); + invalidResources.add(routeConfigName); continue; } @@ -1340,12 +1347,9 @@ protected void handleRdsResponse(String versionInfo, List resources, String getLogger().log(XdsLogLevel.INFO, "Received RDS Response version {0} nonce {1}. Parsed resources: {2}", versionInfo, nonce, unpackedResources); - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.RDS, unpackedResources, versionInfo, nonce, errors); - } else { - handleResourcesAccepted(ResourceType.RDS, parsedResources, versionInfo, nonce); - } + handleResourceUpdate( + ResourceType.RDS, parsedResources, invalidResources, Collections.emptySet(), + versionInfo, nonce, errors); } private static RdsUpdate processRouteConfiguration( @@ -1369,6 +1373,7 @@ private static RdsUpdate processRouteConfiguration( protected void handleCdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); Set retainedEdsResources = new HashSet<>(); @@ -1397,10 +1402,15 @@ protected void handleCdsResponse(String versionInfo, List resources, String // Process Cluster into CdsUpdate. CdsUpdate cdsUpdate; try { - cdsUpdate = parseCluster(cluster, retainedEdsResources); + Set certProviderInstances = null; + if (getBootstrapInfo() != null && getBootstrapInfo().getCertProviders() != null) { + certProviderInstances = getBootstrapInfo().getCertProviders().keySet(); + } + cdsUpdate = parseCluster(cluster, retainedEdsResources, certProviderInstances); } catch (ResourceInvalidException e) { errors.add( "CDS response Cluster '" + clusterName + "' validation error: " + e.getMessage()); + invalidResources.add(clusterName); continue; } parsedResources.put(clusterName, new ParsedResource(cdsUpdate, resource)); @@ -1408,30 +1418,20 @@ protected void handleCdsResponse(String versionInfo, List resources, String getLogger().log(XdsLogLevel.INFO, "Received CDS Response version {0} nonce {1}. Parsed resources: {2}", versionInfo, nonce, unpackedResources); - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.CDS, unpackedResources, versionInfo, nonce, errors); - return; - } - - handleResourcesAccepted(ResourceType.CDS, parsedResources, versionInfo, nonce); - // CDS responses represents the state of the world, EDS resources not referenced in CDS - // resources should be deleted. - for (String resource : edsResourceSubscribers.keySet()) { - ResourceSubscriber subscriber = edsResourceSubscribers.get(resource); - if (!retainedEdsResources.contains(resource)) { - subscriber.onAbsent(); - } - } + handleResourceUpdate( + ResourceType.CDS, parsedResources, invalidResources, retainedEdsResources, versionInfo, + nonce, errors); } @VisibleForTesting - static CdsUpdate parseCluster(Cluster cluster, Set retainedEdsResources) + static CdsUpdate parseCluster(Cluster cluster, Set retainedEdsResources, + Set certProviderInstances) throws ResourceInvalidException { StructOrError structOrError; switch (cluster.getClusterDiscoveryTypeCase()) { case TYPE: - structOrError = parseNonAggregateCluster(cluster, retainedEdsResources); + structOrError = parseNonAggregateCluster(cluster, retainedEdsResources, + certProviderInstances); break; case CLUSTER_TYPE: structOrError = parseAggregateCluster(cluster); @@ -1494,7 +1494,7 @@ private static StructOrError parseAggregateCluster(Cluster cl } private static StructOrError parseNonAggregateCluster( - Cluster cluster, Set edsResources) { + Cluster cluster, Set edsResources, Set certProviderInstances) { String clusterName = cluster.getName(); String lrsServerName = null; Long maxConcurrentRequests = null; @@ -1517,6 +1517,10 @@ private static StructOrError parseNonAggregateCluster( } } } + if (cluster.getTransportSocketMatchesCount() > 0) { + return StructOrError.fromError("Cluster " + clusterName + + ": transport-socket-matches not supported."); + } if (cluster.hasTransportSocket()) { if (!TRANSPORT_SOCKET_NAME_TLS.equals(cluster.getTransportSocket().getName())) { return StructOrError.fromError("transport-socket with name " @@ -1527,7 +1531,8 @@ private static StructOrError parseNonAggregateCluster( validateUpstreamTlsContext( unpackCompatibleType(cluster.getTransportSocket().getTypedConfig(), io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext.class, - TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2))); + TYPE_URL_UPSTREAM_TLS_CONTEXT, TYPE_URL_UPSTREAM_TLS_CONTEXT_V2), + certProviderInstances)); } catch (InvalidProtocolBufferException | ResourceInvalidException e) { return StructOrError.fromError( "Cluster " + clusterName + ": malformed UpstreamTlsContext: " + e); @@ -1596,6 +1601,7 @@ private static StructOrError parseNonAggregateCluster( protected void handleEdsResponse(String versionInfo, List resources, String nonce) { Map parsedResources = new HashMap<>(resources.size()); Set unpackedResources = new HashSet<>(resources.size()); + Set invalidResources = new HashSet<>(); List errors = new ArrayList<>(); for (int i = 0; i < resources.size(); i++) { @@ -1630,16 +1636,17 @@ protected void handleEdsResponse(String versionInfo, List resources, String } catch (ResourceInvalidException e) { errors.add("EDS response ClusterLoadAssignment '" + clusterName + "' validation error: " + e.getMessage()); + invalidResources.add(clusterName); continue; } parsedResources.put(clusterName, new ParsedResource(edsUpdate, resource)); } - - if (!errors.isEmpty()) { - handleResourcesRejected(ResourceType.EDS, unpackedResources, versionInfo, nonce, errors); - } else { - handleResourcesAccepted(ResourceType.EDS, parsedResources, versionInfo, nonce); - } + getLogger().log( + XdsLogLevel.INFO, "Received EDS Response version {0} nonce {1}. Parsed resources: {2}", + versionInfo, nonce, unpackedResources); + handleResourceUpdate( + ResourceType.EDS, parsedResources, invalidResources, Collections.emptySet(), + versionInfo, nonce, errors); } private static EdsUpdate processClusterLoadAssignment(ClusterLoadAssignment assignment) @@ -2029,43 +2036,67 @@ private void cleanUpResourceTimers() { } } - private void handleResourcesAccepted( - ResourceType type, Map parsedResources, String version, - String nonce) { - ackResponse(type, version, nonce); - + private void handleResourceUpdate( + ResourceType type, Map parsedResources, Set invalidResources, + Set retainedResources, String version, String nonce, List errors) { + String errorDetail = null; + if (errors.isEmpty()) { + checkArgument(invalidResources.isEmpty(), "found invalid resources but missing errors"); + ackResponse(type, version, nonce); + } else { + errorDetail = Joiner.on('\n').join(errors); + getLogger().log(XdsLogLevel.WARNING, + "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", + type, version, nonce, errorDetail); + nackResponse(type, nonce, errorDetail); + } long updateTime = timeProvider.currentTimeNanos(); for (Map.Entry entry : getSubscribedResourcesMap(type).entrySet()) { String resourceName = entry.getKey(); ResourceSubscriber subscriber = entry.getValue(); + // Attach error details to the subscribed resources that included in the ADS update. + if (invalidResources.contains(resourceName)) { + subscriber.onRejected(version, updateTime, errorDetail); + } // Notify the watchers. if (parsedResources.containsKey(resourceName)) { subscriber.onData(parsedResources.get(resourceName), version, updateTime); } else if (type == ResourceType.LDS || type == ResourceType.CDS) { + if (subscriber.data != null && invalidResources.contains(resourceName)) { + // Update is rejected but keep using the cached data. + if (type == ResourceType.LDS) { + LdsUpdate ldsUpdate = (LdsUpdate) subscriber.data; + io.grpc.xds.HttpConnectionManager hcm = ldsUpdate.httpConnectionManager(); + if (hcm != null) { + String rdsName = hcm.rdsName(); + if (rdsName != null) { + retainedResources.add(rdsName); + } + } + } else { + CdsUpdate cdsUpdate = (CdsUpdate) subscriber.data; + String edsName = cdsUpdate.edsServiceName(); + if (edsName == null) { + edsName = cdsUpdate.clusterName(); + } + retainedResources.add(edsName); + } + continue; + } // For State of the World services, notify watchers when their watched resource is missing // from the ADS update. subscriber.onAbsent(); } } - } - - private void handleResourcesRejected( - ResourceType type, Set unpackedResourceNames, String version, - String nonce, List errors) { - String errorDetail = Joiner.on('\n').join(errors); - getLogger().log(XdsLogLevel.WARNING, - "Failed processing {0} Response version {1} nonce {2}. Errors:\n{3}", - type, version, nonce, errorDetail); - nackResponse(type, nonce, errorDetail); - - long updateTime = timeProvider.currentTimeNanos(); - for (Map.Entry entry : getSubscribedResourcesMap(type).entrySet()) { - String resourceName = entry.getKey(); - ResourceSubscriber subscriber = entry.getValue(); - - // Attach error details to the subscribed resources that included in the ADS update. - if (unpackedResourceNames.contains(resourceName)) { - subscriber.onRejected(version, updateTime, errorDetail); + // LDS/CDS responses represents the state of the world, RDS/EDS resources not referenced in + // LDS/CDS resources should be deleted. + if (type == ResourceType.LDS || type == ResourceType.CDS) { + Map dependentSubscribers = + type == ResourceType.LDS ? rdsResourceSubscribers : edsResourceSubscribers; + for (String resource : dependentSubscribers.keySet()) { + if (!retainedResources.contains(resource)) { + dependentSubscribers.get(resource).onAbsent(); + } } } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 5beefc3384c..d95361935a7 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -31,8 +31,8 @@ import io.grpc.LoadBalancer; import io.grpc.Metadata; import io.grpc.Status; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.internal.ObjectPool; -import io.grpc.util.ForwardingClientStreamTracer; import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; import io.grpc.xds.ClusterImplLoadBalancerProvider.ClusterImplConfig; @@ -69,7 +69,8 @@ final class ClusterImplLoadBalancer extends LoadBalancer { || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING")); @VisibleForTesting static boolean enableSecurity = - Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT")); + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT")); private static final Attributes.Key ATTR_CLUSTER_LOCALITY_STATS = Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityStats"); @@ -329,7 +330,8 @@ public String toString() { } } - private static final class CountingStreamTracerFactory extends ClientStreamTracer.Factory { + private static final class CountingStreamTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { private ClusterLocalityStats stats; private final AtomicLong inFlights; @Nullable diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index f05e6dd6c9a..aa53d834d3b 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -313,7 +313,7 @@ public String toString() { /** * Corresponds to Envoy proto message {@link io.envoyproxy.envoy.api.v2.listener.FilterChain}. */ - public static final class FilterChain { + static final class FilterChain { // Unique name for the FilterChain. private final String name; // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. diff --git a/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java new file mode 100644 index 00000000000..b828b862454 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/FilterChainMatchingProtocolNegotiators.java @@ -0,0 +1,446 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; +import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; +import com.google.common.collect.Iterables; +import com.google.protobuf.UInt32Value; +import io.grpc.Attributes; +import io.grpc.internal.ObjectPool; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalGracefulServerCloseCommand; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiator; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.ProtocolNegotiationEvent; +import io.grpc.xds.EnvoyServerProtoData.CidrRange; +import io.grpc.xds.EnvoyServerProtoData.ConnectionSourceType; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; +import io.grpc.xds.internal.Matchers.CidrMatcher; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.AsciiString; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + + +/** + * Handles L4 filter chain match for the connection based on the xds configuration. + * */ +@SuppressWarnings("FutureReturnValueIgnored") // Netty doesn't follow this pattern +final class FilterChainMatchingProtocolNegotiators { + private static final Logger log = Logger.getLogger( + FilterChainMatchingProtocolNegotiators.class.getName()); + + private static final AsciiString SCHEME = AsciiString.of("http"); + + private FilterChainMatchingProtocolNegotiators() { + } + + @VisibleForTesting + static final class FilterChainMatchingHandler extends ChannelInboundHandlerAdapter { + + private final GrpcHttp2ConnectionHandler grpcHandler; + private final FilterChainSelectorManager filterChainSelectorManager; + private final ProtocolNegotiator delegate; + + FilterChainMatchingHandler( + GrpcHttp2ConnectionHandler grpcHandler, + FilterChainSelectorManager filterChainSelectorManager, + ProtocolNegotiator delegate) { + this.grpcHandler = checkNotNull(grpcHandler, "grpcHandler"); + this.filterChainSelectorManager = + checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (!(evt instanceof ProtocolNegotiationEvent)) { + super.userEventTriggered(ctx, evt); + return; + } + long drainGraceTime = 0; + TimeUnit drainGraceTimeUnit = null; + Long drainGraceNanosObj = grpcHandler.getEagAttributes().get(ATTR_DRAIN_GRACE_NANOS); + if (drainGraceNanosObj != null) { + drainGraceTime = drainGraceNanosObj; + drainGraceTimeUnit = TimeUnit.NANOSECONDS; + } + FilterChainSelectorManager.Closer closer = new FilterChainSelectorManager.Closer( + new GracefullyShutdownChannelRunnable(ctx.channel(), drainGraceTime, drainGraceTimeUnit)); + FilterChainSelector selector = filterChainSelectorManager.register(closer); + ctx.channel().closeFuture().addListener( + new FilterChainSelectorManagerDeregister(filterChainSelectorManager, closer)); + checkNotNull(selector, "selector"); + SelectedConfig config = selector.select( + (InetSocketAddress) ctx.channel().localAddress(), + (InetSocketAddress) ctx.channel().remoteAddress()); + if (config == null) { + log.log(Level.WARNING, "Connection from {0} to {1} has no matching filter chain. Closing", + new Object[] {ctx.channel().remoteAddress(), ctx.channel().localAddress()}); + ctx.close().addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + return; + } + ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent) evt; + // TODO(zivy): merge into one key and take care of this outer class visibility. + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(pne).toBuilder() + .set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, config.sslContextProviderSupplier) + .set(ATTR_SERVER_ROUTING_CONFIG, config.routingConfig) + .build(); + pne = InternalProtocolNegotiationEvent.withAttributes(pne, attr); + ctx.pipeline().replace(this, null, delegate.newHandler(grpcHandler)); + ctx.fireUserEventTriggered(pne); + } + + static final class FilterChainSelector { + public static final FilterChainSelector NO_FILTER_CHAIN = new FilterChainSelector( + Collections.>emptyMap(), + null, new AtomicReference()); + private final Map> routingConfigs; + @Nullable + private final SslContextProviderSupplier defaultSslContextProviderSupplier; + @Nullable + private final AtomicReference defaultRoutingConfig; + + FilterChainSelector(Map> routingConfigs, + @Nullable SslContextProviderSupplier defaultSslContextProviderSupplier, + @Nullable AtomicReference defaultRoutingConfig) { + this.routingConfigs = checkNotNull(routingConfigs, "routingConfigs"); + this.defaultSslContextProviderSupplier = defaultSslContextProviderSupplier; + this.defaultRoutingConfig = checkNotNull(defaultRoutingConfig, "defaultRoutingConfig"); + } + + @VisibleForTesting + Map> getRoutingConfigs() { + return routingConfigs; + } + + @VisibleForTesting + AtomicReference getDefaultRoutingConfig() { + return defaultRoutingConfig; + } + + @VisibleForTesting + SslContextProviderSupplier getDefaultSslContextProviderSupplier() { + return defaultSslContextProviderSupplier; + } + + /** + * Throws IllegalStateException when no exact one match, and we should close the connection. + */ + SelectedConfig select(InetSocketAddress localAddr, InetSocketAddress remoteAddr) { + Collection filterChains = routingConfigs.keySet(); + filterChains = filterOnDestinationPort(filterChains); + filterChains = filterOnIpAddress(filterChains, localAddr.getAddress(), true); + filterChains = filterOnServerNames(filterChains); + filterChains = filterOnTransportProtocol(filterChains); + filterChains = filterOnApplicationProtocols(filterChains); + filterChains = + filterOnSourceType(filterChains, remoteAddr.getAddress(), localAddr.getAddress()); + filterChains = filterOnIpAddress(filterChains, remoteAddr.getAddress(), false); + filterChains = filterOnSourcePort(filterChains, remoteAddr.getPort()); + + if (filterChains.size() > 1) { + throw new IllegalStateException("Found more than one matching filter chains. This should " + + "not be possible as ClientXdsClient validated the chains for uniqueness."); + } + if (filterChains.size() == 1) { + FilterChain selected = Iterables.getOnlyElement(filterChains); + return new SelectedConfig( + routingConfigs.get(selected), selected.getSslContextProviderSupplier()); + } + if (defaultRoutingConfig.get() != null) { + return new SelectedConfig(defaultRoutingConfig, defaultSslContextProviderSupplier); + } + return null; + } + + // reject if filer-chain-match has non-empty application_protocols + private static Collection filterOnApplicationProtocols( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getApplicationProtocols().isEmpty()) { + filtered.add(filterChain); + } + } + return filtered; + } + + // reject if filer-chain-match has non-empty transport protocol other than "raw_buffer" + private static Collection filterOnTransportProtocol( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + String transportProtocol = filterChainMatch.getTransportProtocol(); + if (Strings.isNullOrEmpty(transportProtocol) || "raw_buffer".equals(transportProtocol)) { + filtered.add(filterChain); + } + } + return filtered; + } + + // reject if filer-chain-match has server_name(s) + private static Collection filterOnServerNames( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getServerNames().isEmpty()) { + filtered.add(filterChain); + } + } + return filtered; + } + + // destination_port present => Always fail match + private static Collection filterOnDestinationPort( + Collection filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getDestinationPort() + == UInt32Value.getDefaultInstance().getValue()) { + filtered.add(filterChain); + } + } + return filtered; + } + + private static Collection filterOnSourcePort( + Collection filterChains, int sourcePort) { + ArrayList filteredOnMatch = new ArrayList<>(filterChains.size()); + ArrayList filteredOnEmpty = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + List sourcePortsToMatch = filterChainMatch.getSourcePorts(); + if (sourcePortsToMatch.isEmpty()) { + filteredOnEmpty.add(filterChain); + } else if (sourcePortsToMatch.contains(sourcePort)) { + filteredOnMatch.add(filterChain); + } + } + // match against source port is more specific than match against empty list + return filteredOnMatch.isEmpty() ? filteredOnEmpty : filteredOnMatch; + } + + private static Collection filterOnSourceType( + Collection filterChains, InetAddress sourceAddress, + InetAddress destAddress) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + ConnectionSourceType sourceType = + filterChainMatch.getConnectionSourceType(); + + boolean matching = false; + if (sourceType == ConnectionSourceType.SAME_IP_OR_LOOPBACK) { + matching = + sourceAddress.isLoopbackAddress() + || sourceAddress.isAnyLocalAddress() + || sourceAddress.equals(destAddress); + } else if (sourceType == ConnectionSourceType.EXTERNAL) { + matching = !sourceAddress.isLoopbackAddress() && !sourceAddress.isAnyLocalAddress(); + } else { // ANY or null + matching = true; + } + if (matching) { + filtered.add(filterChain); + } + } + return filtered; + } + + private static int getMatchingPrefixLength( + FilterChainMatch filterChainMatch, InetAddress address, boolean forDestination) { + boolean isIPv6 = address instanceof Inet6Address; + List cidrRanges = + forDestination + ? filterChainMatch.getPrefixRanges() + : filterChainMatch.getSourcePrefixRanges(); + int matchingPrefixLength; + if (cidrRanges.isEmpty()) { // if there is no CidrRange assume 0-length match + matchingPrefixLength = 0; + } else { + matchingPrefixLength = -1; + for (CidrRange cidrRange : cidrRanges) { + InetAddress cidrAddr = cidrRange.getAddressPrefix(); + boolean cidrIsIpv6 = cidrAddr instanceof Inet6Address; + if (isIPv6 == cidrIsIpv6) { + int prefixLen = cidrRange.getPrefixLen(); + CidrMatcher matcher = CidrMatcher.create(cidrAddr, prefixLen); + if (matcher.matches(address) && prefixLen > matchingPrefixLength) { + matchingPrefixLength = prefixLen; + } + } + } + } + return matchingPrefixLength; + } + + // use prefix_ranges (CIDR) and get the most specific matches + private static Collection filterOnIpAddress( + Collection filterChains, InetAddress address, boolean forDestination) { + // curent list of top ones + ArrayList topOnes = new ArrayList<>(filterChains.size()); + int topMatchingPrefixLen = -1; + for (FilterChain filterChain : filterChains) { + int currentMatchingPrefixLen = getMatchingPrefixLength( + filterChain.getFilterChainMatch(), address, forDestination); + + if (currentMatchingPrefixLen >= 0) { + if (currentMatchingPrefixLen < topMatchingPrefixLen) { + continue; + } + if (currentMatchingPrefixLen > topMatchingPrefixLen) { + topMatchingPrefixLen = currentMatchingPrefixLen; + topOnes.clear(); + } + topOnes.add(filterChain); + } + } + return topOnes; + } + } + } + + static final class FilterChainMatchingNegotiatorServerFactory + implements InternalProtocolNegotiator.ServerFactory { + private final InternalProtocolNegotiator.ServerFactory delegate; + + public FilterChainMatchingNegotiatorServerFactory( + InternalProtocolNegotiator.ServerFactory delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public ProtocolNegotiator newNegotiator( + final ObjectPool offloadExecutorPool) { + + class FilterChainMatchingNegotiator implements ProtocolNegotiator { + + @Override + public AsciiString scheme() { + return SCHEME; + } + + @Override + public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { + FilterChainSelectorManager filterChainSelectorManager = + grpcHandler.getEagAttributes().get(ATTR_FILTER_CHAIN_SELECTOR_MANAGER); + checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); + return new FilterChainMatchingHandler(grpcHandler, filterChainSelectorManager, + delegate.newNegotiator(offloadExecutorPool)); + } + + @Override + public void close() { + } + } + + return new FilterChainMatchingNegotiator(); + } + } + + /** + * The FilterChain level configuration. + */ + private static final class SelectedConfig { + private final AtomicReference routingConfig; + @Nullable + private final SslContextProviderSupplier sslContextProviderSupplier; + + private SelectedConfig(AtomicReference routingConfig, + @Nullable SslContextProviderSupplier sslContextProviderSupplier) { + this.routingConfig = checkNotNull(routingConfig, "routingConfig"); + this.sslContextProviderSupplier = sslContextProviderSupplier; + } + } + + private static class FilterChainSelectorManagerDeregister implements ChannelFutureListener { + private final FilterChainSelectorManager filterChainSelectorManager; + private final FilterChainSelectorManager.Closer closer; + + public FilterChainSelectorManagerDeregister( + FilterChainSelectorManager filterChainSelectorManager, + FilterChainSelectorManager.Closer closer) { + this.filterChainSelectorManager = + checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); + this.closer = checkNotNull(closer, "closer"); + } + + @Override public void operationComplete(ChannelFuture future) throws Exception { + filterChainSelectorManager.deregister(closer); + } + } + + private static class GracefullyShutdownChannelRunnable implements Runnable { + private final Channel channel; + private final long drainGraceTime; + @Nullable + private final TimeUnit drainGraceTimeUnit; + + public GracefullyShutdownChannelRunnable( + Channel channel, long drainGraceTime, @Nullable TimeUnit drainGraceTimeUnit) { + this.channel = checkNotNull(channel, "channel"); + this.drainGraceTime = drainGraceTime; + this.drainGraceTimeUnit = drainGraceTimeUnit; + } + + @Override public void run() { + Object gracefulCloseCommand = InternalGracefulServerCloseCommand.create( + "xds_drain", drainGraceTime, drainGraceTimeUnit); + channel.writeAndFlush(gracefulCloseCommand) + .addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java new file mode 100644 index 00000000000..4295d75f59b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/FilterChainSelectorManager.java @@ -0,0 +1,95 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import java.util.Comparator; +import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicLong; +import javax.annotation.concurrent.GuardedBy; + +/** + * Maintains the current xDS selector and any resources using that selector. When the selector + * changes, old resources are closed to avoid old config usages. + */ +final class FilterChainSelectorManager { + private static final AtomicLong closerId = new AtomicLong(); + + private final Object lock = new Object(); + @GuardedBy("lock") + private FilterChainSelector selector; + // Avoid HashSet since it does not decrease in size, forming a high water mark. + @GuardedBy("lock") + private TreeSet closers = new TreeSet(new CloserComparator()); + + public FilterChainSelector register(Closer closer) { + synchronized (lock) { + Preconditions.checkState(closers.add(closer), "closer already registered"); + return selector; + } + } + + public void deregister(Closer closer) { + synchronized (lock) { + closers.remove(closer); + } + } + + /** Only safe to be called by code that is responsible for updating the selector. */ + public FilterChainSelector getSelectorToUpdateSelector() { + synchronized (lock) { + return selector; + } + } + + public void updateSelector(FilterChainSelector newSelector) { + TreeSet oldClosers; + synchronized (lock) { + oldClosers = closers; + closers = new TreeSet(closers.comparator()); + selector = newSelector; + } + for (Closer closer : oldClosers) { + closer.closer.run(); + } + } + + @VisibleForTesting + int getRegisterCount() { + synchronized (lock) { + return closers.size(); + } + } + + public static final class Closer { + private final long id = closerId.getAndIncrement(); + private final Runnable closer; + + /** {@code closer} may be run multiple times. */ + public Closer(Runnable closer) { + this.closer = Preconditions.checkNotNull(closer, "closer"); + } + } + + private static class CloserComparator implements Comparator { + @Override public int compare(Closer c1, Closer c2) { + return Long.compare(c1.id, c2.id); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java index 55ec809772f..2845d0a00e8 100644 --- a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java @@ -184,11 +184,11 @@ public void run() { ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); } ImmutableMap.Builder serverBuilder = ImmutableMap.builder(); - String server_uri = "directpath-trafficdirector.googleapis.com"; + String serverUri = "directpath-pa.googleapis.com"; if (serverUriOverride != null && serverUriOverride.length() > 0) { - server_uri = serverUriOverride; + serverUri = serverUriOverride; } - serverBuilder.put("server_uri", server_uri); + serverBuilder.put("server_uri", serverUri); serverBuilder.put("channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default"))); serverBuilder.put("server_features", ImmutableList.of("xds_v3")); diff --git a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java index efeee0758a3..410a64df9ca 100644 --- a/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java @@ -75,5 +75,18 @@ public final class InternalXdsAttributes { static final Attributes.Key ATTR_SERVER_WEIGHT = Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight"); + /** + * Filter chain match for network filters. + */ + @Grpc.TransportAttr + static final Attributes.Key + ATTR_FILTER_CHAIN_SELECTOR_MANAGER = Attributes.Key.create( + "io.grpc.xds.InternalXdsAttributes.filterChainSelectorManager"); + + /** Grace time to use when draining. Null for an infinite grace time. */ + @Grpc.TransportAttr + static final Attributes.Key ATTR_DRAIN_GRACE_NANOS = + Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.drainGraceTime"); + private InternalXdsAttributes() {} } diff --git a/xds/src/main/java/io/grpc/xds/LameFilter.java b/xds/src/main/java/io/grpc/xds/LameFilter.java deleted file mode 100644 index 4dd1d3c96ed..00000000000 --- a/xds/src/main/java/io/grpc/xds/LameFilter.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import com.google.common.util.concurrent.MoreExecutors; -import com.google.protobuf.Message; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.Context; -import io.grpc.LoadBalancer.PickSubchannelArgs; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import io.grpc.xds.Filter.ClientInterceptorBuilder; -import java.util.concurrent.Executor; -import java.util.concurrent.ScheduledExecutorService; -import javax.annotation.Nullable; - -/** - * A filter that fails all RPCs. To be added to the end of filter chain if RouterFilter is absent. - */ -enum LameFilter implements Filter, ClientInterceptorBuilder { - INSTANCE; - - static final FilterConfig LAME_CONFIG = new FilterConfig() { - @Override - public String typeUrl() { - throw new UnsupportedOperationException("shouldn't be called"); - } - - @Override - public String toString() { - return "LAME_CONFIG"; - } - }; - - @Override - public String[] typeUrls() { - return new String[0]; - } - - @Override - public ConfigOrError parseFilterConfig(Message rawProtoMessage) { - throw new UnsupportedOperationException(); - } - - @Override - public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { - throw new UnsupportedOperationException(); - } - - @Nullable - @Override - public ClientInterceptor buildClientInterceptor( - FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args, - ScheduledExecutorService scheduler) { - class LameInterceptor implements ClientInterceptor { - - @Override - public ClientCall interceptCall( - MethodDescriptor method, final CallOptions callOptions, Channel next) { - final Context context = Context.current(); - return new ClientCall() { - @Override - public void start(final Listener listener, Metadata headers) { - Executor callExecutor = callOptions.getExecutor(); - if (callExecutor == null) { // This should never happen in practice because - // ManagedChannelImpl.ConfigSelectingClientCall always provides CallOptions with - // a callExecutor. - // TODO(https://github.com/grpc/grpc-java/issues/7868) - callExecutor = MoreExecutors.directExecutor(); - } - callExecutor.execute( - new Runnable() { - @Override - public void run() { - Context previous = context.attach(); - try { - listener.onClose( - Status.UNAVAILABLE.withDescription("No router filter"), new Metadata()); - } finally { - context.detach(previous); - } - } - }); - } - - @Override - public void request(int numMessages) {} - - @Override - public void cancel(@Nullable String message, @Nullable Throwable cause) {} - - @Override - public void halfClose() {} - - @Override - public void sendMessage(ReqT message) {} - }; - } - } - - return new LameInterceptor(); - } -} diff --git a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java index c193f5e35e5..156d53f638e 100644 --- a/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java +++ b/xds/src/main/java/io/grpc/xds/OrcaPerRequestUtil.java @@ -25,8 +25,8 @@ import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.LoadBalancer; import io.grpc.Metadata; +import io.grpc.internal.ForwardingClientStreamTracer; import io.grpc.protobuf.ProtoUtils; -import io.grpc.util.ForwardingClientStreamTracer; import java.util.ArrayList; import java.util.List; @@ -37,7 +37,7 @@ abstract class OrcaPerRequestUtil { private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER = new ClientStreamTracer() {}; private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY = - new ClientStreamTracer.Factory() { + new ClientStreamTracer.InternalLimitedInfoFactory() { @Override public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { return NOOP_CLIENT_STREAM_TRACER; @@ -189,7 +189,8 @@ public interface OrcaPerRequestReportListener { * per-request ORCA reports and push to registered listeners for calls they trace. */ @VisibleForTesting - static final class OrcaReportingTracerFactory extends ClientStreamTracer.Factory { + static final class OrcaReportingTracerFactory extends + ClientStreamTracer.InternalLimitedInfoFactory { @VisibleForTesting static final Metadata.Key ORCA_ENDPOINT_LOAD_METRICS_KEY = diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java index 48b4954767a..39f91b475ae 100644 --- a/xds/src/main/java/io/grpc/xds/RbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -28,6 +28,7 @@ import io.envoyproxy.envoy.config.rbac.v3.Principal; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; +import io.envoyproxy.envoy.type.v3.Int32Range; import io.grpc.Metadata; import io.grpc.ServerCall; import io.grpc.ServerCallHandler; @@ -45,6 +46,7 @@ import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthenticatedMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationIpMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortRangeMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.InvertMatcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.Matcher; import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.OrMatcher; @@ -216,6 +218,8 @@ private static Matcher parsePermission(Permission permission) { return createDestinationIpMatcher(permission.getDestinationIp()); case DESTINATION_PORT: return createDestinationPortMatcher(permission.getDestinationPort()); + case DESTINATION_PORT_RANGE: + return parseDestinationPortRangeMatcher(permission.getDestinationPortRange()); case NOT_RULE: return new InvertMatcher(parsePermission(permission.getNotRule())); case METADATA: // hard coded, never match. @@ -291,6 +295,14 @@ private static RequestedServerNameMatcher parseRequestedServerNameMatcher( private static AuthHeaderMatcher parseHeaderMatcher( io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { + if (proto.getName().startsWith("grpc-")) { + throw new IllegalArgumentException("Invalid header matcher config: [grpc-] prefixed " + + "header name is not allowed."); + } + if (":scheme".equals(proto.getName())) { + throw new IllegalArgumentException("Invalid header matcher config: header name [:scheme] " + + "is not allowed."); + } return new AuthHeaderMatcher(MatcherParser.parseHeaderMatcher(proto)); } @@ -304,6 +316,10 @@ private static DestinationPortMatcher createDestinationPortMatcher(int port) { return new DestinationPortMatcher(port); } + private static DestinationPortRangeMatcher parseDestinationPortRangeMatcher(Int32Range range) { + return new DestinationPortRangeMatcher(range.getStart(), range.getEnd()); + } + private static DestinationIpMatcher createDestinationIpMatcher(CidrRange cidrRange) { return new DestinationIpMatcher(Matchers.CidrMatcher.create( resolve(cidrRange), cidrRange.getPrefixLen().getValue())); diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index af613b26078..102a18000a9 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -17,6 +17,7 @@ package io.grpc.xds; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -45,7 +46,8 @@ public final class RingHashLoadBalancerProvider extends LoadBalancerProvider { static final long MAX_RING_SIZE = 8 * 1024 * 1024L; private static final boolean enableRingHash = - Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")); + Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")) + || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")); @Override public LoadBalancer newLoadBalancer(Helper helper) { diff --git a/xds/src/main/java/io/grpc/xds/RoutingUtils.java b/xds/src/main/java/io/grpc/xds/RoutingUtils.java new file mode 100644 index 00000000000..8bf879f43b0 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/RoutingUtils.java @@ -0,0 +1,219 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.common.base.Joiner; +import io.grpc.Metadata; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; +import java.util.List; +import java.util.Locale; +import javax.annotation.Nullable; + +/** + * Utilities for performing virtual host domain name matching and route matching. + */ +// TODO(chengyuanzhang): clean up implementations in XdsNameResolver. +final class RoutingUtils { + // Prevent instantiation. + private RoutingUtils() { + } + + /** + * Returns the {@link VirtualHost} with the best match domain for the given hostname. + */ + @Nullable + static VirtualHost findVirtualHostForHostName(List virtualHosts, String hostName) { + // Domain search order: + // 1. Exact domain names: ``www.foo.com``. + // 2. Suffix domain wildcards: ``*.foo.com`` or ``*-bar.foo.com``. + // 3. Prefix domain wildcards: ``foo.*`` or ``foo-*``. + // 4. Special wildcard ``*`` matching any domain. + // + // The longest wildcards match first. + // Assuming only a single virtual host in the entire route configuration can match + // on ``*`` and a domain must be unique across all virtual hosts. + int matchingLen = -1; // longest length of wildcard pattern that matches host name + boolean exactMatchFound = false; // true if a virtual host with exactly matched domain found + VirtualHost targetVirtualHost = null; // target VirtualHost with longest matched domain + for (VirtualHost vHost : virtualHosts) { + for (String domain : vHost.domains()) { + boolean selected = false; + if (matchHostName(hostName, domain)) { // matching + if (!domain.contains("*")) { // exact matching + exactMatchFound = true; + targetVirtualHost = vHost; + break; + } else if (domain.length() > matchingLen) { // longer matching pattern + selected = true; + } else if (domain.length() == matchingLen && domain.startsWith("*")) { // suffix matching + selected = true; + } + } + if (selected) { + matchingLen = domain.length(); + targetVirtualHost = vHost; + } + } + if (exactMatchFound) { + break; + } + } + return targetVirtualHost; + } + + /** + * Returns {@code true} iff {@code hostName} matches the domain name {@code pattern} with + * case-insensitive. + * + *

Wildcard pattern rules: + *

    + *
  1. A single asterisk (*) matches any domain.
  2. + *
  3. Asterisk (*) is only permitted in the left-most or the right-most part of the pattern, + * but not both.
  4. + *
+ */ + private static boolean matchHostName(String hostName, String pattern) { + checkArgument(hostName.length() != 0 && !hostName.startsWith(".") && !hostName.endsWith("."), + "Invalid host name"); + checkArgument(pattern.length() != 0 && !pattern.startsWith(".") && !pattern.endsWith("."), + "Invalid pattern/domain name"); + + hostName = hostName.toLowerCase(Locale.US); + pattern = pattern.toLowerCase(Locale.US); + // hostName and pattern are now in lower case -- domain names are case-insensitive. + + if (!pattern.contains("*")) { + // Not a wildcard pattern -- hostName and pattern must match exactly. + return hostName.equals(pattern); + } + // Wildcard pattern + + if (pattern.length() == 1) { + return true; + } + + int index = pattern.indexOf('*'); + + // At most one asterisk (*) is allowed. + if (pattern.indexOf('*', index + 1) != -1) { + return false; + } + + // Asterisk can only match prefix or suffix. + if (index != 0 && index != pattern.length() - 1) { + return false; + } + + // HostName must be at least as long as the pattern because asterisk has to + // match one or more characters. + if (hostName.length() < pattern.length()) { + return false; + } + + if (index == 0 && hostName.endsWith(pattern.substring(1))) { + // Prefix matching fails. + return true; + } + + // Pattern matches hostname if suffix matching succeeds. + return index == pattern.length() - 1 + && hostName.startsWith(pattern.substring(0, pattern.length() - 1)); + } + + /** + * Returns {@code true} iff the given {@link RouteMatch} matches the RPC's full method name and + * headers. + */ + static boolean matchRoute(RouteMatch routeMatch, String fullMethodName, + Metadata headers, ThreadSafeRandom random) { + if (!matchPath(routeMatch.pathMatcher(), fullMethodName)) { + return false; + } + for (HeaderMatcher headerMatcher : routeMatch.headerMatchers()) { + if (!matchHeader(headerMatcher, getHeaderValue(headers, headerMatcher.name()))) { + return false; + } + } + FractionMatcher fraction = routeMatch.fractionMatcher(); + return fraction == null || random.nextInt(fraction.denominator()) < fraction.numerator(); + } + + private static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) { + if (pathMatcher.path() != null) { + return pathMatcher.caseSensitive() + ? pathMatcher.path().equals(fullMethodName) + : pathMatcher.path().equalsIgnoreCase(fullMethodName); + } else if (pathMatcher.prefix() != null) { + return pathMatcher.caseSensitive() + ? fullMethodName.startsWith(pathMatcher.prefix()) + : fullMethodName.toLowerCase().startsWith(pathMatcher.prefix().toLowerCase()); + } + return pathMatcher.regEx().matches(fullMethodName); + } + + private static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) { + if (headerMatcher.present() != null) { + return (value == null) == headerMatcher.present().equals(headerMatcher.inverted()); + } + if (value == null) { + return false; + } + boolean baseMatch; + if (headerMatcher.exactValue() != null) { + baseMatch = headerMatcher.exactValue().equals(value); + } else if (headerMatcher.safeRegEx() != null) { + baseMatch = headerMatcher.safeRegEx().matches(value); + } else if (headerMatcher.range() != null) { + long numValue; + try { + numValue = Long.parseLong(value); + baseMatch = numValue >= headerMatcher.range().start() + && numValue <= headerMatcher.range().end(); + } catch (NumberFormatException ignored) { + baseMatch = false; + } + } else if (headerMatcher.prefix() != null) { + baseMatch = value.startsWith(headerMatcher.prefix()); + } else { + baseMatch = value.endsWith(headerMatcher.suffix()); + } + return baseMatch != headerMatcher.inverted(); + } + + @Nullable + private static String getHeaderValue(Metadata headers, String headerName) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return null; + } + if (headerName.equals("content-type")) { + return "application/grpc"; + } + Metadata.Key key; + try { + key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); + } catch (IllegalArgumentException e) { + return null; + } + Iterable values = headers.getAll(key); + return values == null ? null : Joiner.on(",").join(values); + } +} diff --git a/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java b/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java index 71cff0cedf3..7aa55c27429 100644 --- a/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java +++ b/xds/src/main/java/io/grpc/xds/SharedCallCounterMap.java @@ -58,8 +58,14 @@ public synchronized AtomicLong getOrCreate(String cluster, @Nullable String edsS counters.put(cluster, clusterCounters); } CounterReference ref = clusterCounters.get(edsServiceName); - AtomicLong counter; - if (ref == null || (counter = ref.get()) == null) { + AtomicLong counter = null; + if (ref != null) { + counter = ref.get(); + if (counter == null) { + ref.enqueue(); + } + } + if (counter == null) { counter = new AtomicLong(); ref = new CounterReference(counter, refQueue, cluster, edsServiceName); clusterCounters.put(edsServiceName, ref); @@ -73,6 +79,9 @@ void cleanQueue() { CounterReference ref; while ((ref = (CounterReference) refQueue.poll()) != null) { Map clusterCounter = counters.get(ref.cluster); + if (clusterCounter.get(ref.edsServiceName) != ref) { + continue; + } clusterCounter.remove(ref.edsServiceName); if (clusterCounter.isEmpty()) { counters.remove(ref.cluster); diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index cc87b9c6b6f..95eef3e3d80 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -52,7 +52,7 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { private final AtomicReference> bootstrapOverride = new AtomicReference<>(); private volatile ObjectPool xdsClientPool; - private SharedXdsClientPoolProvider() { + SharedXdsClientPoolProvider() { this(new BootstrapperImpl()); } diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java deleted file mode 100644 index c73c071d951..00000000000 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ /dev/null @@ -1,447 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Strings; -import com.google.common.collect.ImmutableSet; -import com.google.protobuf.UInt32Value; -import io.grpc.Internal; -import io.grpc.Status; -import io.grpc.internal.ObjectPool; -import io.grpc.xds.EnvoyServerProtoData.CidrRange; -import io.grpc.xds.EnvoyServerProtoData.FilterChain; -import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; -import io.grpc.xds.internal.Matchers.CidrMatcher; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.netty.channel.Channel; -import java.net.Inet6Address; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.annotation.Nullable; - -/** - * Serves as a wrapper for {@link XdsClient} used on the server side by {@link - * XdsServerBuilder}. - */ -@Internal -public final class XdsClientWrapperForServerSds { - private static final Logger logger = - Logger.getLogger(XdsClientWrapperForServerSds.class.getName()); - - private AtomicReference curListener = new AtomicReference<>(); - private ObjectPool xdsClientPool; - private final XdsNameResolverProvider.XdsClientPoolFactory xdsClientPoolFactory; - @Nullable private XdsClient xdsClient; - private final int port; - private XdsClient.LdsResourceWatcher listenerWatcher; - private boolean newServerApi; - private String grpcServerResourceId; - @VisibleForTesting final Set serverWatchers = new HashSet<>(); - - /** - * Creates a {@link XdsClientWrapperForServerSds}. - * - * @param port server's port for which listener config is needed. - */ - XdsClientWrapperForServerSds(int port) { - this(port, SharedXdsClientPoolProvider.getDefaultProvider()); - } - - @VisibleForTesting - XdsClientWrapperForServerSds(int port, - XdsNameResolverProvider.XdsClientPoolFactory xdsClientPoolFactory) { - this.port = port; - this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); - } - - @VisibleForTesting XdsClient getXdsClient() { - return xdsClient; - } - - public TlsContextManager getTlsContextManager() { - return xdsClient.getTlsContextManager(); - } - - /** Accepts an XdsClient and starts a watch. */ - @VisibleForTesting - public void start() { - try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(); - } catch (Exception e) { - reportError(e, true); - return; - } - xdsClient = xdsClientPool.getObject(); - this.listenerWatcher = - new XdsClient.LdsResourceWatcher() { - @Override - public void onChanged(XdsClient.LdsUpdate update) { - releaseOldSuppliers(curListener.getAndSet(update.listener())); - reportSuccess(); - } - - @Override - public void onResourceDoesNotExist(String resourceName) { - logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName); - releaseOldSuppliers(curListener.getAndSet(null)); - reportError(Status.NOT_FOUND.asException(), true); - } - - @Override - public void onError(Status error) { - logger.log( - Level.WARNING, "LdsResourceWatcher in XdsClientWrapperForServerSds: {0}", error); - if (isResourceAbsent(error)) { - releaseOldSuppliers(curListener.getAndSet(null)); - reportError(error.asException(), true); - } else { - reportError(error.asException(), false); - } - } - }; - grpcServerResourceId = xdsClient.getBootstrapInfo() - .getServerListenerResourceNameTemplate(); - newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); - if (newServerApi && grpcServerResourceId == null) { - reportError( - new XdsInitializationException( - "missing server_listener_resource_name_template value in xds bootstrap"), - true); - } - grpcServerResourceId = grpcServerResourceId.replaceAll("%s", "0.0.0.0:" + port); - xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher); - } - - // go thru the old listener and release all the old SslContextProviderSupplier - private void releaseOldSuppliers(EnvoyServerProtoData.Listener oldListener) { - if (oldListener != null) { - List filterChains = oldListener.getFilterChains(); - for (FilterChain filterChain : filterChains) { - releaseSupplier(filterChain); - } - releaseSupplier(oldListener.getDefaultFilterChain()); - } - } - - private static void releaseSupplier(FilterChain filterChain) { - if (filterChain != null) { - SslContextProviderSupplier sslContextProviderSupplier = - filterChain.getSslContextProviderSupplier(); - if (sslContextProviderSupplier != null) { - sslContextProviderSupplier.close(); - } - } - } - - /** Whether the throwable indicates our listener resource is absent/deleted. */ - private static boolean isResourceAbsent(Status status) { - Status.Code code = status.getCode(); - switch (code) { - case NOT_FOUND: - case INVALID_ARGUMENT: - case PERMISSION_DENIED: // means resource not available for us - case UNIMPLEMENTED: - case UNAUTHENTICATED: // same as above, resource not available for us - return true; - default: - return false; - } - } - - /** - * Locates the best matching FilterChain to the channel from the current listener and if found - * returns the SslContextProviderSupplier from that FilterChain, else null. - */ - @Nullable - public SslContextProviderSupplier getSslContextProviderSupplier(Channel channel) { - EnvoyServerProtoData.Listener copyListener = curListener.get(); - if (copyListener != null && channel != null) { - SocketAddress localAddress = channel.localAddress(); - SocketAddress remoteAddress = channel.remoteAddress(); - if (localAddress instanceof InetSocketAddress && remoteAddress instanceof InetSocketAddress) { - InetSocketAddress localInetAddr = (InetSocketAddress) localAddress; - InetSocketAddress remoteInetAddr = (InetSocketAddress) remoteAddress; - checkState( - port == localInetAddr.getPort(), - "Channel localAddress port does not match requested listener port"); - return getSslContextProviderSupplier(localInetAddr, remoteInetAddr, copyListener); - } - } - return null; - } - - /** - * Using the logic specified at - * https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/listener/listener_components.proto.html?highlight=filter%20chain#listener-filterchainmatch - * locate a matching filter and return the corresponding SslContextProviderSupplier or else - * return one from default filter chain. - * - * @param localInetAddr dest address of the inbound connection - * @param remoteInetAddr source address of the inbound connection - */ - private static SslContextProviderSupplier getSslContextProviderSupplier( - InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr, - EnvoyServerProtoData.Listener listener) { - List filterChains = listener.getFilterChains(); - - filterChains = filterOnDestinationPort(filterChains); - filterChains = filterOnIpAddress(filterChains, localInetAddr.getAddress(), true); - filterChains = filterOnServerNames(filterChains); - filterChains = filterOnTransportProtocol(filterChains); - filterChains = filterOnApplicationProtocols(filterChains); - filterChains = - filterOnSourceType(filterChains, remoteInetAddr.getAddress(), localInetAddr.getAddress()); - filterChains = filterOnIpAddress(filterChains, remoteInetAddr.getAddress(), false); - filterChains = filterOnSourcePort(filterChains, remoteInetAddr.getPort()); - - if (filterChains.size() > 1) { - // close the connection - throw new IllegalStateException("Found 2 matching filter-chains"); - } else if (filterChains.size() == 1) { - return filterChains.get(0).getSslContextProviderSupplier(); - } - if (listener.getDefaultFilterChain() == null) { - // close the connection - throw new RuntimeException( - "no matching filter chain. local: " + localInetAddr + " remote: " + remoteInetAddr); - } - return listener.getDefaultFilterChain().getSslContextProviderSupplier(); - } - - // reject if filer-chain-match has non-empty application_protocols - private static List filterOnApplicationProtocols(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - if (filterChainMatch.getApplicationProtocols().isEmpty()) { - filtered.add(filterChain); - } - } - return filtered; - } - - // reject if filer-chain-match has non-empty transport protocol other than "raw_buffer" - private static List filterOnTransportProtocol(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - String transportProtocol = filterChainMatch.getTransportProtocol(); - if ( Strings.isNullOrEmpty(transportProtocol) || "raw_buffer".equals(transportProtocol)) { - filtered.add(filterChain); - } - } - return filtered; - } - - // reject if filer-chain-match has server_name(s) - private static List filterOnServerNames(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - if (filterChainMatch.getServerNames().isEmpty()) { - filtered.add(filterChain); - } - } - return filtered; - } - - // destination_port present => Always fail match - private static List filterOnDestinationPort(List filterChains) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - if (filterChainMatch.getDestinationPort() == UInt32Value.getDefaultInstance().getValue()) { - filtered.add(filterChain); - } - } - return filtered; - } - - private static List filterOnSourcePort( - List filterChains, int sourcePort) { - ArrayList filteredOnMatch = new ArrayList<>(filterChains.size()); - ArrayList filteredOnEmpty = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - - List sourcePortsToMatch = filterChainMatch.getSourcePorts(); - if (sourcePortsToMatch.isEmpty()) { - filteredOnEmpty.add(filterChain); - } else if (sourcePortsToMatch.contains(sourcePort)) { - filteredOnMatch.add(filterChain); - } - } - // match against source port is more specific than match against empty list - return filteredOnMatch.isEmpty() ? filteredOnEmpty : filteredOnMatch; - } - - private static List filterOnSourceType( - List filterChains, InetAddress sourceAddress, InetAddress destAddress) { - ArrayList filtered = new ArrayList<>(filterChains.size()); - for (FilterChain filterChain : filterChains) { - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - EnvoyServerProtoData.ConnectionSourceType sourceType = - filterChainMatch.getConnectionSourceType(); - - boolean matching = false; - if (sourceType == EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK) { - matching = - sourceAddress.isLoopbackAddress() - || sourceAddress.isAnyLocalAddress() - || sourceAddress.equals(destAddress); - } else if (sourceType == EnvoyServerProtoData.ConnectionSourceType.EXTERNAL) { - matching = !sourceAddress.isLoopbackAddress() && !sourceAddress.isAnyLocalAddress(); - } else { // ANY or null - matching = true; - } - if (matching) { - filtered.add(filterChain); - } - } - return filtered; - } - - private static int getMatchingPrefixLength( - FilterChainMatch filterChainMatch, InetAddress address, boolean forDestination) { - boolean isIPv6 = address instanceof Inet6Address; - List cidrRanges = - forDestination - ? filterChainMatch.getPrefixRanges() - : filterChainMatch.getSourcePrefixRanges(); - int matchingPrefixLength; - if (cidrRanges.isEmpty()) { // if there is no CidrRange assume 0-length match - matchingPrefixLength = 0; - } else { - matchingPrefixLength = -1; - for (CidrRange cidrRange : cidrRanges) { - InetAddress cidrAddr = cidrRange.getAddressPrefix(); - boolean cidrIsIpv6 = cidrAddr instanceof Inet6Address; - if (isIPv6 == cidrIsIpv6) { - int prefixLen = cidrRange.getPrefixLen(); - CidrMatcher matcher = CidrMatcher.create(cidrAddr, prefixLen); - if (matcher.matches(address) && prefixLen > matchingPrefixLength) { - matchingPrefixLength = prefixLen; - } - } - } - } - return matchingPrefixLength; - } - - // use prefix_ranges (CIDR) and get the most specific matches - private static List filterOnIpAddress( - List filterChains, InetAddress address, boolean forDestination) { - // curent list of top ones - ArrayList topOnes = new ArrayList<>(filterChains.size()); - int topMatchingPrefixLen = -1; - for (FilterChain filterChain : filterChains) { - int currentMatchingPrefixLen = - getMatchingPrefixLength(filterChain.getFilterChainMatch(), address, forDestination); - - if (currentMatchingPrefixLen >= 0) { - if (currentMatchingPrefixLen < topMatchingPrefixLen) { - continue; - } - if (currentMatchingPrefixLen > topMatchingPrefixLen) { - topMatchingPrefixLen = currentMatchingPrefixLen; - topOnes.clear(); - } - topOnes.add(filterChain); - } - } - return topOnes; - } - - /** Adds a {@link ServerWatcher} to the list. */ - public void addServerWatcher(ServerWatcher serverWatcher) { - checkNotNull(serverWatcher, "serverWatcher"); - synchronized (serverWatchers) { - serverWatchers.add(serverWatcher); - } - EnvoyServerProtoData.Listener copyListener = curListener.get(); - if (copyListener != null) { - serverWatcher.onListenerUpdate(); - } - } - - /** Removes a {@link ServerWatcher} from the list. */ - public void removeServerWatcher(ServerWatcher serverWatcher) { - checkNotNull(serverWatcher, "serverWatcher"); - synchronized (serverWatchers) { - serverWatchers.remove(serverWatcher); - } - } - - private Set getServerWatchers() { - synchronized (serverWatchers) { - return ImmutableSet.copyOf(serverWatchers); - } - } - - private void reportError(Throwable throwable, boolean isAbsent) { - for (ServerWatcher watcher : getServerWatchers()) { - watcher.onError(throwable, isAbsent); - } - } - - private void reportSuccess() { - for (ServerWatcher watcher : getServerWatchers()) { - watcher.onListenerUpdate(); - } - } - - @VisibleForTesting - public XdsClient.LdsResourceWatcher getListenerWatcher() { - return listenerWatcher; - } - - /** Watcher interface for the clients of this class. */ - public interface ServerWatcher { - - /** Called to report errors from the control plane including "not found". */ - void onError(Throwable throwable, boolean isAbsent); - - /** Called to report successful receipt of listener config. */ - void onListenerUpdate(); - } - - /** Shutdown this instance and release resources. */ - public void shutdown() { - logger.log(Level.FINER, "Shutdown"); - if (xdsClient != null) { - xdsClient.cancelLdsResourceWatch(grpcServerResourceId, listenerWatcher); - xdsClient = xdsClientPool.returnObject(xdsClient); - } - releaseOldSuppliers(curListener.getAndSet(null)); - } -} diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 3ae6346c158..4cd52c8b3f9 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -23,7 +23,6 @@ import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import com.google.gson.Gson; import com.google.protobuf.util.Durations; @@ -95,8 +94,6 @@ final class XdsNameResolver extends NameResolver { CallOptions.Key.create("io.grpc.xds.CLUSTER_SELECTION_KEY"); static final CallOptions.Key RPC_HASH_KEY = CallOptions.Key.create("io.grpc.xds.RPC_HASH_KEY"); - private static final NamedFilterConfig LAME_FILTER = - new NamedFilterConfig(null, LameFilter.LAME_CONFIG); @VisibleForTesting static boolean enableTimeout = Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT")) @@ -124,22 +121,25 @@ final class XdsNameResolver extends NameResolver { private ResolveState resolveState; XdsNameResolver(String name, ServiceConfigParser serviceConfigParser, - SynchronizationContext syncContext, ScheduledExecutorService scheduler) { + SynchronizationContext syncContext, ScheduledExecutorService scheduler, + @Nullable Map bootstrapOverride) { this(name, serviceConfigParser, syncContext, scheduler, SharedXdsClientPoolProvider.getDefaultProvider(), ThreadSafeRandomImpl.instance, - FilterRegistry.getDefaultRegistry()); + FilterRegistry.getDefaultRegistry(), bootstrapOverride); } @VisibleForTesting XdsNameResolver(String name, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, XdsClientPoolFactory xdsClientPoolFactory, ThreadSafeRandom random, - FilterRegistry filterRegistry) { + FilterRegistry filterRegistry, @Nullable Map bootstrapOverride) { authority = GrpcUtil.checkAuthority(checkNotNull(name, "name")); this.serviceConfigParser = checkNotNull(serviceConfigParser, "serviceConfigParser"); this.syncContext = checkNotNull(syncContext, "syncContext"); this.scheduler = checkNotNull(scheduler, "scheduler"); - this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.xdsClientPoolFactory = bootstrapOverride == null ? checkNotNull(xdsClientPoolFactory, + "xdsClientPoolFactory") : new SharedXdsClientPoolProvider(); + this.xdsClientPoolFactory.setBootstrapOverride(bootstrapOverride); this.random = checkNotNull(random, "random"); this.filterRegistry = checkNotNull(filterRegistry, "filterRegistry"); logId = InternalLogId.allocate("xds-resolver", name); @@ -182,13 +182,14 @@ public void shutdown() { @VisibleForTesting static Map generateServiceConfigWithMethodConfig( @Nullable Long timeoutNano, @Nullable RetryPolicy retryPolicy) { - if (timeoutNano == null && retryPolicy == null) { + if (timeoutNano == null + && (retryPolicy == null || retryPolicy.retryableStatusCodes().isEmpty())) { return Collections.emptyMap(); } ImmutableMap.Builder methodConfig = ImmutableMap.builder(); methodConfig.put( "name", Collections.singletonList(Collections.emptyMap())); - if (retryPolicy != null) { + if (retryPolicy != null && !retryPolicy.retryableStatusCodes().isEmpty()) { ImmutableMap.Builder rawRetryPolicy = ImmutableMap.builder(); rawRetryPolicy.put("maxAttempts", (double) retryPolicy.maxAttempts()); rawRetryPolicy.put("initialBackoff", Durations.toString(retryPolicy.initialBackoff())); @@ -371,10 +372,6 @@ public Result selectConfig(PickSubchannelArgs args) { do { routingCfg = routingConfig; selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig); - if (routingCfg.filterChain != null - && Iterables.getLast(routingCfg.filterChain).equals(LAME_FILTER)) { - break; - } for (Route route : routingCfg.routes) { if (matchRoute(route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(), headers, random)) { @@ -439,12 +436,7 @@ public Result selectConfig(PickSubchannelArgs args) { if (routingCfg.filterChain != null) { for (NamedFilterConfig namedFilter : routingCfg.filterChain) { FilterConfig filterConfig = namedFilter.filterConfig; - Filter filter; - if (namedFilter.equals(LAME_FILTER)) { - filter = LameFilter.INSTANCE; - } else { - filter = filterRegistry.get(filterConfig.typeUrl()); - } + Filter filter = filterRegistry.get(filterConfig.typeUrl()); if (filter instanceof ClientInterceptorBuilder) { ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter) .buildClientInterceptor( @@ -455,12 +447,6 @@ public Result selectConfig(PickSubchannelArgs args) { } } } - if (Iterables.getLast(routingCfg.filterChain).equals(LAME_FILTER)) { - return Result.newBuilder() - .setConfig(config) - .setInterceptor(combineInterceptors(filterInterceptors)) - .build(); - } } final String finalCluster = cluster; final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers); @@ -591,7 +577,7 @@ static boolean matchRoute(RouteMatch routeMatch, String fullMethodName, return false; } for (HeaderMatcher headerMatcher : routeMatch.headerMatchers()) { - if (!matchHeader(headerMatcher, getHeaderValue(headers, headerMatcher.name()))) { + if (!headerMatcher.matches(getHeaderValue(headers, headerMatcher.name()))) { return false; } } @@ -612,36 +598,6 @@ private static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) return pathMatcher.regEx().matches(fullMethodName); } - // TODO(zivy): consider reuse Matchers.HeaderMatcher.matches() - private static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) { - if (headerMatcher.present() != null) { - return (value == null) == headerMatcher.present().equals(headerMatcher.inverted()); - } - if (value == null) { - return false; - } - boolean baseMatch; - if (headerMatcher.exactValue() != null) { - baseMatch = headerMatcher.exactValue().equals(value); - } else if (headerMatcher.safeRegEx() != null) { - baseMatch = headerMatcher.safeRegEx().matches(value); - } else if (headerMatcher.range() != null) { - long numValue; - try { - numValue = Long.parseLong(value); - baseMatch = numValue >= headerMatcher.range().start() - && numValue <= headerMatcher.range().end(); - } catch (NumberFormatException ignored) { - baseMatch = false; - } - } else if (headerMatcher.prefix() != null) { - baseMatch = value.startsWith(headerMatcher.prefix()); - } else { - baseMatch = value.endsWith(headerMatcher.suffix()); - } - return baseMatch != headerMatcher.inverted(); - } - @Nullable private static String getHeaderValue(Metadata headers, String headerName) { if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { @@ -751,27 +707,7 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura return; } - // A router filter is required for request routing. For backward compatibility, routing - // is always enabled for gRPC clients without HttpFilter support. List routes = virtualHost.routes(); - List filterChain = null; - if (filterConfigs != null) { - boolean hasRouter = false; - filterChain = new ArrayList<>(filterConfigs.size()); - for (NamedFilterConfig namedFilter : filterConfigs) { - filterChain.add(namedFilter); - if (namedFilter.filterConfig.equals(RouterFilter.ROUTER_CONFIG)) { - hasRouter = true; - break; - } - } - if (!hasRouter) { - // Fail all RPCs if a router filter is not present. Reference counts for all currently - // selectable clusters should be reclaimed. - filterChain.add(LAME_FILTER); - routes = Collections.emptyList(); - } - } // Populate all clusters to which requests can be routed to through the virtual host. Set clusters = new HashSet<>(); @@ -812,7 +748,7 @@ private void updateRoutes(List virtualHosts, long httpMaxStreamDura // selectable. routingConfig = new RoutingConfig( - httpMaxStreamDurationNano, routes, filterChain, + httpMaxStreamDurationNano, routes, filterConfigs, virtualHost.filterConfigOverrides()); shouldUpdateResult = false; for (String cluster : deletedClusters) { diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index 40aa4f919e9..03d88a9752e 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -42,10 +42,31 @@ public final class XdsNameResolverProvider extends NameResolverProvider { private static final String SCHEME = "xds"; + private final String scheme; + private final Map bootstrapOverride; + + public XdsNameResolverProvider() { + this(SCHEME, null); + } + + private XdsNameResolverProvider(String scheme, + @Nullable Map bootstrapOverride) { + this.scheme = checkNotNull(scheme, "scheme"); + this.bootstrapOverride = bootstrapOverride; + } + + /** + * A convenient method to allow creating a {@link XdsNameResolverProvider} with custom scheme + * and bootstrap. + */ + public static XdsNameResolverProvider createForTest(String scheme, + @Nullable Map bootstrapOverride) { + return new XdsNameResolverProvider(scheme, bootstrapOverride); + } @Override public XdsNameResolver newNameResolver(URI targetUri, Args args) { - if (SCHEME.equals(targetUri.getScheme())) { + if (scheme.equals(targetUri.getScheme())) { String targetPath = checkNotNull(targetUri.getPath(), "targetPath"); Preconditions.checkArgument( targetPath.startsWith("/"), @@ -54,14 +75,15 @@ public XdsNameResolver newNameResolver(URI targetUri, Args args) { targetUri); String name = targetPath.substring(1); return new XdsNameResolver(name, args.getServiceConfigParser(), - args.getSynchronizationContext(), args.getScheduledExecutorService()); + args.getSynchronizationContext(), args.getScheduledExecutorService(), + bootstrapOverride); } return null; } @Override public String getDefaultScheme() { - return SCHEME; + return scheme; } @Override diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index d201c565caa..c95c1e6d48f 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -16,8 +16,11 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.xds.InternalXdsAttributes.ATTR_DRAIN_GRACE_NANOS; +import static io.grpc.xds.InternalXdsAttributes.ATTR_FILTER_CHAIN_SELECTOR_MANAGER; import com.google.common.annotations.VisibleForTesting; import com.google.errorprone.annotations.DoNotCall; @@ -29,9 +32,12 @@ import io.grpc.ServerBuilder; import io.grpc.ServerCredentials; import io.grpc.netty.InternalNettyServerBuilder; +import io.grpc.netty.InternalNettyServerCredentials; +import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.NettyServerBuilder; -import io.grpc.xds.internal.sds.SdsProtocolNegotiators; -import io.grpc.xds.internal.sds.ServerWrapperForXds; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingNegotiatorServerFactory; +import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; @@ -40,11 +46,17 @@ */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/7514") public final class XdsServerBuilder extends ForwardingServerBuilder { + private static final long AS_LARGE_AS_INFINITE = TimeUnit.DAYS.toNanos(1000); private final NettyServerBuilder delegate; private final int port; private XdsServingStatusListener xdsServingStatusListener; private AtomicBoolean isServerBuilt = new AtomicBoolean(false); + private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); + private XdsClientPoolFactory xdsClientPoolFactory = + SharedXdsClientPoolProvider.getDefaultProvider(); + private long drainGraceTime = 10; + private TimeUnit drainGraceTimeUnit = TimeUnit.MINUTES; private XdsServerBuilder(NettyServerBuilder nettyDelegate, int port) { this.delegate = nettyDelegate; @@ -68,36 +80,60 @@ public XdsServerBuilder xdsServingStatusListener( } /** - * Unsupported call. Users should only use {@link #forPort(int, ServerCredentials)}. + * Sets the grace time when draining connections with outdated configuration. When an xDS config + * update changes connection configuration, pre-existing connections stop accepting new RPCs to be + * replaced by new connections. RPCs on those pre-existing connections have the grace time to + * complete. RPCs that do not complete in time will be cancelled, allowing the connection to + * terminate. {@code Long.MAX_VALUE} nano seconds or an unreasonably large value are considered + * infinite. The default is 10 minutes. */ + public XdsServerBuilder drainGraceTime(long drainGraceTime, TimeUnit drainGraceTimeUnit) { + checkArgument(drainGraceTime >= 0, "drain grace time must be non-negative: %s", + drainGraceTime); + checkNotNull(drainGraceTimeUnit, "drainGraceTimeUnit"); + if (drainGraceTimeUnit.toNanos(drainGraceTime) >= AS_LARGE_AS_INFINITE) { + drainGraceTimeUnit = null; + } + this.drainGraceTime = drainGraceTime; + this.drainGraceTimeUnit = drainGraceTimeUnit; + return this; + } + @DoNotCall("Unsupported. Use forPort(int, ServerCredentials) instead") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException( - "Unsupported call - use forPort(int, ServerCredentials)"); + "Unsupported call - use forPort(int, ServerCredentials)"); } /** Creates a gRPC server builder for the given port. */ public static XdsServerBuilder forPort(int port, ServerCredentials serverCredentials) { - NettyServerBuilder nettyDelegate = NettyServerBuilder.forPort(port, serverCredentials); + checkNotNull(serverCredentials, "serverCredentials"); + InternalProtocolNegotiator.ServerFactory originalNegotiatorFactory = + InternalNettyServerCredentials.toNegotiator(serverCredentials); + ServerCredentials wrappedCredentials = InternalNettyServerCredentials.create( + new FilterChainMatchingNegotiatorServerFactory(originalNegotiatorFactory)); + NettyServerBuilder nettyDelegate = NettyServerBuilder.forPort(port, wrappedCredentials); return new XdsServerBuilder(nettyDelegate, port); } @Override public Server build() { - return buildServer(new XdsClientWrapperForServerSds(port)); + checkState(isServerBuilt.compareAndSet(false, true), "Server already built!"); + FilterChainSelectorManager filterChainSelectorManager = new FilterChainSelectorManager(); + Attributes.Builder builder = Attributes.newBuilder() + .set(ATTR_FILTER_CHAIN_SELECTOR_MANAGER, filterChainSelectorManager); + if (drainGraceTimeUnit != null) { + builder.set(ATTR_DRAIN_GRACE_NANOS, drainGraceTimeUnit.toNanos(drainGraceTime)); + } + InternalNettyServerBuilder.eagAttributes(delegate, builder.build()); + return new XdsServerWrapper("0.0.0.0:" + port, delegate, xdsServingStatusListener, + filterChainSelectorManager, xdsClientPoolFactory, filterRegistry); } - /** - * Creates a Server using the given xdsClient. - */ @VisibleForTesting - ServerWrapperForXds buildServer( - XdsClientWrapperForServerSds xdsClient) { - checkState(isServerBuilt.compareAndSet(false, true), "Server already built!"); - InternalNettyServerBuilder.eagAttributes(delegate, Attributes.newBuilder() - .set(SdsProtocolNegotiators.SERVER_XDS_CLIENT, xdsClient) - .build()); - return new ServerWrapperForXds(delegate, xdsClient, xdsServingStatusListener); + XdsServerBuilder xdsClientPoolFactory(XdsClientPoolFactory xdsClientPoolFactory) { + this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + return this; } /** diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java new file mode 100644 index 00000000000..5f7cc43d670 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -0,0 +1,811 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import com.google.auto.value.AutoValue; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; +import io.grpc.InternalServerInterceptors; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.SynchronizationContext; +import io.grpc.SynchronizationContext.ScheduledHandle; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourceHolder; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.Filter.ServerInterceptorBuilder; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.XdsClient.LdsResourceWatcher; +import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsClient.RdsResourceWatcher; +import io.grpc.xds.XdsClient.RdsUpdate; +import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +final class XdsServerWrapper extends Server { + private static final Logger logger = Logger.getLogger(XdsServerWrapper.class.getName()); + + private final SynchronizationContext syncContext = new SynchronizationContext( + new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.log(Level.SEVERE, "Exception!" + e); + // TODO(chengyuanzhang): implement cleanup. + } + }); + + public static final Attributes.Key> + ATTR_SERVER_ROUTING_CONFIG = + Attributes.Key.create("io.grpc.xds.ServerWrapper.serverRoutingConfig"); + + @VisibleForTesting + static final long RETRY_DELAY_NANOS = TimeUnit.MINUTES.toNanos(1); + private final String listenerAddress; + private final ServerBuilder delegateBuilder; + private boolean sharedTimeService; + private final ScheduledExecutorService timeService; + private final FilterRegistry filterRegistry; + private final ThreadSafeRandom random = ThreadSafeRandomImpl.instance; + private final XdsClientPoolFactory xdsClientPoolFactory; + private final XdsServingStatusListener listener; + private final FilterChainSelectorManager filterChainSelectorManager; + private final AtomicBoolean started = new AtomicBoolean(false); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private boolean isServing; + private final CountDownLatch internalTerminationLatch = new CountDownLatch(1); + private final SettableFuture initialStartFuture = SettableFuture.create(); + private boolean initialStarted; + private ScheduledHandle restartTimer; + private ObjectPool xdsClientPool; + private XdsClient xdsClient; + private DiscoveryState discoveryState; + private volatile Server delegate; + + XdsServerWrapper( + String listenerAddress, + ServerBuilder delegateBuilder, + XdsServingStatusListener listener, + FilterChainSelectorManager filterChainSelectorManager, + XdsClientPoolFactory xdsClientPoolFactory, + FilterRegistry filterRegistry) { + this(listenerAddress, delegateBuilder, listener, filterChainSelectorManager, + xdsClientPoolFactory, filterRegistry, SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE)); + sharedTimeService = true; + } + + @VisibleForTesting + XdsServerWrapper( + String listenerAddress, + ServerBuilder delegateBuilder, + XdsServingStatusListener listener, + FilterChainSelectorManager filterChainSelectorManager, + XdsClientPoolFactory xdsClientPoolFactory, + FilterRegistry filterRegistry, + ScheduledExecutorService timeService) { + this.listenerAddress = checkNotNull(listenerAddress, "listenerAddress"); + this.delegateBuilder = checkNotNull(delegateBuilder, "delegateBuilder"); + this.delegateBuilder.intercept(new ConfigApplyingInterceptor()); + this.listener = checkNotNull(listener, "listener"); + this.filterChainSelectorManager + = checkNotNull(filterChainSelectorManager, "filterChainSelectorManager"); + this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.timeService = checkNotNull(timeService, "timeService"); + this.filterRegistry = checkNotNull(filterRegistry,"filterRegistry"); + this.delegate = delegateBuilder.build(); + } + + @Override + public Server start() throws IOException { + checkState(started.compareAndSet(false, true), "Already started"); + syncContext.execute(new Runnable() { + @Override + public void run() { + internalStart(); + } + }); + Exception exception; + try { + exception = initialStartFuture.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + if (exception != null) { + throw (exception instanceof IOException) ? (IOException) exception : + new IOException(exception); + } + return this; + } + + private void internalStart() { + try { + xdsClientPool = xdsClientPoolFactory.getOrCreate(); + } catch (Exception e) { + StatusException statusException = Status.UNAVAILABLE.withDescription( + "Failed to initialize xDS").withCause(e).asException(); + listener.onNotServing(statusException); + initialStartFuture.set(statusException); + return; + } + xdsClient = xdsClientPool.getObject(); + boolean useProtocolV3 = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); + String listenerTemplate = xdsClient.getBootstrapInfo().getServerListenerResourceNameTemplate(); + if (!useProtocolV3 || listenerTemplate == null) { + StatusException statusException = + Status.UNAVAILABLE.withDescription( + "Can only support xDS v3 with listener resource name template").asException(); + listener.onNotServing(statusException); + initialStartFuture.set(statusException); + xdsClient = xdsClientPool.returnObject(xdsClient); + return; + } + discoveryState = new DiscoveryState(listenerTemplate.replaceAll("%s", listenerAddress)); + } + + @Override + public Server shutdown() { + if (!shutdown.compareAndSet(false, true)) { + return this; + } + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!delegate.isShutdown()) { + delegate.shutdown(); + } + internalShutdown(); + } + }); + return this; + } + + @Override + public Server shutdownNow() { + if (!shutdown.compareAndSet(false, true)) { + return this; + } + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!delegate.isShutdown()) { + delegate.shutdownNow(); + } + internalShutdown(); + } + }); + return this; + } + + // Must run in SynchronizationContext + private void internalShutdown() { + logger.log(Level.FINER, "Shutting down XdsServerWrapper"); + if (discoveryState != null) { + discoveryState.shutdown(); + } + if (xdsClient != null) { + xdsClient = xdsClientPool.returnObject(xdsClient); + } + if (restartTimer != null) { + restartTimer.cancel(); + } + if (sharedTimeService) { + SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timeService); + } + isServing = false; + internalTerminationLatch.countDown(); + } + + @Override + public boolean isShutdown() { + return shutdown.get(); + } + + @Override + public boolean isTerminated() { + return internalTerminationLatch.getCount() == 0 && delegate.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + long startTime = System.nanoTime(); + if (!internalTerminationLatch.await(timeout, unit)) { + return false; + } + long remainingTime = unit.toNanos(timeout) - (System.nanoTime() - startTime); + return delegate.awaitTermination(remainingTime, TimeUnit.NANOSECONDS); + } + + @Override + public void awaitTermination() throws InterruptedException { + internalTerminationLatch.await(); + delegate.awaitTermination(); + } + + @Override + public int getPort() { + return delegate.getPort(); + } + + @Override + public List getListenSockets() { + return delegate.getListenSockets(); + } + + @Override + public List getServices() { + return delegate.getServices(); + } + + @Override + public List getImmutableServices() { + return delegate.getImmutableServices(); + } + + @Override + public List getMutableServices() { + return delegate.getMutableServices(); + } + + // Must run in SynchronizationContext + private void startDelegateServer() { + if (restartTimer != null && restartTimer.isPending()) { + return; + } + if (isServing) { + return; + } + if (delegate.isShutdown()) { + delegate = delegateBuilder.build(); + } + try { + delegate.start(); + listener.onServing(); + isServing = true; + if (!initialStarted) { + initialStarted = true; + initialStartFuture.set(null); + } + } catch (IOException e) { + logger.log(Level.FINE, "Fail to start delegate server: {0}", e); + if (!initialStarted) { + initialStarted = true; + initialStartFuture.set(e); + } + restartTimer = syncContext.schedule( + new RestartTask(), RETRY_DELAY_NANOS, TimeUnit.NANOSECONDS, timeService); + } + } + + private final class RestartTask implements Runnable { + @Override + public void run() { + startDelegateServer(); + } + } + + private final class DiscoveryState implements LdsResourceWatcher { + private final String resourceName; + // RDS resource name is the key. + private final Map routeDiscoveryStates = new HashMap<>(); + // Track pending RDS resources using rds name. + private final Set pendingRds = new HashSet<>(); + // Most recently discovered filter chains. + private List filterChains = new ArrayList<>(); + // Most recently discovered default filter chain. + @Nullable + private FilterChain defaultFilterChain; + private boolean stopped; + private final Map> savedRdsRoutingConfigRef + = new HashMap<>(); + private final ServerInterceptor noopInterceptor = new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + return next.startCall(call, headers); + } + }; + + private DiscoveryState(String resourceName) { + this.resourceName = checkNotNull(resourceName, "resourceName"); + xdsClient.watchLdsResource(resourceName, this); + } + + @Override + public void onChanged(final LdsUpdate update) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (stopped) { + return; + } + checkNotNull(update.listener(), "update"); + if (!pendingRds.isEmpty()) { + // filter chain state has not yet been applied to filterChainSelectorManager and there + // are two sets of sslContextProviderSuppliers, so we release the old ones. + releaseSuppliersInFlight(); + pendingRds.clear(); + } + filterChains = update.listener().getFilterChains(); + defaultFilterChain = update.listener().getDefaultFilterChain(); + List allFilterChains = filterChains; + if (defaultFilterChain != null) { + allFilterChains = new ArrayList<>(filterChains); + allFilterChains.add(defaultFilterChain); + } + Set allRds = new HashSet<>(); + for (FilterChain filterChain : allFilterChains) { + HttpConnectionManager hcm = filterChain.getHttpConnectionManager(); + if (hcm.virtualHosts() == null) { + RouteDiscoveryState rdsState = routeDiscoveryStates.get(hcm.rdsName()); + if (rdsState == null) { + rdsState = new RouteDiscoveryState(hcm.rdsName()); + routeDiscoveryStates.put(hcm.rdsName(), rdsState); + xdsClient.watchRdsResource(hcm.rdsName(), rdsState); + } + if (rdsState.isPending) { + pendingRds.add(hcm.rdsName()); + } + allRds.add(hcm.rdsName()); + } + } + for (Map.Entry entry: routeDiscoveryStates.entrySet()) { + if (!allRds.contains(entry.getKey())) { + xdsClient.cancelRdsResourceWatch(entry.getKey(), entry.getValue()); + } + } + routeDiscoveryStates.keySet().retainAll(allRds); + if (pendingRds.isEmpty()) { + updateSelector(); + } + } + }); + } + + @Override + public void onResourceDoesNotExist(final String resourceName) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (stopped) { + return; + } + StatusException statusException = Status.UNAVAILABLE.withDescription( + "Listener " + resourceName + " unavailable").asException(); + handleConfigNotFound(statusException); + } + }); + } + + @Override + public void onError(final Status error) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (stopped) { + return; + } + boolean isPermanentError = isPermanentError(error); + logger.log(Level.FINE, "{0} error from XdsClient: {1}", + new Object[]{isPermanentError ? "Permanent" : "Transient", error}); + if (isPermanentError) { + handleConfigNotFound(error.asException()); + } else if (!isServing) { + listener.onNotServing(error.asException()); + } + } + }); + } + + private void shutdown() { + stopped = true; + cleanUpRouteDiscoveryStates(); + logger.log(Level.FINE, "Stop watching LDS resource {0}", resourceName); + xdsClient.cancelLdsResourceWatch(resourceName, this); + List toRelease = getSuppliersInUse(); + filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); + for (SslContextProviderSupplier s: toRelease) { + s.close(); + } + releaseSuppliersInFlight(); + } + + private void updateSelector() { + Map> filterChainRouting = new HashMap<>(); + savedRdsRoutingConfigRef.clear(); + for (FilterChain filterChain: filterChains) { + filterChainRouting.put(filterChain, generateRoutingConfig(filterChain)); + } + FilterChainSelector selector = new FilterChainSelector( + Collections.unmodifiableMap(filterChainRouting), + defaultFilterChain == null ? null : defaultFilterChain.getSslContextProviderSupplier(), + defaultFilterChain == null ? new AtomicReference() : + generateRoutingConfig(defaultFilterChain)); + List toRelease = getSuppliersInUse(); + filterChainSelectorManager.updateSelector(selector); + for (SslContextProviderSupplier e: toRelease) { + e.close(); + } + startDelegateServer(); + } + + private AtomicReference generateRoutingConfig(FilterChain filterChain) { + HttpConnectionManager hcm = filterChain.getHttpConnectionManager(); + if (hcm.virtualHosts() != null) { + ImmutableMap interceptors = generatePerRouteInterceptors( + hcm.httpFilterConfigs(), hcm.virtualHosts()); + return new AtomicReference<>(ServerRoutingConfig.create(hcm.virtualHosts(),interceptors)); + } else { + RouteDiscoveryState rds = routeDiscoveryStates.get(hcm.rdsName()); + checkNotNull(rds, "rds"); + AtomicReference serverRoutingConfigRef = new AtomicReference<>(); + if (rds.savedVirtualHosts != null) { + ImmutableMap interceptors = generatePerRouteInterceptors( + hcm.httpFilterConfigs(), rds.savedVirtualHosts); + ServerRoutingConfig serverRoutingConfig = + ServerRoutingConfig.create(rds.savedVirtualHosts, interceptors); + serverRoutingConfigRef.set(serverRoutingConfig); + } else { + serverRoutingConfigRef.set(ServerRoutingConfig.FAILING_ROUTING_CONFIG); + } + savedRdsRoutingConfigRef.put(filterChain, serverRoutingConfigRef); + return serverRoutingConfigRef; + } + } + + private ImmutableMap generatePerRouteInterceptors( + List namedFilterConfigs, List virtualHosts) { + ImmutableMap.Builder perRouteInterceptors = + new ImmutableMap.Builder<>(); + for (VirtualHost virtualHost : virtualHosts) { + for (Route route : virtualHost.routes()) { + List filterInterceptors = new ArrayList<>(); + Map selectedOverrideConfigs = + new HashMap<>(virtualHost.filterConfigOverrides()); + selectedOverrideConfigs.putAll(route.filterConfigOverrides()); + if (namedFilterConfigs != null) { + for (NamedFilterConfig namedFilterConfig : namedFilterConfigs) { + FilterConfig filterConfig = namedFilterConfig.filterConfig; + Filter filter = filterRegistry.get(filterConfig.typeUrl()); + if (filter instanceof ServerInterceptorBuilder) { + ServerInterceptor interceptor = + ((ServerInterceptorBuilder) filter).buildServerInterceptor( + filterConfig, selectedOverrideConfigs.get(namedFilterConfig.name)); + if (interceptor != null) { + filterInterceptors.add(interceptor); + } + } else { + logger.log(Level.WARNING, "HttpFilterConfig(type URL: " + + filterConfig.typeUrl() + ") is not supported on server-side. " + + "Probably a bug at ClientXdsClient verification."); + } + } + } + ServerInterceptor interceptor = combineInterceptors(filterInterceptors); + perRouteInterceptors.put(route, interceptor); + } + } + return perRouteInterceptors.build(); + } + + private ServerInterceptor combineInterceptors(final List interceptors) { + if (interceptors.isEmpty()) { + return noopInterceptor; + } + if (interceptors.size() == 1) { + return interceptors.get(0); + } + return new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + // intercept forward + for (int i = interceptors.size() - 1; i >= 0; i--) { + next = InternalServerInterceptors.interceptCallHandlerCreate( + interceptors.get(i), next); + } + return next.startCall(call, headers); + } + }; + } + + private void handleConfigNotFound(StatusException exception) { + cleanUpRouteDiscoveryStates(); + List toRelease = getSuppliersInUse(); + filterChainSelectorManager.updateSelector(FilterChainSelector.NO_FILTER_CHAIN); + for (SslContextProviderSupplier s: toRelease) { + s.close(); + } + if (!initialStarted) { + initialStarted = true; + initialStartFuture.set(exception); + } + if (restartTimer != null) { + restartTimer.cancel(); + } + if (!delegate.isShutdown()) { + delegate.shutdown(); // let in-progress calls finish + } + isServing = false; + listener.onNotServing(exception); + } + + private void cleanUpRouteDiscoveryStates() { + for (RouteDiscoveryState rdsState : routeDiscoveryStates.values()) { + String rdsName = rdsState.resourceName; + logger.log(Level.FINE, "Stop watching RDS resource {0}", rdsName); + xdsClient.cancelRdsResourceWatch(rdsName, rdsState); + } + routeDiscoveryStates.clear(); + savedRdsRoutingConfigRef.clear(); + } + + private List getSuppliersInUse() { + List toRelease = new ArrayList<>(); + FilterChainSelector selector = filterChainSelectorManager.getSelectorToUpdateSelector(); + if (selector != null) { + for (FilterChain f: selector.getRoutingConfigs().keySet()) { + if (f.getSslContextProviderSupplier() != null) { + toRelease.add(f.getSslContextProviderSupplier()); + } + } + SslContextProviderSupplier defaultSupplier = + selector.getDefaultSslContextProviderSupplier(); + if (defaultSupplier != null) { + toRelease.add(defaultSupplier); + } + } + return toRelease; + } + + private void releaseSuppliersInFlight() { + SslContextProviderSupplier supplier; + for (FilterChain filterChain : filterChains) { + supplier = filterChain.getSslContextProviderSupplier(); + if (supplier != null) { + supplier.close(); + } + } + if (defaultFilterChain != null + && (supplier = defaultFilterChain.getSslContextProviderSupplier()) != null) { + supplier.close(); + } + } + + private final class RouteDiscoveryState implements RdsResourceWatcher { + private final String resourceName; + private ImmutableList savedVirtualHosts; + private boolean isPending = true; + + private RouteDiscoveryState(String resourceName) { + this.resourceName = checkNotNull(resourceName, "resourceName"); + } + + @Override + public void onChanged(final RdsUpdate update) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; + } + if (savedVirtualHosts == null && !isPending) { + logger.log(Level.WARNING, "Received valid Rds {0} configuration.", resourceName); + } + savedVirtualHosts = ImmutableList.copyOf(update.virtualHosts); + updateRdsRoutingConfig(); + maybeUpdateSelector(); + } + }); + } + + @Override + public void onResourceDoesNotExist(final String resourceName) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; + } + logger.log(Level.WARNING, "Rds {0} unavailable", resourceName); + savedVirtualHosts = null; + updateRdsRoutingConfig(); + maybeUpdateSelector(); + } + }); + } + + @Override + public void onError(final Status error) { + syncContext.execute(new Runnable() { + @Override + public void run() { + if (!routeDiscoveryStates.containsKey(resourceName)) { + return; + } + logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", + new Object[]{resourceName, error}); + maybeUpdateSelector(); + } + }); + } + + private void updateRdsRoutingConfig() { + for (FilterChain filterChain : savedRdsRoutingConfigRef.keySet()) { + if (resourceName.equals(filterChain.getHttpConnectionManager().rdsName())) { + ServerRoutingConfig updatedRoutingConfig; + if (savedVirtualHosts == null) { + updatedRoutingConfig = ServerRoutingConfig.FAILING_ROUTING_CONFIG; + } else { + ImmutableMap updatedInterceptors = + generatePerRouteInterceptors( + filterChain.getHttpConnectionManager().httpFilterConfigs(), + savedVirtualHosts); + updatedRoutingConfig = ServerRoutingConfig.create(savedVirtualHosts, + updatedInterceptors); + } + savedRdsRoutingConfigRef.get(filterChain).set(updatedRoutingConfig); + } + } + } + + // Update the selector to use the most recently updated configs only after all rds have been + // discovered for the first time. Later changes on rds will be applied through virtual host + // list atomic ref. + private void maybeUpdateSelector() { + isPending = false; + boolean isLastPending = pendingRds.remove(resourceName) && pendingRds.isEmpty(); + if (isLastPending) { + updateSelector(); + } + } + } + + private boolean isPermanentError(Status error) { + return EnumSet.of( + Status.Code.INTERNAL, + Status.Code.INVALID_ARGUMENT, + Status.Code.FAILED_PRECONDITION, + Status.Code.PERMISSION_DENIED, + Status.Code.UNAUTHENTICATED) + .contains(error.getCode()); + } + } + + @VisibleForTesting + final class ConfigApplyingInterceptor implements ServerInterceptor { + private final ServerInterceptor noopInterceptor = new ServerInterceptor() { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + return next.startCall(call, headers); + } + }; + + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + AtomicReference routingConfigRef = + call.getAttributes().get(ATTR_SERVER_ROUTING_CONFIG); + ServerRoutingConfig routingConfig = routingConfigRef == null ? null : + routingConfigRef.get(); + if (routingConfig == null || routingConfig == ServerRoutingConfig.FAILING_ROUTING_CONFIG) { + String errorMsg = "Missing or broken xDS routing config: RDS config unavailable."; + call.close(Status.UNAVAILABLE.withDescription(errorMsg), new Metadata()); + return new Listener() {}; + } + List virtualHosts = routingConfig.virtualHosts(); + VirtualHost virtualHost = RoutingUtils.findVirtualHostForHostName( + virtualHosts, call.getAuthority()); + if (virtualHost == null) { + call.close( + Status.UNAVAILABLE.withDescription("Could not find xDS virtual host matching RPC"), + new Metadata()); + return new Listener() {}; + } + Route selectedRoute = null; + MethodDescriptor method = call.getMethodDescriptor(); + for (Route route : virtualHost.routes()) { + if (RoutingUtils.matchRoute( + route.routeMatch(), "/" + method.getFullMethodName(), headers, random)) { + selectedRoute = route; + break; + } + } + if (selectedRoute == null) { + call.close(Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC"), + new Metadata()); + return new ServerCall.Listener() {}; + } + if (selectedRoute.routeAction() != null) { + call.close(Status.UNAVAILABLE.withDescription("Invalid xDS route action for matching " + + "route: only Route.non_forwarding_action should be allowed."), new Metadata()); + return new ServerCall.Listener() {}; + } + ServerInterceptor routeInterceptor = noopInterceptor; + Map perRouteInterceptors = routingConfig.interceptors(); + if (perRouteInterceptors != null && perRouteInterceptors.get(selectedRoute) != null) { + routeInterceptor = perRouteInterceptors.get(selectedRoute); + } + return routeInterceptor.interceptCall(call, headers, next); + } + } + + /** + * The HttpConnectionManager level configuration. + */ + @AutoValue + abstract static class ServerRoutingConfig { + @VisibleForTesting + static final ServerRoutingConfig FAILING_ROUTING_CONFIG = ServerRoutingConfig.create( + ImmutableList.of(), ImmutableMap.of()); + + abstract ImmutableList virtualHosts(); + + // Prebuilt per route server interceptors from http filter configs. + abstract ImmutableMap interceptors(); + + /** + * Server routing configuration. + * */ + public static ServerRoutingConfig create( + ImmutableList virtualHosts, + ImmutableMap interceptors) { + checkNotNull(virtualHosts, "virtualHosts"); + checkNotNull(interceptors, "interceptors"); + return new AutoValue_XdsServerWrapper_ServerRoutingConfig(virtualHosts, interceptors); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/Matchers.java b/xds/src/main/java/io/grpc/xds/internal/Matchers.java index 28ec8418297..3bf7b7723e2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/Matchers.java +++ b/xds/src/main/java/io/grpc/xds/internal/Matchers.java @@ -117,13 +117,8 @@ private static HeaderMatcher create(String name, @Nullable String exactValue, /** Returns the matching result. */ public boolean matches(@Nullable String value) { - if (present() != null) { - return (value == null) == present().equals(inverted()); - } - // FIXME(zivy@): invert result for null value. - // https://github.com/envoyproxy/envoy/blob/0fae6970ddaf93f024908ba304bbd2b34e997a51/source/common/http/header_utility.cc#L130 if (value == null) { - return false; + return present() != null && present() == inverted(); } boolean baseMatch; if (exactValue() != null) { @@ -141,6 +136,8 @@ public boolean matches(@Nullable String value) { } } else if (prefix() != null) { baseMatch = value.startsWith(prefix()); + } else if (present() != null) { + baseMatch = present(); } else { baseMatch = value.endsWith(suffix()); } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java index 1dc7be1be33..ce9ef3de680 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProvider.java @@ -22,7 +22,6 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; @@ -32,6 +31,7 @@ import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; import java.util.Map; +import javax.annotation.Nullable; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ @Internal @@ -39,7 +39,7 @@ public final class CertProviderClientSslContextProvider extends CertProviderSslC private CertProviderClientSslContextProvider( Node node, - Map certProviders, + @Nullable Map certProviders, CommonTlsContext.CertificateProviderInstance certInstance, CommonTlsContext.CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, @@ -90,30 +90,15 @@ public static Factory getInstance() { public CertProviderClientSslContextProvider getProvider( UpstreamTlsContext upstreamTlsContext, Node node, - Map certProviders) { + @Nullable Map certProviders) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); - CommonTlsContext.CertificateProviderInstance rootCertInstance = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = - combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - CommonTlsContext.CertificateProviderInstance certInstance = null; - if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); - } + CertificateValidationContext staticCertValidationContext = getStaticValidationContext( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( + commonTlsContext); return new CertProviderClientSslContextProvider( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java index 78e825f60fd..a7f0849d00b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProvider.java @@ -22,7 +22,6 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; import io.grpc.Internal; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; @@ -35,6 +34,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Map; +import javax.annotation.Nullable; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ @Internal @@ -42,7 +42,7 @@ public final class CertProviderServerSslContextProvider extends CertProviderSslC private CertProviderServerSslContextProvider( Node node, - Map certProviders, + @Nullable Map certProviders, CommonTlsContext.CertificateProviderInstance certInstance, CommonTlsContext.CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, @@ -93,30 +93,15 @@ public static Factory getInstance() { public CertProviderServerSslContextProvider getProvider( DownstreamTlsContext downstreamTlsContext, Node node, - Map certProviders) { + @Nullable Map certProviders) { checkNotNull(downstreamTlsContext, "downstreamTlsContext"); CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext(); - CommonTlsContext.CertificateProviderInstance rootCertInstance = null; - CertificateValidationContext staticCertValidationContext = null; - if (commonTlsContext.hasCombinedValidationContext()) { - CombinedCertificateValidationContext combinedValidationContext = - commonTlsContext.getCombinedValidationContext(); - if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = - combinedValidationContext.getValidationContextCertificateProviderInstance(); - } - if (combinedValidationContext.hasDefaultValidationContext()) { - staticCertValidationContext = combinedValidationContext.getDefaultValidationContext(); - } - } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { - rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance(); - } else if (commonTlsContext.hasValidationContext()) { - staticCertValidationContext = commonTlsContext.getValidationContext(); - } - CommonTlsContext.CertificateProviderInstance certInstance = null; - if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { - certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance(); - } + CertificateValidationContext staticCertValidationContext = getStaticValidationContext( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance( + commonTlsContext); + CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance( + commonTlsContext); return new CertProviderServerSslContextProvider( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java index eef5ee551e7..1ec58764196 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertProviderSslContextProvider.java @@ -18,9 +18,11 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; +import io.grpc.xds.internal.sds.CommonTlsContextUtil; import io.grpc.xds.internal.sds.DynamicSslContextProvider; import java.security.PrivateKey; import java.security.cert.X509Certificate; @@ -42,7 +44,7 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider protected CertProviderSslContextProvider( Node node, - Map certProviders, + @Nullable Map certProviders, CertificateProviderInstance certInstance, CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, @@ -56,8 +58,8 @@ protected CertProviderSslContextProvider( certInstanceName = certInstance.getInstanceName(); CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, certInstanceName); - certHandle = - certificateProviderStore.createOrGetProvider( + certHandle = certProviderInstanceConfig == null ? null + : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.getPluginName(), certProviderInstanceConfig.getConfig(), @@ -71,8 +73,8 @@ protected CertProviderSslContextProvider( && !rootCertInstance.getInstanceName().equals(certInstanceName)) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); - rootCertHandle = - certificateProviderStore.createOrGetProvider( + rootCertHandle = certProviderInstanceConfig == null ? null + : certificateProviderStore.createOrGetProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.getPluginName(), certProviderInstanceConfig.getConfig(), @@ -84,8 +86,54 @@ protected CertProviderSslContextProvider( } private static CertificateProviderInfo getCertProviderConfig( - Map certProviders, String pluginInstanceName) { - return certProviders.get(pluginInstanceName); + @Nullable Map certProviders, String pluginInstanceName) { + return certProviders != null ? certProviders.get(pluginInstanceName) : null; + } + + @Nullable + protected static CertificateProviderInstance getCertProviderInstance( + CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasTlsCertificateProviderInstance()) { + return CommonTlsContextUtil.convert(commonTlsContext.getTlsCertificateProviderInstance()); + } else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) { + return commonTlsContext.getTlsCertificateCertificateProviderInstance(); + } + return null; + } + + @Nullable + protected static CertificateValidationContext getStaticValidationContext( + CommonTlsContext commonTlsContext) { + if (commonTlsContext.hasValidationContext()) { + return commonTlsContext.getValidationContext(); + } else if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasDefaultValidationContext()) { + return combinedValidationContext.getDefaultValidationContext(); + } + } + return null; + } + + @Nullable + protected static CommonTlsContext.CertificateProviderInstance getRootCertProviderInstance( + CommonTlsContext commonTlsContext) { + CertificateValidationContext certValidationContext = getStaticValidationContext( + commonTlsContext); + if (certValidationContext != null && certValidationContext.hasCaCertificateProviderInstance()) { + return CommonTlsContextUtil.convert(certValidationContext.getCaCertificateProviderInstance()); + } + if (commonTlsContext.hasCombinedValidationContext()) { + CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext = + commonTlsContext.getCombinedValidationContext(); + if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) { + return combinedValidationContext.getValidationContextCertificateProviderInstance(); + } + } else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) { + return commonTlsContext.getValidationContextCertificateProviderInstance(); + } + return null; } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java b/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java index 6d275d322a2..bb911461a27 100644 --- a/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java +++ b/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java @@ -20,6 +20,7 @@ import com.google.auto.value.AutoValue; import com.google.common.base.Joiner; +import com.google.common.io.BaseEncoding; import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.ServerCall; @@ -35,6 +36,7 @@ import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Locale; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -234,6 +236,23 @@ public boolean matches(EvaluateArgs args) { } } + public static final class DestinationPortRangeMatcher implements Matcher { + private final int start; + private final int end; + + /** Start of the range is inclusive. End of the range is exclusive.*/ + public DestinationPortRangeMatcher(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public boolean matches(EvaluateArgs args) { + int port = args.getDestinationPort(); + return port >= start && port < end; + } + } + public static final class RequestedServerNameMatcher implements Matcher { private final Matchers.StringMatcher delegate; @@ -316,9 +335,44 @@ private Collection getPrincipalNames() { @Nullable private String getHeader(String headerName) { - if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + headerName = headerName.toLowerCase(Locale.ROOT); + if ("te".equals(headerName)) { return null; } + if (":authority".equals(headerName)) { + headerName = "host"; + } + if ("host".equals(headerName)) { + return serverCall.getAuthority(); + } + if (":path".equals(headerName)) { + return getPath(); + } + if (":method".equals(headerName)) { + return "POST"; + } + return deserializeHeader(headerName); + } + + @Nullable + private String deserializeHeader(String headerName) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + Metadata.Key key; + try { + key = Metadata.Key.of(headerName, Metadata.BINARY_BYTE_MARSHALLER); + } catch (IllegalArgumentException e) { + return null; + } + Iterable values = metadata.getAll(key); + if (values == null) { + return null; + } + List encoded = new ArrayList<>(); + for (byte[] v : values) { + encoded.add(BaseEncoding.base64().omitPadding().encode(v)); + } + return Joiner.on(",").join(encoded); + } Metadata.Key key; try { key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java index 234989ad115..0c28c79ee22 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/CommonTlsContextUtil.java @@ -16,11 +16,12 @@ package io.grpc.xds.internal.sds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext; /** Class for utility functions for {@link CommonTlsContext}. */ -final class CommonTlsContextUtil { +public final class CommonTlsContextUtil { private CommonTlsContextUtil() {} @@ -38,4 +39,15 @@ private static boolean hasCertProviderValidationContext(CommonTlsContext commonT } return commonTlsContext.hasValidationContextCertificateProviderInstance(); } + + /** + * Converts {@link CertificateProviderPluginInstance} to + * {@link CommonTlsContext.CertificateProviderInstance}. + */ + public static CommonTlsContext.CertificateProviderInstance convert( + CertificateProviderPluginInstance pluginInstance) { + return CommonTlsContext.CertificateProviderInstance.newBuilder() + .setInstanceName(pluginInstance.getInstanceName()) + .setCertificateName(pluginInstance.getCertificateName()).build(); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index 37161325746..0128fa53106 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -29,7 +29,6 @@ import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.InternalXdsAttributes; -import io.grpc.xds.XdsClientWrapperForServerSds; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -57,10 +56,12 @@ private SdsProtocolNegotiators() { private static final Logger logger = Logger.getLogger(SdsProtocolNegotiators.class.getName()); - public static final Attributes.Key SERVER_XDS_CLIENT - = Attributes.Key.create("serverXdsClient"); private static final AsciiString SCHEME = AsciiString.of("http"); + public static final Attributes.Key + ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.sds.server.sslContextProviderSupplier"); + /** * Returns a {@link InternalProtocolNegotiator.ClientFactory}. * @@ -253,10 +254,7 @@ public AsciiString scheme() { @Override public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - grpcHandler.getEagAttributes().get(SERVER_XDS_CLIENT); - return new HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, - fallbackProtocolNegotiator); + return new HandlerPickerHandler(grpcHandler, fallbackProtocolNegotiator); } @Override @@ -267,25 +265,21 @@ public void close() {} static final class HandlerPickerHandler extends ChannelInboundHandlerAdapter { private final GrpcHttp2ConnectionHandler grpcHandler; - private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; HandlerPickerHandler( GrpcHttp2ConnectionHandler grpcHandler, - @Nullable XdsClientWrapperForServerSds xdsClientWrapperForServerSds, - ProtocolNegotiator fallbackProtocolNegotiator) { + @Nullable ProtocolNegotiator fallbackProtocolNegotiator) { this.grpcHandler = checkNotNull(grpcHandler, "grpcHandler"); - this.xdsClientWrapperForServerSds = xdsClientWrapperForServerSds; this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; } @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds == null - ? null - : xdsClientWrapperForServerSds.getSslContextProviderSupplier(ctx.channel()); + ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent)evt; + SslContextProviderSupplier sslContextProviderSupplier = InternalProtocolNegotiationEvent + .getAttributes(pne).get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER); if (sslContextProviderSupplier == null) { if (fallbackProtocolNegotiator == null) { ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); @@ -297,7 +291,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc this, null, fallbackProtocolNegotiator.newHandler(grpcHandler)); - ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ctx.fireUserEventTriggered(pne); return; } else { @@ -307,7 +300,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc null, new ServerSdsHandler( grpcHandler, sslContextProviderSupplier)); - ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ctx.fireUserEventTriggered(pne); return; } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java deleted file mode 100644 index 968c7385499..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerWrapperForXds.java +++ /dev/null @@ -1,368 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.sds; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.ServerServiceDefinition; -import io.grpc.Status; -import io.grpc.internal.GrpcUtil; -import io.grpc.internal.SharedResourceHolder; -import io.grpc.xds.XdsClientWrapperForServerSds; -import io.grpc.xds.XdsInitializationException; -import io.grpc.xds.XdsServerBuilder; -import java.io.IOException; -import java.net.BindException; -import java.net.SocketAddress; -import java.util.EnumSet; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import javax.annotation.Nullable; - -/** - * Wraps a {@link Server} delegate and {@link XdsClientWrapperForServerSds} and intercepts {@link - * Server#shutdown()} and {@link Server#start()} to shut down and start the - * {@link XdsClientWrapperForServerSds} object. - */ -@VisibleForTesting -public final class ServerWrapperForXds extends Server { - private Server delegate; - private final ServerBuilder delegateBuilder; - private final XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener; - @Nullable XdsClientWrapperForServerSds.ServerWatcher serverWatcher; - private AtomicBoolean started = new AtomicBoolean(); - private volatile ServingState currentServingState; - private final long delayForRetry; - private final TimeUnit timeUnitForDelayForRetry; - private StartRetryTask startRetryTask; - - @VisibleForTesting public enum ServingState { - // during start() i.e. first start - STARTING, - - // after start (1st or subsequent ones) - STARTED, - - // not serving due to listener deletion - NOT_SERVING, - - // enter serving mode after NOT_SERVING - ENTER_SERVING, - - // shut down - could be due to failure - SHUTDOWN - } - - /** Creates the wrapper object using the delegate passed. */ - public ServerWrapperForXds( - ServerBuilder delegateBuilder, - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, - XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) { - this( - delegateBuilder, - xdsClientWrapperForServerSds, - xdsServingStatusListener, - 1L, - TimeUnit.MINUTES); - } - - /** Creates the wrapper object using params passed - used for tests. */ - @VisibleForTesting - public ServerWrapperForXds(ServerBuilder delegateBuilder, - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, - XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener, - long delayForRetry, TimeUnit timeUnitForDelayForRetry) { - this.delegateBuilder = checkNotNull(delegateBuilder, "delegateBuilder"); - this.delegate = delegateBuilder.build(); - this.xdsClientWrapperForServerSds = - checkNotNull(xdsClientWrapperForServerSds, "xdsClientWrapperForServerSds"); - this.xdsServingStatusListener = - checkNotNull(xdsServingStatusListener, "xdsServingStatusListener"); - this.delayForRetry = delayForRetry; - this.timeUnitForDelayForRetry = - checkNotNull(timeUnitForDelayForRetry, "timeUnitForDelayForRetry"); - } - - @Override - public Server start() throws IOException { - checkState(started.compareAndSet(false, true), "Already started"); - currentServingState = ServingState.STARTING; - SettableFuture future = addServerWatcher(); - xdsClientWrapperForServerSds.start(); - try { - Throwable throwable = future.get(); - if (throwable != null) { - removeServerWatcher(); - if (throwable instanceof IOException) { - throw (IOException) throwable; - } - throw new IOException(throwable); - } - } catch (InterruptedException | ExecutionException ex) { - removeServerWatcher(); - if (ex instanceof InterruptedException) { - Thread.currentThread().interrupt(); - } - throw new RuntimeException(ex); - } - return this; - } - - @VisibleForTesting public ServingState getCurrentServingState() { - return currentServingState; - } - - private SettableFuture addServerWatcher() { - final SettableFuture future = SettableFuture.create(); - serverWatcher = - new XdsClientWrapperForServerSds.ServerWatcher() { - @Override - public void onError(Throwable throwable, boolean isAbsent) { - synchronized (ServerWrapperForXds.this) { - if (currentServingState == ServingState.SHUTDOWN) { - return; - } else if (currentServingState == ServingState.STARTING) { - // during start - if (isPermanentErrorFromXds(throwable)) { - currentServingState = ServingState.SHUTDOWN; - future.set(throwable); - return; - } - xdsServingStatusListener.onNotServing(throwable); - } else { - // is one of STARTED, NOT_SERVING or ENTER_SERVING - if (isAbsent) { - xdsServingStatusListener.onNotServing(throwable); - if (currentServingState == ServingState.STARTED) { - // shutdown the server - delegate.shutdown(); // let existing calls finish on delegate - currentServingState = ServingState.NOT_SERVING; - } - } - } - } - } - - @Override - public void onListenerUpdate() { - synchronized (ServerWrapperForXds.this) { - if (currentServingState == ServingState.SHUTDOWN) { - return; - } else if (currentServingState == ServingState.STARTING) { - // during start() - try { - delegate.start(); - currentServingState = ServingState.STARTED; - xdsServingStatusListener.onServing(); - future.set(null); - } catch (IOException ioe) { - future.set(ioe); - } - } else if (currentServingState == ServingState.NOT_SERVING) { - // coming out of NOT_SERVING - currentServingState = ServingState.ENTER_SERVING; - startRetryTask = new StartRetryTask(); - startRetryTask.createTask(0L); - } - } - } - }; - xdsClientWrapperForServerSds.addServerWatcher(serverWatcher); - return future; - } - - private final class StartRetryTask implements Runnable { - - ScheduledExecutorService timerService; - AtomicReference> future = new AtomicReference<>(); - - private void createTask(long delay) { - if (timerService == null) { - timerService = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - } - future.set(timerService.schedule(this, delay, timeUnitForDelayForRetry)); - } - - private void rebuildAndRestartServer() { - delegate = delegateBuilder.build(); - try { - delegate = delegate.start(); - currentServingState = ServingState.STARTED; - xdsServingStatusListener.onServing(); - cleanUpStartRetryTask(); - } catch (IOException ioe) { - xdsServingStatusListener.onNotServing(ioe); - if (isRetriableErrorInDelegateStart(ioe)) { - createTask(delayForRetry); - } else { - // permanent failure - currentServingState = ServingState.SHUTDOWN; - cleanUpStartRetryTask(); - } - } - } - - @Override - public void run() { - if (currentServingState == ServingState.SHUTDOWN) { - return; - } else if (currentServingState != ServingState.ENTER_SERVING) { - throw new AssertionError("Wrong state:" + currentServingState); - } - rebuildAndRestartServer(); - } - - private void cleanUpStartRetryTask() { - synchronized (ServerWrapperForXds.this) { - if (timerService != null) { - timerService = SharedResourceHolder.release(GrpcUtil.TIMER_SERVICE, timerService); - } - startRetryTask = null; - } - } - - public void shutdownNow() { - ScheduledFuture oldValue = future.getAndSet(null); - if (oldValue != null) { - oldValue.cancel(true); - } - cleanUpStartRetryTask(); - } - } - - private void removeServerWatcher() { - synchronized (xdsClientWrapperForServerSds) { - if (serverWatcher != null) { - xdsClientWrapperForServerSds.removeServerWatcher(serverWatcher); - serverWatcher = null; - } - } - } - - // if the IOException indicates we can rebuild delegate and retry start... - private static boolean isRetriableErrorInDelegateStart(IOException ioe) { - if (ioe instanceof BindException) { - return true; - } - Throwable cause = ioe.getCause(); - return cause instanceof BindException; - } - - // if the Throwable indicates a permanent error in xDS processing - private static boolean isPermanentErrorFromXds(Throwable throwable) { - if (throwable instanceof XdsInitializationException) { - return true; - } - Status.Code code = Status.fromThrowable(throwable).getCode(); - return EnumSet.of( - Status.Code.INTERNAL, - Status.Code.INVALID_ARGUMENT, - Status.Code.FAILED_PRECONDITION, - Status.Code.PERMISSION_DENIED, - Status.Code.UNAUTHENTICATED) - .contains(code); - } - - private void cleanupStartRetryTaskAndShutdownDelegateAndXdsClient(boolean shutdownNow) { - Server delegateCopy = null; - synchronized (ServerWrapperForXds.this) { - if (startRetryTask != null) { - startRetryTask.shutdownNow(); - } - currentServingState = ServingState.SHUTDOWN; - if (delegate != null && !delegate.isShutdown()) { - delegateCopy = delegate; - } - } - if (delegateCopy != null) { - if (shutdownNow) { - delegateCopy.shutdownNow(); - } else { - delegateCopy.shutdown(); - } - } - xdsClientWrapperForServerSds.shutdown(); - } - - @Override - public Server shutdown() { - cleanupStartRetryTaskAndShutdownDelegateAndXdsClient(false); - return this; - } - - @Override - public Server shutdownNow() { - cleanupStartRetryTaskAndShutdownDelegateAndXdsClient(true); - return this; - } - - @Override - public boolean isShutdown() { - return delegate.isShutdown(); - } - - @Override - public boolean isTerminated() { - return delegate.isTerminated(); - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { - return delegate.awaitTermination(timeout, unit); - } - - @Override - public void awaitTermination() throws InterruptedException { - delegate.awaitTermination(); - } - - @Override - public int getPort() { - return delegate.getPort(); - } - - @Override - public List getListenSockets() { - return delegate.getListenSockets(); - } - - @Override - public List getServices() { - return delegate.getServices(); - } - - @Override - public List getImmutableServices() { - return delegate.getImmutableServices(); - } - - @Override - public List getMutableServices() { - return delegate.getMutableServices(); - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 3300c22b2bf..664b4881bc2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -128,15 +128,13 @@ public boolean equals(Object o) { return false; } SslContextProviderSupplier that = (SslContextProviderSupplier) o; - return shutdown == that.shutdown - && Objects.equals(tlsContext, that.tlsContext) - && Objects.equals(tlsContextManager, that.tlsContextManager) - && Objects.equals(sslContextProvider, that.sslContextProvider); + return Objects.equals(tlsContext, that.tlsContext) + && Objects.equals(tlsContextManager, that.tlsContextManager); } @Override public int hashCode() { - return Objects.hash(tlsContext, tlsContextManager, sslContextProvider, shutdown); + return Objects.hash(tlsContext, tlsContextManager); } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index 4b12ebd71c8..17ca907da42 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -19,9 +19,9 @@ import static com.google.common.truth.Truth.assertThat; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.protobuf.Any; import com.google.protobuf.BoolValue; -import com.google.protobuf.Duration; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.StringValue; import com.google.protobuf.UInt32Value; @@ -46,7 +46,6 @@ import io.envoyproxy.envoy.config.core.v3.TrafficDirection; import io.envoyproxy.envoy.config.core.v3.TransportSocket; import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; -import io.envoyproxy.envoy.config.core.v3.WatchedDirectory; import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; import io.envoyproxy.envoy.config.listener.v3.Filter; import io.envoyproxy.envoy.config.listener.v3.FilterChain; @@ -78,6 +77,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; @@ -86,7 +86,6 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsParameters; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsSessionTicketKeys; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatchAndSubstitute; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; @@ -130,18 +129,23 @@ public class ClientXdsClientDataTest { @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); - private final FilterRegistry filterRegistry = FilterRegistry.newRegistry(); + private final FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); private boolean originalEnableRetry; + private boolean originalEnableRbac; @Before public void setUp() { originalEnableRetry = ClientXdsClient.enableRetry; - assertThat(originalEnableRetry).isFalse(); + assertThat(originalEnableRetry).isTrue(); + originalEnableRbac = ClientXdsClient.enableRbac; + assertThat(originalEnableRbac).isFalse(); + ClientXdsClient.enableRbac = true; } @After public void tearDown() { ClientXdsClient.enableRetry = originalEnableRetry; + ClientXdsClient.enableRbac = originalEnableRbac; } @Test @@ -258,6 +262,7 @@ public void parseRoute_skipRouteWithUnsupportedAction() { } @Test + @SuppressWarnings("deprecation") public void parseRouteMatch_withHeaderMatcher() { io.envoyproxy.envoy.config.route.v3.RouteMatch proto = io.envoyproxy.envoy.config.route.v3.RouteMatch.newBuilder() @@ -338,6 +343,7 @@ public void parsePathMatcher_withSafeRegEx() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withExactMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -351,6 +357,7 @@ public void parseHeaderMatcher_withExactMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withSafeRegExMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -390,6 +397,7 @@ public void parseHeaderMatcher_withPresentMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withPrefixMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -403,6 +411,7 @@ public void parseHeaderMatcher_withPrefixMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_withSuffixMatch() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -416,6 +425,7 @@ public void parseHeaderMatcher_withSuffixMatch() { } @Test + @SuppressWarnings("deprecation") public void parseHeaderMatcher_malformedRegExPattern() { io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto = io.envoyproxy.envoy.config.route.v3.HeaderMatcher.newBuilder() @@ -547,7 +557,8 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder.build()) .build(); struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); - assertThat(struct.getStruct().retryPolicy()).isNull(); + assertThat(struct.getStruct().retryPolicy()).isNotNull(); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()).isEmpty(); // base_interval unset builder @@ -646,7 +657,8 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); - assertThat(struct.getStruct().retryPolicy()).isNull(); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) + .containsExactly(Code.CANCELLED); // unsupported retry_on code builder = RetryPolicy.newBuilder() @@ -662,7 +674,25 @@ public void parseRouteAction_withRetryPolicy() { .setRetryPolicy(builder) .build(); struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); - assertThat(struct.getStruct().retryPolicy()).isNull(); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) + .containsExactly(Code.CANCELLED); + + // whitespace in retry_on + builder = RetryPolicy.newBuilder() + .setNumRetries(UInt32Value.of(3)) + .setRetryBackOff( + RetryBackOff.newBuilder() + .setBaseInterval(Durations.fromMillis(500)) + .setMaxInterval(Durations.fromMillis(600))) + .setPerTryTimeout(Durations.fromMillis(300)) + .setRetryOn("abort, , cancelled , "); + proto = io.envoyproxy.envoy.config.route.v3.RouteAction.newBuilder() + .setCluster("cluster-foo") + .setRetryPolicy(builder) + .build(); + struct = ClientXdsClient.parseRouteAction(proto, filterRegistry, false); + assertThat(struct.getStruct().retryPolicy().retryableStatusCodes()) + .containsExactly(Code.CANCELLED); } @Test @@ -1083,6 +1113,19 @@ public void parseHttpConnectionManager_xffNumTrustedHopsUnsupported() hcm, new HashSet(), filterRegistry, false /* does not matter */, true /* does not matter */); } + + @Test + public void parseHttpConnectionManager_OriginalIpDetectionExtensionsMustEmpty() + throws ResourceInvalidException { + @SuppressWarnings("deprecation") + HttpConnectionManager hcm = HttpConnectionManager.newBuilder() + .addOriginalIpDetectionExtensions(TypedExtensionConfig.newBuilder().build()) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("HttpConnectionManager with original_ip_detection_extensions unsupported"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, false /* does not matter */, false); + } @Test public void parseHttpConnectionManager_missingRdsAndInlinedRouteConfiguration() @@ -1108,6 +1151,9 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) .addHttpFilters( HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("terminal").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true)) .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("HttpConnectionManager contains duplicate HttpFilter: envoy.filter.foo"); @@ -1116,6 +1162,70 @@ public void parseHttpConnectionManager_duplicateHttpFilters() throws ResourceInv true /* does not matter */); } + @Test + public void parseHttpConnectionManager_lastNotTerminal() throws ResourceInvalidException { + filterRegistry.register(FaultFilter.INSTANCE); + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true) + .setTypedConfig(Any.pack(HTTPFault.newBuilder().build()))) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true /* does not matter */); + } + + @Test + public void parseHttpConnectionManager_terminalNotLast() throws ResourceInvalidException { + filterRegistry.register(RouterFilter.INSTANCE); + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .addHttpFilters( + HttpFilter.newBuilder().setName("terminal").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("A terminal HttpFilter must be the last filter: terminal"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true); + } + + @Test + public void parseHttpConnectionManager_unknownFilters() throws ResourceInvalidException { + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.foo").setIsOptional(true)) + .addHttpFilters( + HttpFilter.newBuilder().setName("envoy.filter.bar").setIsOptional(true)) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("The last HttpFilter must be a terminal filter: envoy.filter.bar"); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true /* does not matter */); + } + + @Test + public void parseHttpConnectionManager_emptyFilters() throws ResourceInvalidException { + HttpConnectionManager hcm = + HttpConnectionManager.newBuilder() + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("Missing HttpFilter in HttpConnectionManager."); + ClientXdsClient.parseHttpConnectionManager( + hcm, new HashSet(), filterRegistry, true /* parseHttpFilter */, + true /* does not matter */); + } + @Test public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInvalidException { Cluster cluster = Cluster.newBuilder() @@ -1130,7 +1240,7 @@ public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInval .setLbPolicy(LbPolicy.RING_HASH) .build(); - CdsUpdate update = ClientXdsClient.parseCluster(cluster, new HashSet()); + CdsUpdate update = ClientXdsClient.parseCluster(cluster, new HashSet(), null); assertThat(update.lbPolicy()).isEqualTo(CdsUpdate.LbPolicy.RING_HASH); assertThat(update.minRingSize()) .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE); @@ -1138,6 +1248,28 @@ public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInval .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE); } + @Test + public void parseCluster_transportSocketMatches_exception() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.ROUND_ROBIN) + .addTransportSocketMatches( + Cluster.TransportSocketMatch.newBuilder().setName("match1").build()) + .build(); + + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage( + "Cluster cluster-foo.googleapis.com: transport-socket-matches not supported."); + ClientXdsClient.parseCluster(cluster, new HashSet(), null); + } + @Test public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMax() throws ResourceInvalidException { @@ -1160,7 +1292,7 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMa thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); - ClientXdsClient.parseCluster(cluster, new HashSet()); + ClientXdsClient.parseCluster(cluster, new HashSet(), null); } @Test @@ -1187,7 +1319,7 @@ public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_tooLargeRingSize thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); - ClientXdsClient.parseCluster(cluster, new HashSet()); + ClientXdsClient.parseCluster(cluster, new HashSet(), null); } @Test @@ -1200,7 +1332,7 @@ public void parseServerSideListener_invalidTrafficDirection() throws ResourceInv thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 with invalid traffic direction: OUTBOUND"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1214,7 +1346,7 @@ public void parseServerSideListener_listenerFiltersPresent() throws ResourceInva thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 cannot have listener_filters"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1228,13 +1360,14 @@ public void parseServerSideListener_useOriginalDst() throws ResourceInvalidExcep thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Listener listener1 cannot have use_original_dst set to true"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceInvalidException { Filter filter1 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-1").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-1").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch1 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(80, 8080)) @@ -1250,7 +1383,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI .addFilters(filter1) .build(); Filter filter2 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-2").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-2").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch2 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(443, 8080)) @@ -1275,14 +1409,15 @@ public void parseServerSideListener_nonUniqueFilterChainMatch() throws ResourceI thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Found duplicate matcher:"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() throws ResourceInvalidException { Filter filter1 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-1").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-1").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch1 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(80, 8080)) @@ -1297,7 +1432,8 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() .addFilters(filter1) .build(); Filter filter2 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-2").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-2").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch2 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(443, 8080)) @@ -1322,13 +1458,14 @@ public void parseServerSideListener_nonUniqueFilterChainMatch_sameFilter() thrown.expect(ResourceInvalidException.class); thrown.expectMessage("Found duplicate matcher:"); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInvalidException { Filter filter1 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-1").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-1").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch1 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(80, 8080)) @@ -1345,7 +1482,8 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .addFilters(filter1) .build(); Filter filter2 = buildHttpConnectionManagerFilter( - HttpFilter.newBuilder().setName("http-filter-2").setIsOptional(true).build()); + HttpFilter.newBuilder().setName("http-filter-2").setTypedConfig( + Any.pack(Router.newBuilder().build())).setIsOptional(true).build()); FilterChainMatch filterChainMatch2 = FilterChainMatch.newBuilder() .addAllSourcePorts(Arrays.asList(443, 8080)) @@ -1369,7 +1507,7 @@ public void parseServerSideListener_uniqueFilterChainMatch() throws ResourceInva .addAllFilterChains(Arrays.asList(filterChain1, filterChain2)) .build(); ClientXdsClient.parseServerSideListener( - listener, new HashSet(), null, filterRegistry, true /* does not matter */); + listener, new HashSet(), null, filterRegistry, null, true /* does not matter */); } @Test @@ -1382,9 +1520,10 @@ public void parseFilterChain_noHcm() throws ResourceInvalidException { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "FilterChain filter-chain-foo missing required HttpConnectionManager filter"); + "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1400,9 +1539,10 @@ public void parseFilterChain_duplicateFilter() throws ResourceInvalidException { .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "FilterChain filter-chain-foo with duplicated filter: envoy.http_connection_manager"); + "FilterChain filter-chain-foo should contain exact one HttpConnectionManager filter"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1420,7 +1560,8 @@ public void parseFilterChain_filterMissingTypedConfig() throws ResourceInvalidEx "FilterChain filter-chain-foo contains filter envoy.http_connection_manager " + "without typed_config"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1442,7 +1583,8 @@ public void parseFilterChain_unsupportedFilter() throws ResourceInvalidException "FilterChain filter-chain-foo contains filter unsupported with unsupported " + "typed_config type unsupported-type-url"); ClientXdsClient.parseFilterChain( - filterChain, new HashSet(), null, filterRegistry, null, true /* does not matter */); + filterChain, new HashSet(), null, filterRegistry, null, null, + true /* does not matter */); } @Test @@ -1454,6 +1596,7 @@ public void parseFilterChain_noName_generatedUuid() throws ResourceInvalidExcept HttpFilter.newBuilder() .setName("http-filter-foo") .setIsOptional(true) + .setTypedConfig(Any.pack(Router.newBuilder().build())) .build())) .build(); FilterChain filterChain2 = @@ -1462,20 +1605,20 @@ public void parseFilterChain_noName_generatedUuid() throws ResourceInvalidExcept .addFilters(buildHttpConnectionManagerFilter( HttpFilter.newBuilder() .setName("http-filter-bar") + .setTypedConfig(Any.pack(Router.newBuilder().build())) .setIsOptional(true) .build())) .build(); EnvoyServerProtoData.FilterChain parsedFilterChain1 = ClientXdsClient.parseFilterChain( filterChain1, new HashSet(), null, filterRegistry, null, - true /* does not matter */); + null, true /* does not matter */); EnvoyServerProtoData.FilterChain parsedFilterChain2 = ClientXdsClient.parseFilterChain( filterChain2, new HashSet(), null, filterRegistry, null, - true /* does not matter */); + null, true /* does not matter */); assertThat(parsedFilterChain1.getName()).isNotEqualTo(parsedFilterChain2.getName()); } - @Test public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1483,7 +1626,7 @@ public void validateCommonTlsContext_tlsParams() throws ResourceInvalidException .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with tls_params is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1493,7 +1636,7 @@ public void validateCommonTlsContext_customHandshaker() throws ResourceInvalidEx .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context with custom_handshaker is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1502,8 +1645,8 @@ public void validateCommonTlsContext_validationContext() throws ResourceInvalidE .setValidationContext(CertificateValidationContext.getDefaultInstance()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1515,10 +1658,11 @@ public void validateCommonTlsContext_validationContextSdsSecretConfig() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with validation_context_sds_secret_config is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextCertificateProvider() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1528,10 +1672,11 @@ public void validateCommonTlsContext_validationContextCertificateProvider() thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "common-tls-context with validation_context_certificate_provider is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_validationContextCertificateProviderInstance() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1542,7 +1687,7 @@ public void validateCommonTlsContext_validationContextCertificateProviderInstanc thrown.expectMessage( "common-tls-context with validation_context_certificate_provider_instance is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1552,18 +1697,91 @@ public void validateCommonTlsContext_tlsCertificateProviderInstance_isRequiredFo .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "tls_certificate_certificate_provider_instance is required in downstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, true); + "tls_certificate_provider_instance is required in downstream-tls-context"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, true); + } + + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_tlsNewCertificateProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName("name1").build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); + } + + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_tlsCertificateProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); + } + + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_tlsCertificateProviderInstance_absentInBootstrapFile() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setTlsCertificateCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage( + "CertificateProvider instance name 'bad-name' not defined in the bootstrap file."); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), true); } + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_validationContextProviderInstance() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setValidationContextCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("name1").build()) + .build()) + .build(); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); + } + + @Test + @SuppressWarnings("deprecation") + public void validateCommonTlsContext_validationContextProviderInstance_absentInBootstrapFile() + throws ResourceInvalidException { + CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() + .setCombinedValidationContext( + CommonTlsContext.CombinedCertificateValidationContext.newBuilder() + .setValidationContextCertificateProviderInstance( + CertificateProviderInstance.newBuilder().setInstanceName("bad-name").build()) + .build()) + .build(); + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage( + "ca_certificate_provider_instance name 'bad-name' not defined in the bootstrap file."); + ClientXdsClient + .validateCommonTlsContext(commonTlsContext, ImmutableSet.of("name1", "name2"), false); + } + + @Test public void validateCommonTlsContext_tlsCertificatesCount() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .addTlsCertificates(TlsCertificate.getDefaultInstance()) .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("common-tls-context with tls_certificates is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + thrown.expectMessage("tls_certificate_provider_instance is unset"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1574,11 +1792,12 @@ public void validateCommonTlsContext_tlsCertificateSdsSecretConfigsCount() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "common-tls-context with tls_certificate_sds_secret_configs is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + "tls_certificate_provider_instance is unset"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_tlsCertificateCertificateProvider() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1587,8 +1806,8 @@ public void validateCommonTlsContext_tlsCertificateCertificateProvider() .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "common-tls-context with tls_certificate_certificate_provider is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + "tls_certificate_provider_instance is unset"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1597,8 +1816,8 @@ public void validateCommonTlsContext_combinedValidationContext_isRequiredForClie CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .build(); thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("combined_validation_context is required in upstream-tls-context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + thrown.expectMessage("ca_certificate_provider_instance is required in upstream-tls-context"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test @@ -1610,12 +1829,12 @@ public void validateCommonTlsContext_combinedValidationContextWithoutCertProvide .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage( - "validation_context_certificate_provider_instance is required in " - + "combined_validation_context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + "ca_certificate_provider_instance is required in upstream-tls-context"); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, null, false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextForServer() throws ResourceInvalidException, InvalidProtocolBufferException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1631,46 +1850,11 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextForS .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("match_subject_alt_names only allowed in upstream_tls_context"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, true); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDefaultValidationContextTrustedCa() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext(CertificateValidationContext.newBuilder() - .setTrustedCa(DataSource.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("trusted_ca in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDefaultValContextWatchedDirectory() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext(CertificateValidationContext.newBuilder() - .setWatchedDirectory(WatchedDirectory.getDefaultInstance()))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("watched_directory in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), true); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertSpki() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1686,10 +1870,11 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri thrown.expect(ResourceInvalidException.class); thrown.expectMessage("verify_certificate_spki in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDefaultValContextVerifyCertHash() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1705,10 +1890,11 @@ public void validateCommonTlsContext_combinedValContextWithDefaultValContextVeri thrown.expect(ResourceInvalidException.class); thrown.expectMessage("verify_certificate_hash in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextDfltValContextRequireSignedCertTimestamp() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1725,10 +1911,11 @@ public void validateCommonTlsContext_combinedValContextDfltValContextRequireSign thrown.expectMessage( "require_signed_certificate_timestamp in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValidationContextWithDefaultValidationContextCrl() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1743,49 +1930,11 @@ public void validateCommonTlsContext_combinedValidationContextWithDefaultValidat .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("crl in default_validation_context is not supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDefaultValContextAllowExpiredCert() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext( - CertificateValidationContext.newBuilder().setAllowExpiredCertificate(true))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown - .expectMessage("allow_expired_certificate in default_validation_context is not " - + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); - } - - @Test - public void validateCommonTlsContext_combinedValContextWithDfltValContextTrustChainVerification() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .setDefaultValidationContext(CertificateValidationContext.newBuilder() - .setTrustChainVerification( - CertificateValidationContext.TrustChainVerification.ACCEPT_UNTRUSTED))) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("Only VERIFY_TRUST_CHAIN for trust_chain_verification supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test + @SuppressWarnings("deprecation") public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomValidatorConfig() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() @@ -1801,7 +1950,7 @@ public void validateCommonTlsContext_combinedValContextWithDfltValContextCustomV thrown.expect(ResourceInvalidException.class); thrown.expectMessage("custom_validator_config in default_validation_context is not " + "supported"); - ClientXdsClient.validateCommonTlsContext(commonTlsContext, false); + ClientXdsClient.validateCommonTlsContext(commonTlsContext, ImmutableSet.of(""), false); } @Test @@ -1809,10 +1958,11 @@ public void validateDownstreamTlsContext_noCommonTlsContext() throws ResourceInv DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.getDefaultInstance(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context is required in downstream-tls-context"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); + ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, null); } @Test + @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( @@ -1828,90 +1978,11 @@ public void validateDownstreamTlsContext_hasRequireSni() throws ResourceInvalidE .build(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("downstream-tls-context with require-sni is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasSessionTikcetKeys() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSessionTicketKeys(TlsSessionTicketKeys.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("downstream-tls-context with session_ticket_keys is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasSessionTikcetKeysSdsSecretConfig() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSessionTicketKeysSdsSecretConfig(SdsSecretConfig.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "downstream-tls-context with session_ticket_keys_sds_secret_config is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasDisableStatelessSessionResumption() - throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setDisableStatelessSessionResumption(true) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage( - "downstream-tls-context with disable_stateless_session_resumption is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); - } - - @Test - public void validateDownstreamTlsContext_hasSessionTimeout() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .setTlsCertificateCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance()) - .build(); - DownstreamTlsContext downstreamTlsContext = DownstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSessionTimeout(Duration.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("downstream-tls-context with session_timeout is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); + ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test + @SuppressWarnings("deprecation") public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceInvalidException { CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() .setCombinedValidationContext( @@ -1928,7 +1999,7 @@ public void validateDownstreamTlsContext_hasOcspStaplePolicy() throws ResourceIn thrown.expect(ResourceInvalidException.class); thrown.expectMessage( "downstream-tls-context with ocsp_staple_policy value STRICT_STAPLING is not supported"); - ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext); + ClientXdsClient.validateDownstreamTlsContext(downstreamTlsContext, ImmutableSet.of("")); } @Test @@ -1936,58 +2007,7 @@ public void validateUpstreamTlsContext_noCommonTlsContext() throws ResourceInval UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.getDefaultInstance(); thrown.expect(ResourceInvalidException.class); thrown.expectMessage("common-tls-context is required in upstream-tls-context"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); - } - - @Test - public void validateUpstreamTlsContext_sni() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .build(); - UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setSni("foo") - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("upstream-tls-context with sni is not supported"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); - } - - @Test - public void validateUpstreamTlsContext_allowRenegotiation() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .build(); - UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setAllowRenegotiation(true) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("upstream-tls-context with allow_renegotiation is not supported"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); - } - - @Test - public void validateUpstreamTlsContext_maxSessionKeys() throws ResourceInvalidException { - CommonTlsContext commonTlsContext = CommonTlsContext.newBuilder() - .setCombinedValidationContext( - CommonTlsContext.CombinedCertificateValidationContext.newBuilder() - .setValidationContextCertificateProviderInstance( - CommonTlsContext.CertificateProviderInstance.getDefaultInstance())) - .build(); - UpstreamTlsContext upstreamTlsContext = UpstreamTlsContext.newBuilder() - .setCommonTlsContext(commonTlsContext) - .setMaxSessionKeys(UInt32Value.getDefaultInstance()) - .build(); - thrown.expect(ResourceInvalidException.class); - thrown.expectMessage("upstream-tls-context with max_session_keys is not supported"); - ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext); + ClientXdsClient.validateUpstreamTlsContext(upstreamTlsContext, null); } private static Filter buildHttpConnectionManagerFilter(HttpFilter... httpFilters) { diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 913eb208ad7..3a9ab23aa74 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -39,6 +39,8 @@ import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import io.envoyproxy.envoy.config.route.v3.FilterConfig; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.BindableService; import io.grpc.Context; @@ -56,6 +58,7 @@ import io.grpc.internal.TimeProvider; import io.grpc.testing.GrpcCleanupRule; import io.grpc.xds.AbstractXdsClient.ResourceType; +import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; @@ -243,6 +246,7 @@ public long currentTimeNanos() { private ManagedChannel channel; private ClientXdsClient xdsClient; private boolean originalEnableFaultInjection; + private boolean originalEnableRbac; @Before public void setUp() throws IOException { @@ -255,6 +259,9 @@ public void setUp() throws IOException { // Start the server and the client. originalEnableFaultInjection = ClientXdsClient.enableFaultInjection; ClientXdsClient.enableFaultInjection = true; + originalEnableRbac = ClientXdsClient.enableRbac; + assertThat(originalEnableRbac).isFalse(); + ClientXdsClient.enableRbac = true; final String serverName = InProcessServerBuilder.generateName(); cleanupRule.register( InProcessServerBuilder @@ -273,7 +280,8 @@ public void setUp() throws IOException { new Bootstrapper.ServerInfo( SERVER_URI, InsecureChannelCredentials.create(), useProtocolV3())), EnvoyProtoData.Node.newBuilder().build(), - null, + ImmutableMap.of("cert-instance-name", + new CertificateProviderInfo("file-watcher", ImmutableMap.of())), null); xdsClient = new ClientXdsClient( @@ -293,6 +301,7 @@ public void setUp() throws IOException { @After public void tearDown() { ClientXdsClient.enableFaultInjection = originalEnableFaultInjection; + ClientXdsClient.enableRbac = originalEnableRbac; xdsClient.shutdown(); channel.shutdown(); // channel not owned by XdsClient assertThat(adsEnded.get()).isTrue(); @@ -469,11 +478,11 @@ public void ldsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "LDS response Resource index 0 - can't decode Listener: ", "LDS response Resource index 2 - can't decode Listener: "); - verifyResourceMetadataNacked(LDS, LDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked(LDS, LDS_RESOURCE, testListenerRds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(LDS, LDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(ldsResourceWatcher); + verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); } /** @@ -513,14 +522,14 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.2")), "B", Any.pack(mf.buildListenerWithApiListenerInvalid("B"))); call.sendResponse(LDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B - // {C} -> ACK, version 1 + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); - verifyResourceMetadataNacked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); - verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataDoesNotExist(LDS, "C"); call.verifyRequestNack(LDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); // LDS -> {B, C} version 3 @@ -528,7 +537,7 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { "B", Any.pack(mf.buildListenerWithApiListenerForRds("B", "B.3")), "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.3"))); call.sendResponse(LDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> does not exist // {B, C} -> ACK, version 3 verifyResourceMetadataDoesNotExist(LDS, "A"); verifyResourceMetadataAcked(LDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); @@ -537,6 +546,73 @@ public void ldsResponseErrorHandling_subscribedResourceInvalid() { verifySubscribedResourcesMetadataSizes(3, 0, 0, 0); } + @Test + public void ldsResponseErrorHandling_subscribedResourceInvalid_withRdsSubscriptioin() { + List subscribedResourceNames = ImmutableList.of("A", "B", "C"); + xdsClient.watchLdsResource("A", ldsResourceWatcher); + xdsClient.watchRdsResource("A.1", rdsResourceWatcher); + xdsClient.watchLdsResource("B", ldsResourceWatcher); + xdsClient.watchRdsResource("B.1", rdsResourceWatcher); + xdsClient.watchLdsResource("C", ldsResourceWatcher); + xdsClient.watchRdsResource("C.1", rdsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(LDS, "A"); + verifyResourceMetadataRequested(LDS, "B"); + verifyResourceMetadataRequested(LDS, "C"); + verifyResourceMetadataRequested(RDS, "A.1"); + verifyResourceMetadataRequested(RDS, "B.1"); + verifyResourceMetadataRequested(RDS, "C.1"); + verifySubscribedResourcesMetadataSizes(3, 0, 3, 0); + + // LDS -> {A, B, C}, version 1 + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.1")), + "B", Any.pack(mf.buildListenerWithApiListenerForRds("B", "B.1")), + "C", Any.pack(mf.buildListenerWithApiListenerForRds("C", "C.1"))); + call.sendResponse(LDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A, B, C} -> ACK, version 1 + verifyResourceMetadataAcked(LDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(LDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + call.verifyRequest(LDS, subscribedResourceNames, VERSION_1, "0000", NODE); + + // RDS -> {A.1, B.1, C.1}, version 1 + List vhostsV1 = mf.buildOpaqueVirtualHosts(1); + ImmutableMap resourcesV11 = ImmutableMap.of( + "A.1", Any.pack(mf.buildRouteConfiguration("A.1", vhostsV1)), + "B.1", Any.pack(mf.buildRouteConfiguration("B.1", vhostsV1)), + "C.1", Any.pack(mf.buildRouteConfiguration("C.1", vhostsV1))); + call.sendResponse(RDS, resourcesV11.values().asList(), VERSION_1, "0000"); + // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceMetadataAcked(RDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(RDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); + + // LDS -> {A, B}, version 2 + // Failed to parse endpoint B + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildListenerWithApiListenerForRds("A", "A.2")), + "B", Any.pack(mf.buildListenerWithApiListenerInvalid("B"))); + call.sendResponse(LDS, resourcesV2.values().asList(), VERSION_2, "0001"); + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist + List errorsV2 = ImmutableList.of("LDS response Listener 'B' validation error: "); + verifyResourceMetadataAcked(LDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); + verifyResourceMetadataNacked( + LDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, + errorsV2); + verifyResourceMetadataDoesNotExist(LDS, "C"); + call.verifyRequestNack(LDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); + // {A.1} -> does not exist + // {B.1} -> version 1 + // {C.1} -> does not exist + verifyResourceMetadataDoesNotExist(RDS, "A.1"); + verifyResourceMetadataAcked(RDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataDoesNotExist(RDS, "C.1"); + } + @Test public void ldsResourceFound_containsVirtualHosts() { DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); @@ -655,7 +731,8 @@ public void ldsResourceUpdate_withFaultInjection() { mf.buildHttpFaultTypedConfig( 1L, 2, "cluster1", ImmutableList.of(), 3, null, null, null), - false)))); + false), + mf.buildHttpFilter("terminal", Any.pack(Router.newBuilder().build()), true)))); call.sendResponse(LDS, listener, VERSION_1, "0000"); // Client sends an ACK LDS request. @@ -802,11 +879,11 @@ public void rdsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "RDS response Resource index 0 - can't decode RouteConfiguration: ", "RDS response Resource index 2 - can't decode RouteConfiguration: "); - verifyResourceMetadataNacked(RDS, RDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked(RDS, RDS_RESOURCE, testRouteConfig, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // The response is NACKed with the same error message. call.verifyRequestNack(RDS, RDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(rdsResourceWatcher); + verify(rdsResourceWatcher).onChanged(any(RdsUpdate.class)); } /** @@ -847,12 +924,12 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { "A", Any.pack(mf.buildRouteConfiguration("A", mf.buildOpaqueVirtualHosts(2))), "B", Any.pack(mf.buildRouteConfigurationInvalid("B"))); call.sendResponse(RDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 List errorsV2 = ImmutableList.of("RDS response RouteConfiguration 'B' validation error: "); - verifyResourceMetadataNacked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(RDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); verifyResourceMetadataAcked(RDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -864,10 +941,9 @@ public void rdsResponseErrorHandling_subscribedResourceInvalid() { "B", Any.pack(mf.buildRouteConfiguration("B", vhostsV3)), "C", Any.pack(mf.buildRouteConfiguration("C", vhostsV3))); call.sendResponse(RDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 - verifyResourceMetadataNacked(RDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(RDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(RDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(RDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); call.verifyRequest(RDS, subscribedResourceNames, VERSION_3, "0002", NODE); @@ -990,7 +1066,7 @@ public void rdsResourcesDeletedByLdsTcpListener() { verifySubscribedResourcesMetadataSizes(1, 0, 1, 0); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - RDS_RESOURCE, null, Collections.emptyList()); + RDS_RESOURCE, null, Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "google-sds-config-default", "ROOTCA", false); Message filterChain = mf.buildFilterChain( @@ -1025,7 +1101,7 @@ public void rdsResourcesDeletedByLdsTcpListener() { null, mf.buildRouteConfiguration( "route-bar.googleapis.com", mf.buildOpaqueVirtualHosts(VHOST_SIZE)), - Collections.emptyList()); + Collections.singletonList(mf.buildTerminalFilter())); filterChain = mf.buildFilterChain( Collections.emptyList(), downstreamTlsContext, "envoy.transport_sockets.tls", hcmFilter); @@ -1141,11 +1217,12 @@ public void cdsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "CDS response Resource index 0 - can't decode Cluster: ", "CDS response Resource index 2 - can't decode Cluster: "); - verifyResourceMetadataNacked(CDS, CDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked( + CDS, CDS_RESOURCE, testClusterRoundRobin, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // The response is NACKed with the same error message. call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(cdsResourceWatcher); + verify(cdsResourceWatcher).onChanged(any(CdsUpdate.class)); } /** @@ -1193,14 +1270,14 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { )), "B", Any.pack(mf.buildClusterInvalid("B"))); call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B - // {C} -> ACK, version 1 + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); - verifyResourceMetadataNacked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); - verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataDoesNotExist(CDS, "C"); call.verifyRequestNack(CDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); // CDS -> {B, C} version 3 @@ -1212,7 +1289,7 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { "envoy.transport_sockets.tls", null ))); call.sendResponse(CDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> does not exit // {B, C} -> ACK, version 3 verifyResourceMetadataDoesNotExist(CDS, "A"); verifyResourceMetadataAcked(CDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); @@ -1220,6 +1297,82 @@ public void cdsResponseErrorHandling_subscribedResourceInvalid() { call.verifyRequest(CDS, subscribedResourceNames, VERSION_3, "0002", NODE); } + @Test + public void cdsResponseErrorHandling_subscribedResourceInvalid_withEdsSubscription() { + List subscribedResourceNames = ImmutableList.of("A", "B", "C"); + xdsClient.watchCdsResource("A", cdsResourceWatcher); + xdsClient.watchEdsResource("A.1", edsResourceWatcher); + xdsClient.watchCdsResource("B", cdsResourceWatcher); + xdsClient.watchEdsResource("B.1", edsResourceWatcher); + xdsClient.watchCdsResource("C", cdsResourceWatcher); + xdsClient.watchEdsResource("C.1", edsResourceWatcher); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); + assertThat(call).isNotNull(); + verifyResourceMetadataRequested(CDS, "A"); + verifyResourceMetadataRequested(CDS, "B"); + verifyResourceMetadataRequested(CDS, "C"); + verifyResourceMetadataRequested(EDS, "A.1"); + verifyResourceMetadataRequested(EDS, "B.1"); + verifyResourceMetadataRequested(EDS, "C.1"); + verifySubscribedResourcesMetadataSizes(0, 3, 0, 3); + + // CDS -> {A, B, C}, version 1 + ImmutableMap resourcesV1 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.1", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + )), + "B", Any.pack(mf.buildEdsCluster("B", "B.1", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + )), + "C", Any.pack(mf.buildEdsCluster("C", "C.1", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + ))); + call.sendResponse(CDS, resourcesV1.values().asList(), VERSION_1, "0000"); + // {A, B, C} -> ACK, version 1 + verifyResourceMetadataAcked(CDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT); + verifyResourceMetadataAcked(CDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); + call.verifyRequest(CDS, subscribedResourceNames, VERSION_1, "0000", NODE); + + // EDS -> {A.1, B.1, C.1}, version 1 + List dropOverloads = ImmutableList.of(); + List endpointsV1 = ImmutableList.of(lbEndpointHealthy); + ImmutableMap resourcesV11 = ImmutableMap.of( + "A.1", Any.pack(mf.buildClusterLoadAssignment("A.1", endpointsV1, dropOverloads)), + "B.1", Any.pack(mf.buildClusterLoadAssignment("B.1", endpointsV1, dropOverloads)), + "C.1", Any.pack(mf.buildClusterLoadAssignment("C.1", endpointsV1, dropOverloads))); + call.sendResponse(EDS, resourcesV11.values().asList(), VERSION_1, "0000"); + // {A.1, B.1, C.1} -> ACK, version 1 + verifyResourceMetadataAcked(EDS, "A.1", resourcesV11.get("A.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataAcked(EDS, "C.1", resourcesV11.get("C.1"), VERSION_1, TIME_INCREMENT * 2); + + // CDS -> {A, B}, version 2 + // Failed to parse endpoint B + ImmutableMap resourcesV2 = ImmutableMap.of( + "A", Any.pack(mf.buildEdsCluster("A", "A.2", "round_robin", null, false, null, + "envoy.transport_sockets.tls", null + )), + "B", Any.pack(mf.buildClusterInvalid("B"))); + call.sendResponse(CDS, resourcesV2.values().asList(), VERSION_2, "0001"); + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {C} -> does not exist + List errorsV2 = ImmutableList.of("CDS response Cluster 'B' validation error: "); + verifyResourceMetadataAcked(CDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 3); + verifyResourceMetadataNacked( + CDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 3, + errorsV2); + verifyResourceMetadataDoesNotExist(CDS, "C"); + call.verifyRequestNack(CDS, subscribedResourceNames, VERSION_1, "0001", NODE, errorsV2); + // {A.1} -> does not exist + // {B.1} -> version 1 + // {C.1} -> does not exist + verifyResourceMetadataDoesNotExist(EDS, "A.1"); + verifyResourceMetadataAcked(EDS, "B.1", resourcesV11.get("B.1"), VERSION_1, TIME_INCREMENT * 2); + verifyResourceMetadataDoesNotExist(EDS, "C.1"); + } + @Test public void cdsResourceFound() { DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); @@ -1319,6 +1472,7 @@ public void cdsResponseWithCircuitBreakers() { * CDS response containing UpstreamTlsContext for a cluster. */ @Test + @SuppressWarnings("deprecation") public void cdsResponseWithUpstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); @@ -1327,7 +1481,8 @@ public void cdsResponseWithUpstreamTlsContext() { Any clusterEds = Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", null, true, - mf.buildUpstreamTlsContext("secret1", "cert1"), "envoy.transport_sockets.tls", null)); + mf.buildUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null)); List clusters = ImmutableList.of( Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), @@ -1343,7 +1498,43 @@ public void cdsResponseWithUpstreamTlsContext() { CommonTlsContext.CertificateProviderInstance certificateProviderInstance = cdsUpdate.upstreamTlsContext().getCommonTlsContext().getCombinedValidationContext() .getValidationContextCertificateProviderInstance(); - assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("secret1"); + assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); + assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); + verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); + verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); + } + + /** + * CDS response containing new UpstreamTlsContext for a cluster. + */ + @Test + @SuppressWarnings("deprecation") + public void cdsResponseWithNewUpstreamTlsContext() { + Assume.assumeTrue(useProtocolV3()); + DiscoveryRpcCall call = startResourceWatcher(CDS, CDS_RESOURCE, cdsResourceWatcher); + + // Management server sends back CDS response with UpstreamTlsContext. + Any clusterEds = + Any.pack(mf.buildEdsCluster(CDS_RESOURCE, "eds-cluster-foo.googleapis.com", "round_robin", + null, true, + mf.buildNewUpstreamTlsContext("cert-instance-name", "cert1"), + "envoy.transport_sockets.tls", null)); + List clusters = ImmutableList.of( + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), + clusterEds, + Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, false, + null, "envoy.transport_sockets.tls", null))); + call.sendResponse(CDS, clusters, VERSION_1, "0000"); + + // Client sent an ACK CDS request. + call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); + verify(cdsResourceWatcher, times(1)).onChanged(cdsUpdateCaptor.capture()); + CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); + CertificateProviderPluginInstance certificateProviderInstance = + cdsUpdate.upstreamTlsContext().getCommonTlsContext().getValidationContext() + .getCaCertificateProviderInstance(); + assertThat(certificateProviderInstance.getInstanceName()).isEqualTo("cert-instance-name"); assertThat(certificateProviderInstance.getCertificateName()).isEqualTo("cert1"); verifyResourceMetadataAcked(CDS, CDS_RESOURCE, clusterEds, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); @@ -1369,7 +1560,7 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { "CDS response Cluster 'cluster.googleapis.com' validation error: " + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " + "io.grpc.xds.ClientXdsClient$ResourceInvalidException: " - + "combined_validation_context is required in upstream-tls-context")); + + "ca_certificate_provider_instance is required in upstream-tls-context")); verifyNoInteractions(cdsResourceWatcher); } @@ -1623,11 +1814,14 @@ public void edsResponseErrorHandling_someResourcesFailedUnpack() { List errors = ImmutableList.of( "EDS response Resource index 0 - can't decode ClusterLoadAssignment: ", "EDS response Resource index 2 - can't decode ClusterLoadAssignment: "); - verifyResourceMetadataNacked(EDS, EDS_RESOURCE, null, "", 0, VERSION_1, TIME_INCREMENT, errors); + verifyResourceMetadataAcked( + EDS, EDS_RESOURCE, testClusterLoadAssignment, VERSION_1, TIME_INCREMENT); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // The response is NACKed with the same error message. call.verifyRequestNack(EDS, EDS_RESOURCE, "", "0000", NODE, errors); - verifyNoInteractions(edsResourceWatcher); + verify(edsResourceWatcher).onChanged(edsUpdateCaptor.capture()); + EdsUpdate edsUpdate = edsUpdateCaptor.getValue(); + assertThat(edsUpdate.clusterName).isEqualTo(EDS_RESOURCE); } /** @@ -1670,12 +1864,12 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { "A", Any.pack(mf.buildClusterLoadAssignment("A", endpointsV2, dropOverloads)), "B", Any.pack(mf.buildClusterLoadAssignmentInvalid("B"))); call.sendResponse(EDS, resourcesV2.values().asList(), VERSION_2, "0001"); - // {A, B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 + // {B} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B // {C} -> ACK, version 1 List errorsV2 = ImmutableList.of("EDS response ClusterLoadAssignment 'B' validation error: "); - verifyResourceMetadataNacked(EDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataNacked(EDS, "B", resourcesV1.get("B"), VERSION_1, TIME_INCREMENT, VERSION_2, TIME_INCREMENT * 2, errorsV2); verifyResourceMetadataAcked(EDS, "C", resourcesV1.get("C"), VERSION_1, TIME_INCREMENT); @@ -1688,10 +1882,9 @@ public void edsResponseErrorHandling_subscribedResourceInvalid() { "B", Any.pack(mf.buildClusterLoadAssignment("B", endpointsV3, dropOverloads)), "C", Any.pack(mf.buildClusterLoadAssignment("C", endpointsV3, dropOverloads))); call.sendResponse(EDS, resourcesV3.values().asList(), VERSION_3, "0002"); - // {A} -> NACK, version 1, rejected version 2, rejected reason: Failed to parse B + // {A} -> ACK, version 2 // {B, C} -> ACK, version 3 - verifyResourceMetadataNacked(EDS, "A", resourcesV1.get("A"), VERSION_1, TIME_INCREMENT, - VERSION_2, TIME_INCREMENT * 2, errorsV2); + verifyResourceMetadataAcked(EDS, "A", resourcesV2.get("A"), VERSION_2, TIME_INCREMENT * 2); verifyResourceMetadataAcked(EDS, "B", resourcesV3.get("B"), VERSION_3, TIME_INCREMENT * 3); verifyResourceMetadataAcked(EDS, "C", resourcesV3.get("C"), VERSION_3, TIME_INCREMENT * 3); call.verifyRequest(EDS, subscribedResourceNames, VERSION_3, "0002", NODE); @@ -2162,7 +2355,8 @@ public void serverSideListenerFound() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "google-sds-config-default", "ROOTCA", false); Message filterChain = mf.buildFilterChain( @@ -2185,7 +2379,8 @@ public void serverSideListenerFound() { assertThat(parsedFilterChain.getFilterChainMatch().getApplicationProtocols()).isEmpty(); assertThat(parsedFilterChain.getHttpConnectionManager().rdsName()) .isEqualTo("route-foo.googleapis.com"); - assertThat(parsedFilterChain.getHttpConnectionManager().httpFilterConfigs()).isEmpty(); + assertThat(parsedFilterChain.getHttpConnectionManager().httpFilterConfigs().get(0).filterConfig) + .isEqualTo(RouterFilter.ROUTER_CONFIG); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @@ -2196,7 +2391,8 @@ public void serverSideListenerNotFound() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "google-sds-config-default", "ROOTCA", false); Message filterChain = mf.buildFilterChain( @@ -2222,7 +2418,8 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( null, null,false); Message filterChain = mf.buildFilterChain( @@ -2245,7 +2442,8 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( - "route-foo.googleapis.com", null, Collections.emptyList()); + "route-foo.googleapis.com", null, + Collections.singletonList(mf.buildTerminalFilter())); Message downstreamTlsContext = CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( "cert1", "cert2",false); Message filterChain = mf.buildFilterChain( @@ -2344,7 +2542,7 @@ protected abstract static class MessageFactory { /** Throws {@link InvalidProtocolBufferException} on {@link Any#unpack(Class)}. */ protected static final Any FAILING_ANY = Any.newBuilder().setTypeUrl("fake").build(); - protected final Message buildListenerWithApiListener(String name, Message routeConfiguration) { + protected Message buildListenerWithApiListener(String name, Message routeConfiguration) { return buildListenerWithApiListener( name, routeConfiguration, Collections.emptyList()); } @@ -2396,6 +2594,8 @@ protected abstract Message buildRingHashLbConfig(String hashFunction, long minRi protected abstract Message buildUpstreamTlsContext(String instanceName, String certName); + protected abstract Message buildNewUpstreamTlsContext(String instanceName, String certName); + protected abstract Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests); @@ -2427,5 +2627,7 @@ protected abstract Message buildListenerWithFilterChain( protected abstract Message buildHttpConnectionManagerFilter( @Nullable String rdsName, @Nullable Message routeConfig, List httpFilters); + + protected abstract Message buildTerminalFilter(); } } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java index 409613aecf7..1a69b6fc650 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java @@ -515,6 +515,12 @@ protected Message buildUpstreamTlsContext(String instanceName, String certName) .build(); } + @Override + protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { + return buildUpstreamTlsContext(instanceName, certName); + } + + @Override protected Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests) { @@ -635,6 +641,11 @@ protected Message buildHttpConnectionManagerFilter( @Nullable String rdsName, @Nullable Message routeConfig, List httpFilters) { throw new UnsupportedOperationException(); } + + @Override + protected Message buildTerminalFilter() { + throw new UnsupportedOperationException(); + } } /** diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index 0da6bf7bde5..6df36e1c31e 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -74,9 +74,12 @@ import io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.FaultAbort.HeaderAbort; import io.envoyproxy.envoy.extensions.filters.http.fault.v3.HTTPFault; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc.AggregatedDiscoveryServiceImplBase; @@ -274,6 +277,15 @@ protected Message buildListenerWithApiListener( .build(); } + @Override + protected Message buildListenerWithApiListener(String name, Message routeConfiguration) { + return buildListenerWithApiListener(name, routeConfiguration, Arrays.asList( + HttpFilter.newBuilder() + .setName("terminal") + .setTypedConfig(Any.pack(Router.newBuilder().build())).build() + )); + } + @Override protected Message buildListenerWithApiListenerForRds(String name, String rdsResourceName) { return Listener.newBuilder() @@ -289,6 +301,10 @@ protected Message buildListenerWithApiListenerForRds(String name, String rdsReso .setConfigSource( ConfigSource.newBuilder() .setAds(AggregatedConfigSource.getDefaultInstance()))) + .addHttpFilters( + HttpFilter.newBuilder() + .setName("terminal") + .setTypedConfig(Any.pack(Router.newBuilder().build()))) .build()))) .build(); } @@ -535,6 +551,7 @@ protected Message buildRingHashLbConfig(String hashFunction, long minRingSize, } @Override + @SuppressWarnings("deprecation") protected Message buildUpstreamTlsContext(String instanceName, String certName) { CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); if (instanceName != null && certName != null) { @@ -554,6 +571,20 @@ protected Message buildUpstreamTlsContext(String instanceName, String certName) .build(); } + @Override + protected Message buildNewUpstreamTlsContext(String instanceName, String certName) { + CommonTlsContext.Builder commonTlsContextBuilder = CommonTlsContext.newBuilder(); + if (instanceName != null && certName != null) { + commonTlsContextBuilder.setValidationContext(CertificateValidationContext.newBuilder() + .setCaCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder().setInstanceName(instanceName) + .setCertificateName(certName).build())); + } + return UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContextBuilder) + .build(); + } + @Override protected Message buildCircuitBreakers(int highPriorityMaxRequests, int defaultPriorityMaxRequests) { @@ -725,6 +756,13 @@ protected Message buildHttpConnectionManagerFilter( Any.pack(hcmBuilder.build(), "type.googleapis.com")) .build(); } + + @Override + protected Message buildTerminalFilter() { + return HttpFilter.newBuilder() + .setName("terminal") + .setTypedConfig(Any.pack(Router.newBuilder().build())).build(); + } } /** diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 74aa85501a9..dfcf101fcf5 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -341,8 +341,8 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); - streamTracerFactory.newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), - new Metadata()); + streamTracerFactory.newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); } ClusterStats clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); @@ -429,8 +429,8 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( PickResult result = currentPicker.pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(result.getStatus().isOk()).isTrue(); ClientStreamTracer.Factory streamTracerFactory = result.getStreamTracerFactory(); - streamTracerFactory.newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), - new Metadata()); + streamTracerFactory.newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); } ClusterStats clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); @@ -480,16 +480,16 @@ public void endpointAddressesAttachedWithClusterName() { } @Test - public void endpointAddressesAttachedWithTlsConfig_enableSecurity() { + public void endpointAddressesAttachedWithTlsConfig_disableSecurity() { boolean originalEnableSecurity = ClusterImplLoadBalancer.enableSecurity; - ClusterImplLoadBalancer.enableSecurity = true; - subtest_endpointAddressesAttachedWithTlsConfig(true); + ClusterImplLoadBalancer.enableSecurity = false; + subtest_endpointAddressesAttachedWithTlsConfig(false); ClusterImplLoadBalancer.enableSecurity = originalEnableSecurity; } @Test - public void endpointAddressesAttachedWithTlsConfig_securityDisabledByDefault() { - subtest_endpointAddressesAttachedWithTlsConfig(false); + public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { + subtest_endpointAddressesAttachedWithTlsConfig(true); } private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecurity) { diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java deleted file mode 100644 index 8ee3e87a242..00000000000 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java +++ /dev/null @@ -1,941 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.XdsClient.LdsUpdate; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.netty.channel.Channel; -import java.io.IOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.UnknownHostException; -import java.util.Arrays; -import java.util.Collections; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Tests for {@link XdsClientWrapperForServerSds}. */ -@RunWith(JUnit4.class) -public class FilterChainMatchTest { - - private static final int PORT = 7000; - private static final String LOCAL_IP = "10.1.2.3"; // dest - private static final String REMOTE_IP = "10.4.2.3"; // source - private static final HttpConnectionManager HTTP_CONNECTION_MANAGER = - HttpConnectionManager.forRdsName( - 10L, "route-config", Collections.emptyList()); - - @Mock private Channel channel; - @Mock private TlsContextManager tlsContextManager; - - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsClient.LdsResourceWatcher registeredWatcher; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(PORT, tlsContextManager); - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - } - - @After - public void tearDown() { - xdsClientWrapperForServerSds.shutdown(); - } - - private EnvoyServerProtoData.DownstreamTlsContext getDownstreamTlsContext() { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); - if (sslContextProviderSupplier != null) { - EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); - assertThat(tlsContext).isInstanceOf(EnvoyServerProtoData.DownstreamTlsContext.class); - return (EnvoyServerProtoData.DownstreamTlsContext) tlsContext; - } - return null; - } - - @Test - public void singleFilterChainWithoutAlpn() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.FilterChainMatch filterChainMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContext); - } - - @Test - public void singleFilterChainWithAlpn() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.FilterChainMatch filterChainMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList("managed-mtls"), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, - tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext defaultTlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, - tlsContextManager); - EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChain), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(defaultTlsContext); - } - - @Test - public void defaultFilterChain() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", null, HTTP_CONNECTION_MANAGER, tlsContext, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(), filterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContext); - } - - @Test - public void destPortFails_returnDefaultFilterChain() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextWithDestPort = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithDestPort = - new EnvoyServerProtoData.FilterChainMatch( - PORT, - Arrays.asList(), - Arrays.asList("managed-mtls"), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithDestPort = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithDestPort, HTTP_CONNECTION_MANAGER, - tlsContextWithDestPort, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithDestPort), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); - } - - @Test - public void destPrefixRangeMatch() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, - tlsContextMatch, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); - } - - @Test - public void destPrefixRangeMismatch_returnDefaultFilterChain() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - // 10.2.2.0/24 doesn't match LOCAL_IP - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 24)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, - tlsContextMismatch, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); - } - - @Test - public void dest0LengthPrefixRange() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext0Length = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - // 10.2.2.0/24 doesn't match LOCAL_IP - EnvoyServerProtoData.FilterChainMatch filterChainMatch0Length = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 0)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain0Length = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch0Length, HTTP_CONNECTION_MANAGER, - tlsContext0Length, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, - tlsContextForDefaultFilterChain, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChain0Length), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContext0Length); - } - - @Test - public void destPrefixRange_moreSpecificWins() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecific, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); - } - - @Test - public void destPrefixRange_emptyListLessSpecific() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("8.0.0.0", 5)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecific, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); - } - - @Test - public void destPrefixRangeIpv6_moreSpecificWins() - throws UnknownHostException { - setupChannel("FE80:0000:0000:0000:0202:B3FF:FE1E:8329", "2001:DB8::8:800:200C:417A", 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0:0:0:0:0:0:0", 60)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0000:0000:0000:0202:0:0:0", 80)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecific, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - "FE80:0000:0000:0000:0202:B3FF:FE1E:8329", - Arrays.asList(filterChainLessSpecific, filterChainMoreSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); - } - - @Test - public void destPrefixRange_moreSpecificWith2Wins() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.1.2.0", 24), - new EnvoyServerProtoData.CidrRange(LOCAL_IP, 32)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecificWith2, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainMoreSpecificWith2, filterChainLessSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); - } - - @Test - public void sourceTypeMismatch_returnDefaultFilterChain() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, - tlsContextMismatch, tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); - } - - @Test - public void sourceTypeLocal() throws UnknownHostException { - setupChannel(LOCAL_IP, LOCAL_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, tlsContextMatch, - tlsContextManager); - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); - } - - @Test - public void sourcePrefixRange_moreSpecificWith2Wins() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, - tlsContextMoreSpecificWith2, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, - tlsContextLessSpecific, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainMoreSpecificWith2, filterChainLessSpecific), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); - } - - @Test - public void sourcePrefixRange_2Matchers_expectException() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange("192.168.10.2", 32)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, - tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", LOCAL_IP, Arrays.asList(filterChain1, filterChain2), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - try { - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); - fail("expect exception!"); - } catch (IllegalStateException ise) { - assertThat(ise).hasMessageThat().isEqualTo("Found 2 matching filter-chains"); - } - } - - @Test - public void sourcePortMatch_exactMatchWinsOverEmptyList() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContextEmptySourcePorts = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchEmptySourcePorts = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = - new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatchEmptySourcePorts, HTTP_CONNECTION_MANAGER, - tlsContextEmptySourcePorts, tlsContextManager); - - EnvoyServerProtoData.DownstreamTlsContext tlsContextSourcePortMatch = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.FilterChainMatch filterChainMatchSourcePortMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(7000, 15000), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = - new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatchSourcePortMatch, HTTP_CONNECTION_MANAGER, - tlsContextSourcePortMatch, tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList(filterChainEmptySourcePorts, filterChainSourcePortMatch), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); - assertThat(tlsContext1).isSameInstanceAs(tlsContextSourcePortMatch); - } - - /** - * Create 6 filterChains: - 1st filter chain has dest port & specific prefix range but is - * eliminated due to dest port - 5 advance to next step: 1 is eliminated due to being less - * specific than the remaining 4. - 4 advance to 3rd step: source type external eliminates one - * with local source_type. - 3 advance to 4th step: more specific 2 get picked based on - * source-prefix range. - 5th step: out of 2 one with matching source port gets picked - */ - @Test - public void filterChain_5stepMatch() throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext4 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT4", "VA4"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext5 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT5", "VA5"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext6 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT6", "VA6"); - - // has dest port and specific prefix ranges: gets eliminated in step 1 - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - PORT, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-1", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, - tlsContextManager); - - // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 - - // has single prefix range: and less specific source prefix range: gets eliminated in step 4 - EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.0.0", 16)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( - "filter-chain-2", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, - tlsContextManager); - - // has prefix ranges with one not matching and source type local: gets eliminated in step 3 - EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("192.168.2.0", 24), - new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain3 = new EnvoyServerProtoData.FilterChain( - "filter-chain-3", filterChainMatch3, HTTP_CONNECTION_MANAGER, tlsContext3, - tlsContextManager); - - // has prefix ranges with both matching and source type external but non matching source port: - // gets eliminated in step 5 - EnvoyServerProtoData.FilterChainMatch filterChainMatch4 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), - new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), - EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, - Arrays.asList(16000, 9000), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain4 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-4", filterChainMatch4, HTTP_CONNECTION_MANAGER, tlsContext4, - tlsContextManager); - - // has prefix ranges with both matching and source type external and matching source port: this - // gets selected - EnvoyServerProtoData.FilterChainMatch filterChainMatch5 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), - new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), - Arrays.asList(), - Arrays.asList( - new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), - new EnvoyServerProtoData.CidrRange("192.168.2.0", 24)), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(15000, 8000), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain5 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-5", filterChainMatch5, HTTP_CONNECTION_MANAGER, tlsContext5, - tlsContextManager); - - // has prefix range with prefixLen of 29: gets eliminated in step 2 - EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 29)), - Arrays.asList(), - Arrays.asList(), - EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(), - Arrays.asList(), - null); - EnvoyServerProtoData.FilterChain filterChain6 = - new EnvoyServerProtoData.FilterChain( - "filter-chain-6", filterChainMatch6, HTTP_CONNECTION_MANAGER, tlsContext6, - tlsContextManager); - - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - "listener1", - LOCAL_IP, - Arrays.asList( - filterChain1, filterChain2, filterChain3, filterChain4, filterChain5, filterChain6), - defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); - assertThat(tlsContextPicked).isSameInstanceAs(tlsContext5); - } - - @Test - public void filterChainMatch_unsupportedMatchers() - throws UnknownHostException { - setupChannel(LOCAL_IP, REMOTE_IP, 15000); - EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "ROOTCA"); - EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "ROOTCA"); - - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - 0 /* destinationPort */, - Collections.singletonList( - new EnvoyServerProtoData.CidrRange("10.1.0.0", 16)) /* prefixRange */, - Arrays.asList("managed-mtls", "h2") /* applicationProtocol */, - Collections.emptyList() /* sourcePrefixRanges */, - EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, - Collections.emptyList() /* sourcePorts */, - Arrays.asList("server1", "server2") /* serverNames */, - "tls" /* transportProtocol */); - - EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = - new EnvoyServerProtoData.FilterChainMatch( - 0 /* destinationPort */, - Collections.singletonList( - new EnvoyServerProtoData.CidrRange("10.0.0.0", 8)) /* prefixRange */, - Collections.emptyList() /* applicationProtocol */, - Collections.emptyList() /* sourcePrefixRanges */, - EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, - Collections.emptyList() /* sourcePorts */, - Collections.emptyList() /* serverNames */, - "" /* transportProtocol */); - - EnvoyServerProtoData.FilterChainMatch defaultFilterChainMatch = - new EnvoyServerProtoData.FilterChainMatch( - 0 /* destinationPort */, - Collections.emptyList() /* prefixRange */, - Collections.emptyList() /* applicationProtocol */, - Collections.emptyList() /* sourcePrefixRanges */, - EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, - Collections.emptyList() /* sourcePorts */, - Collections.emptyList() /* serverNames */, - "" /* transportProtocol */); - - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, - mock(TlsContextManager.class)); - EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, - mock(TlsContextManager.class)); - - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, - mock(TlsContextManager.class)); - - EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( - "", "10.2.1.34:8000", Arrays.asList(filterChain1, filterChain2), defaultFilterChain); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - EnvoyServerProtoData.DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); - // assert defaultFilterChain match - assertThat(tlsContextPicked.getCommonTlsContext().getTlsCertificateCertificateProviderInstance() - .getCertificateName()).isEqualTo("CERT3"); - } - - private void setupChannel(String localIp, String remoteIp, int remotePort) - throws UnknownHostException { - when(channel.localAddress()) - .thenReturn(new InetSocketAddress(InetAddress.getByName(localIp), PORT)); - when(channel.remoteAddress()) - .thenReturn(new InetSocketAddress(InetAddress.getByName(remoteIp), remotePort)); - } -} diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java new file mode 100644 index 00000000000..b223516465f --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchingProtocolNegotiatorsTest.java @@ -0,0 +1,1233 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.ServerInterceptor; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.ProtocolNegotiationEvent; +import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.DefaultHttp2FrameReader; +import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class FilterChainMatchingProtocolNegotiatorsTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + private final GrpcHttp2ConnectionHandler grpcHandler = + FakeGrpcHttp2ConnectionHandler.newHandler(); + @Mock private TlsContextManager tlsContextManager; + private ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + private ChannelPipeline pipeline; + private EmbeddedChannel channel; + private ChannelHandlerContext channelHandlerCtx; + @Mock + private ProtocolNegotiator mockDelegate; + private FilterChainSelectorManager selectorManager = new FilterChainSelectorManager(); + private static final HttpConnectionManager HTTP_CONNECTION_MANAGER = createRds("routing-config"); + private static final String LOCAL_IP = "10.1.2.3"; // dest + private static final String REMOTE_IP = "10.4.2.3"; // source + private static final int PORT = 7000; + private final AtomicReference noopConfig = new AtomicReference<>( + ServerRoutingConfig.create(ImmutableList.of(), + ImmutableMap.of())); + final SettableFuture sslSet = SettableFuture.create(); + final SettableFuture> routingSettable = + SettableFuture.create(); + + @After + @SuppressWarnings("FutureReturnValueIgnored") + public void tearDown() { + if (channel.isActive()) { + channel.close(); + channel.runPendingTasks(); + } + assertThat(selectorManager.getRegisterCount()).isEqualTo(0); + } + + @Test + public void nofilterChainMatch_defaultSslContext() throws Exception { + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + SslContextProviderSupplier defaultSsl = new SslContextProviderSupplier(createTls(), + tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + new HashMap>(), + defaultSsl, noopConfig)); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel("172.168.1.1", "172.168.1.2", 80, filterChainMatchingHandler); + ChannelHandlerContext channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + pipeline.fireUserEventTriggered(event); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNull(); + + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(defaultSsl); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + channelHandlerCtx = pipeline.context(next); + assertThat(channelHandlerCtx).isNotNull(); + } + + @Test + public void noFilterChainMatch_noDefaultSslContext() { + selectorManager.updateSelector(new FilterChainSelector( + new HashMap>(), + null, new AtomicReference())); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + assertThat(channel.closeFuture().isDone()).isFalse(); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(channel.closeFuture().isDone()).isTrue(); + } + + @Test + public void filterSelectorChange_drainsConnection() { + ChannelHandler next = new ChannelInboundHandlerAdapter(); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + selectorManager.updateSelector(new FilterChainSelector( + new HashMap>(), null, noopConfig)); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel("172.168.1.1", "172.168.2.2", 90, filterChainMatchingHandler); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + + pipeline.fireUserEventTriggered(event); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNull(); + + channel.runPendingTasks(); + channelHandlerCtx = pipeline.context(next); + assertThat(channelHandlerCtx).isNotNull(); + assertThat(channel.readOutbound()).isNull(); + + selectorManager.updateSelector(new FilterChainSelector( + new HashMap>(), null, noopConfig)); + assertThat(channel.readOutbound().getClass().getName()) + .isEqualTo("io.grpc.netty.GracefulServerCloseCommand"); + } + + @Test + public void singleFilterChainWithoutAlpn() throws Exception { + EnvoyServerProtoData.FilterChainMatch filterChainMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.DownstreamTlsContext tlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, + tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector(ImmutableMap.of(filterChain, noopConfig), + null, new AtomicReference())); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(filterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext); + } + + @Test + public void singleFilterChainWithAlpn() throws Exception { + EnvoyServerProtoData.FilterChainMatch filterChainMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList("managed-mtls"), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.DownstreamTlsContext tlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext, + tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext defaultTlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, defaultTlsContext, + tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChain, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(defaultTlsContext); + } + + @Test + public void destPortFails_returnDefaultFilterChain() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextWithDestPort = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithDestPort = + new EnvoyServerProtoData.FilterChainMatch( + PORT, + Arrays.asList(), + Arrays.asList("managed-mtls"), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithDestPort = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithDestPort, HTTP_CONNECTION_MANAGER, + tlsContextWithDestPort, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + + ServerRoutingConfig routingConfig = ServerRoutingConfig.create( + ImmutableList.of(createVirtualHost("virtual")), + ImmutableMap.of()); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainWithDestPort, + new AtomicReference(routingConfig)), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()) + .isSameInstanceAs(tlsContextForDefaultFilterChain); + } + + @Test + public void destPrefixRangeMatch() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, + tlsContextMatch, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainWithMatch, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("no-match"))); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChainWithMatch.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMatch); + } + + @Test + public void destPrefixRangeMismatch_returnDefaultFilterChain() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + // 10.2.2.0/24 doesn't match LOCAL_IP + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 24)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMismatch = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, + tlsContextMismatch, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.isDone()).isTrue(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); + } + + @Test + public void dest0LengthPrefixRange() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext0Length = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + // 10.2.2.0/24 doesn't match LOCAL_IP + EnvoyServerProtoData.FilterChainMatch filterChainMatch0Length = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.2.2.0", 0)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain0Length = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch0Length, HTTP_CONNECTION_MANAGER, + tlsContext0Length, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, + tlsContextForDefaultFilterChain, tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChain0Length, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), + new AtomicReference())); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChain0Length.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext0Length); + } + + @Test + public void destPrefixRange_moreSpecificWins() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 24)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecific, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), + filterChainMoreSpecific, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChainMoreSpecific.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); + } + + @Test + public void destPrefixRange_emptyListLessSpecific() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("8.0.0.0", 5)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecific, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), + filterChainMoreSpecific, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChainMoreSpecific.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); + } + + @Test + public void destPrefixRangeIpv6_moreSpecificWins() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0:0:0:0:0:0:0", 60)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("FE80:0000:0000:0000:0202:0:0:0", 80)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchMoreSpecific, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecific, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainLessSpecific, randomConfig("no-match"), + filterChainMoreSpecific, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + setupChannel("FE80:0000:0000:0000:0202:B3FF:FE1E:8329", "2001:DB8::8:800:200C:417A", + 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChainMoreSpecific.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecific); + } + + @Test + public void destPrefixRange_moreSpecificWith2Wins() + throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.1.2.0", 24), + new EnvoyServerProtoData.CidrRange(LOCAL_IP, 32)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecificWith2, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.2", 31)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, + filterChainLessSpecific, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo( + filterChainMoreSpecificWith2.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecificWith2); + } + + @Test + public void sourceTypeMismatch_returnDefaultFilterChain() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextMismatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMismatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMismatch = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMismatch, HTTP_CONNECTION_MANAGER, + tlsContextMismatch, tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER,tlsContextForDefaultFilterChain, + tlsContextManager); + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainWithMismatch, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextForDefaultFilterChain); + } + + @Test + public void sourceTypeLocal() throws Exception { + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + EnvoyServerProtoData.DownstreamTlsContext tlsContextMatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchWithMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainWithMatch = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchWithMatch, HTTP_CONNECTION_MANAGER, tlsContextMatch, + tlsContextManager); + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, HTTP_CONNECTION_MANAGER, tlsContextForDefaultFilterChain, + tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainWithMatch, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel(LOCAL_IP, LOCAL_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChainWithMatch.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMatch); + } + + @Test + public void sourcePrefixRange_moreSpecificWith2Wins() + throws Exception { + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextMoreSpecificWith2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchMoreSpecificWith2, HTTP_CONNECTION_MANAGER, + tlsContextMoreSpecificWith2, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextLessSpecific = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchLessSpecific = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainLessSpecific = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchLessSpecific, HTTP_CONNECTION_MANAGER, + tlsContextLessSpecific, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainMoreSpecificWith2, noopConfig, + filterChainLessSpecific, randomConfig("no-match")), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo( + filterChainMoreSpecificWith2.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextMoreSpecificWith2); + } + + @Test + public void sourcePrefixRange_2Matchers_expectException() + throws UnknownHostException { + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + } + }; + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + + EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange("192.168.10.2", 32)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, + tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, null); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChain1, noopConfig, filterChain2, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + try { + channel.checkException(); + fail("expect exception!"); + } catch (IllegalStateException ise) { + assertThat(ise).hasMessageThat().isEqualTo("Found more than one matching filter chains. This " + + "should not be possible as ClientXdsClient validated the chains for uniqueness."); + assertThat(sslSet.isDone()).isFalse(); + channelHandlerCtx = pipeline.context(filterChainMatchingHandler); + assertThat(channelHandlerCtx).isNotNull(); + } + } + + @Test + public void sourcePortMatch_exactMatchWinsOverEmptyList() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContextEmptySourcePorts = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchEmptySourcePorts = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = + new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatchEmptySourcePorts, HTTP_CONNECTION_MANAGER, + tlsContextEmptySourcePorts, tlsContextManager); + + EnvoyServerProtoData.DownstreamTlsContext tlsContextSourcePortMatch = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChainMatch filterChainMatchSourcePortMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(7000, 15000), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = + new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatchSourcePortMatch, HTTP_CONNECTION_MANAGER, + tlsContextSourcePortMatch, tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChainEmptySourcePorts, randomConfig("no-match"), + filterChainSourcePortMatch, noopConfig), + defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChainSourcePortMatch.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContextSourcePortMatch); + } + + /** + * Create 6 filterChains: - 1st filter chain has dest port & specific prefix range but is + * eliminated due to dest port - 5 advance to next step: 1 is eliminated due to being less + * specific than the remaining 4. - 4 advance to 3rd step: source type external eliminates one + * with local source_type. - 3 advance to 4th step: more specific 2 get picked based on + * source-prefix range. - 5th step: out of 2 one with matching source port gets picked + */ + @Test + public void filterChain_5stepMatch() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext4 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT4", "VA4"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext5 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT5", "VA5"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext6 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT6", "VA6"); + + // has dest port and specific prefix ranges: gets eliminated in step 1 + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + PORT, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-1", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, + tlsContextManager); + + // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 + + // has single prefix range: and less specific source prefix range: gets eliminated in step 4 + EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.0.0", 16)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( + "filter-chain-2", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, + tlsContextManager); + + // has prefix ranges with one not matching and source type local: gets eliminated in step 3 + EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("192.168.2.0", 24), + new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain3 = new EnvoyServerProtoData.FilterChain( + "filter-chain-3", filterChainMatch3, HTTP_CONNECTION_MANAGER, tlsContext3, + tlsContextManager); + + // has prefix ranges with both matching and source type external but non matching source port: + // gets eliminated in step 5 + EnvoyServerProtoData.FilterChainMatch filterChainMatch4 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), + new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), + EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, + Arrays.asList(16000, 9000), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain4 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-4", filterChainMatch4, HTTP_CONNECTION_MANAGER, tlsContext4, + tlsContextManager); + + // has prefix ranges with both matching and source type external and matching source port: this + // gets selected + EnvoyServerProtoData.FilterChainMatch filterChainMatch5 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.1.0.0", 16), + new EnvoyServerProtoData.CidrRange("10.1.2.0", 30)), + Arrays.asList(), + Arrays.asList( + new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), + new EnvoyServerProtoData.CidrRange("192.168.2.0", 24)), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(15000, 8000), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain5 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-5", filterChainMatch5, HTTP_CONNECTION_MANAGER, tlsContext5, + tlsContextManager); + + // has prefix range with prefixLen of 29: gets eliminated in step 2 + EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(new EnvoyServerProtoData.CidrRange("10.1.2.0", 29)), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + EnvoyServerProtoData.FilterChain filterChain6 = + new EnvoyServerProtoData.FilterChain( + "filter-chain-6", filterChainMatch6, HTTP_CONNECTION_MANAGER, tlsContext6, + tlsContextManager); + + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-7", null, HTTP_CONNECTION_MANAGER, null, tlsContextManager); + + Map> map = new HashMap<>(); + map.put(filterChain1, randomConfig("1")); + map.put(filterChain2, randomConfig("2")); + map.put(filterChain3, randomConfig("3")); + map.put(filterChain4, randomConfig("4")); + map.put(filterChain5, noopConfig); + map.put(filterChain6, randomConfig("6")); + selectorManager.updateSelector(new FilterChainSelector( + map, defaultFilterChain.getSslContextProviderSupplier(), randomConfig("default"))); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(filterChain5.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext()).isSameInstanceAs(tlsContext5); + } + + @Test + @SuppressWarnings("deprecation") + public void filterChainMatch_unsupportedMatchers() throws Exception { + EnvoyServerProtoData.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "ROOTCA"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "ROOTCA"); + EnvoyServerProtoData.DownstreamTlsContext tlsContext3 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "ROOTCA"); + + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + 0 /* destinationPort */, + Collections.singletonList( + new EnvoyServerProtoData.CidrRange("10.1.0.0", 16)) /* prefixRange */, + Arrays.asList("managed-mtls", "h2") /* applicationProtocol */, + Collections.emptyList() /* sourcePrefixRanges */, + EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, + Collections.emptyList() /* sourcePorts */, + Arrays.asList("server1", "server2") /* serverNames */, + "tls" /* transportProtocol */); + + EnvoyServerProtoData.FilterChainMatch filterChainMatch2 = + new EnvoyServerProtoData.FilterChainMatch( + 0 /* destinationPort */, + Collections.singletonList( + new EnvoyServerProtoData.CidrRange("10.0.0.0", 8)) /* prefixRange */, + Collections.emptyList() /* applicationProtocol */, + Collections.emptyList() /* sourcePrefixRanges */, + EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, + Collections.emptyList() /* sourcePorts */, + Collections.emptyList() /* serverNames */, + "" /* transportProtocol */); + + EnvoyServerProtoData.FilterChainMatch defaultFilterChainMatch = + new EnvoyServerProtoData.FilterChainMatch( + 0 /* destinationPort */, + Collections.emptyList() /* prefixRange */, + Collections.emptyList() /* applicationProtocol */, + Collections.emptyList() /* sourcePrefixRanges */, + EnvoyServerProtoData.ConnectionSourceType.ANY /* sourceType */, + Collections.emptyList() /* sourcePorts */, + Collections.emptyList() /* serverNames */, + "" /* transportProtocol */); + + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch1, HTTP_CONNECTION_MANAGER, tlsContext1, + mock(TlsContextManager.class)); + EnvoyServerProtoData.FilterChain filterChain2 = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", filterChainMatch2, HTTP_CONNECTION_MANAGER, tlsContext2, + mock(TlsContextManager.class)); + + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-baz", defaultFilterChainMatch, HTTP_CONNECTION_MANAGER, tlsContext3, + mock(TlsContextManager.class)); + + selectorManager.updateSelector(new FilterChainSelector( + ImmutableMap.of(filterChain1, randomConfig("1"), filterChain2, randomConfig("2")), + defaultFilterChain.getSslContextProviderSupplier(), noopConfig)); + + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, selectorManager, mockDelegate); + ChannelHandler next = captureAttrHandler(sslSet, routingSettable); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + setupChannel(LOCAL_IP, REMOTE_IP, 15000, filterChainMatchingHandler); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + assertThat(sslSet.get()).isEqualTo(defaultFilterChain.getSslContextProviderSupplier()); + assertThat(routingSettable.get()).isEqualTo(noopConfig); + assertThat(sslSet.get().getTlsContext().getCommonTlsContext() + .getTlsCertificateCertificateProviderInstance() + .getCertificateName()).isEqualTo("CERT3"); + } + + private static HttpConnectionManager createRds(String name) { + return HttpConnectionManager.forRdsName(0L, name, + new ArrayList()); + } + + private static VirtualHost createVirtualHost(String name) { + return VirtualHost.create( + name, Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + } + + private static AtomicReference randomConfig(String domain) { + return new AtomicReference<>( + ServerRoutingConfig.create(ImmutableList.of(createVirtualHost(domain)), + ImmutableMap.of()) + ); + } + + private EnvoyServerProtoData.DownstreamTlsContext createTls() { + return DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext + .getDefaultInstance()); + } + + private void setupChannel(final String localIp, final String remoteIp, final int remotePort, + FilterChainMatchingHandler matchingHandler) { + channel = + new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return new InetSocketAddress(localIp, 80); + } + + @Override + public SocketAddress remoteAddress() { + return new InetSocketAddress(remoteIp, remotePort); + } + }; + pipeline = channel.pipeline(); + pipeline.addLast(matchingHandler); + } + + private static ChannelHandler captureAttrHandler( + final SettableFuture sslSet, + final SettableFuture> routingSettable) { + return new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + routingSettable.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_ROUTING_CONFIG)); + } + }; + } + + private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { + FakeGrpcHttp2ConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); + } + + static FakeGrpcHttp2ConnectionHandler newHandler() { + DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false); + DefaultHttp2ConnectionEncoder encoder = + new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader()); + Http2Settings settings = new Http2Settings(); + return new FakeGrpcHttp2ConnectionHandler( + /*channelUnused=*/ null, decoder, encoder, settings); + } + + @Override + public String getAuthority() { + return "authority"; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java b/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java new file mode 100644 index 00000000000..a3a2218d4c3 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/FilterChainSelectorManagerTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.grpc.ServerInterceptor; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.FilterChainSelectorManager.Closer; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class FilterChainSelectorManagerTest { + private FilterChainSelectorManager manager = new FilterChainSelectorManager(); + private AtomicReference noopConfig = new AtomicReference<>( + ServerRoutingConfig.create(ImmutableList.of(), + ImmutableMap.of())); + private FilterChainSelector selector1 = new FilterChainSelector( + Collections.>emptyMap(), + null, new AtomicReference()); + private FilterChainSelector selector2 = new FilterChainSelector( + Collections.>emptyMap(), + null, noopConfig); + private CounterRunnable runnable1 = new CounterRunnable(); + private CounterRunnable runnable2 = new CounterRunnable(); + + @Test + public void updateSelector_changesSelector() { + assertThat(manager.getSelectorToUpdateSelector()).isNull(); + assertThat(manager.register(new Closer(runnable1))).isNull(); + + manager.updateSelector(selector1); + + assertThat(runnable1.counter).isEqualTo(1); + assertThat(manager.getSelectorToUpdateSelector()).isSameInstanceAs(selector1); + assertThat(manager.register(new Closer(runnable2))).isSameInstanceAs(selector1); + assertThat(runnable2.counter).isEqualTo(0); + } + + @Test + public void updateSelector_callsCloserOnce() { + assertThat(manager.register(new Closer(runnable1))).isNull(); + + manager.updateSelector(selector1); + manager.updateSelector(selector2); + + assertThat(runnable1.counter).isEqualTo(1); + } + + @Test + public void deregister_removesCloser() { + Closer closer1 = new Closer(runnable1); + manager.updateSelector(selector1); + assertThat(manager.register(closer1)).isSameInstanceAs(selector1); + assertThat(manager.getRegisterCount()).isEqualTo(1); + + manager.deregister(closer1); + + assertThat(manager.getRegisterCount()).isEqualTo(0); + manager.updateSelector(selector2); + assertThat(runnable1.counter).isEqualTo(0); + } + + @Test + public void deregister_removesCorrectCloser() { + Closer closer1 = new Closer(runnable1); + Closer closer2 = new Closer(runnable2); + manager.updateSelector(selector1); + assertThat(manager.register(closer1)).isSameInstanceAs(selector1); + assertThat(manager.register(closer2)).isSameInstanceAs(selector1); + assertThat(manager.getRegisterCount()).isEqualTo(2); + + manager.deregister(closer1); + + assertThat(manager.getRegisterCount()).isEqualTo(1); + manager.updateSelector(selector2); + assertThat(runnable1.counter).isEqualTo(0); + assertThat(runnable2.counter).isEqualTo(1); + } + + private static class CounterRunnable implements Runnable { + int counter; + + @Override public void run() { + counter++; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java b/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java index 07e957b24c4..421b2a1dd0a 100644 --- a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java @@ -189,7 +189,7 @@ public void onGcpAndNoProvidedBootstrapDelegateToXds() { Map server = Iterables.getOnlyElement( (List>) bootstrap.get("xds_servers")); assertThat(server).containsExactly( - "server_uri", "directpath-trafficdirector.googleapis.com", + "server_uri", "directpath-pa.googleapis.com", "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), "server_features", ImmutableList.of("xds_v3")); } diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index c5fe7b3d1bd..082c49ef665 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -41,6 +41,7 @@ import io.envoyproxy.envoy.type.matcher.v3.MetadataMatcher; import io.envoyproxy.envoy.type.matcher.v3.PathMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.Int32Range; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; @@ -109,6 +110,33 @@ public void ipPortParser() { assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); } + @Test + @SuppressWarnings({"unchecked", "deprecation"}) + public void portRangeParser() { + List permissionList = Arrays.asList( + Permission.newBuilder().setDestinationPortRange( + Int32Range.newBuilder().setStart(1010).setEnd(65535).build() + ).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setRemoteIp( + CidrRange.newBuilder().setAddressPrefix("10.10.10.0") + .setPrefixLen(UInt32Value.of(24)).build() + ).build()); + ConfigOrError result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + ServerCall serverCall = mock(ServerCall.class); + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new InetSocketAddress("10.10.10.0", 1)) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress("10.10.10.0",9090)) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(method().build()); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(((RbacConfig)result.config).authConfig()); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); + } + @Test @SuppressWarnings("unchecked") public void pathParser() { @@ -155,7 +183,7 @@ public void authenticatedParser() throws Exception { } @Test - @SuppressWarnings("unchecked") + @SuppressWarnings({"unchecked", "deprecation"}) public void headerParser() { HeaderMatcher headerMatcher = HeaderMatcher.newBuilder() .setName("party").setExactMatch("win").build(); @@ -172,6 +200,21 @@ public void headerParser() { assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); } + @Test + @SuppressWarnings("deprecation") + public void headerParser_headerName() { + HeaderMatcher headerMatcher = HeaderMatcher.newBuilder() + .setName("grpc--feature").setExactMatch("win").build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setHeader(headerMatcher).build()); + HeaderMatcher headerMatcher2 = HeaderMatcher.newBuilder() + .setName(":scheme").setExactMatch("win").build(); + List principalList = Arrays.asList( + Principal.newBuilder().setHeader(headerMatcher2).build()); + ConfigOrError result = parseOverride(permissionList, principalList); + assertThat(result.errorDetail).isNotNull(); + } + @Test @SuppressWarnings("unchecked") public void compositeRules() { diff --git a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java b/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java deleted file mode 100644 index c4e888f5439..00000000000 --- a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java +++ /dev/null @@ -1,320 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.truth.Truth.assertThat; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.inOrder; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.common.util.concurrent.SettableFuture; -import io.grpc.Server; -import io.grpc.ServerBuilder; -import io.grpc.Status; -import io.grpc.StatusException; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.ServerWrapperForXds; -import java.io.IOException; -import java.net.BindException; -import java.net.NoRouteToHostException; -import java.util.List; -import java.util.concurrent.CancellationException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.InOrder; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -/** - * Unit tests for {@link ServerWrapperForXds}. - */ -@RunWith(JUnit4.class) -public class ServerWrapperForXdsTest { - - private ServerWrapperForXds serverWrapperForXds; - private ServerBuilder mockDelegateBuilder; - private int port; - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener; - private XdsClient.LdsResourceWatcher listenerWatcher; - private Server mockServer; - private TlsContextManager tlsContextManager; - - @Before - public void setUp() throws IOException { - port = XdsServerTestHelper.findFreePort(); - mockDelegateBuilder = mock(ServerBuilder.class); - tlsContextManager = mock(TlsContextManager.class); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(port, tlsContextManager); - mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); - listenerWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - mockServer = mock(Server.class); - when(mockDelegateBuilder.build()).thenReturn(mockServer); - serverWrapperForXds = new ServerWrapperForXds(mockDelegateBuilder, - xdsClientWrapperForServerSds, - mockXdsServingStatusListener, - 100, TimeUnit.MILLISECONDS); - } - - private Future startServerAsync() throws InterruptedException { - final SettableFuture settableFuture = SettableFuture.create(); - Executors.newSingleThreadExecutor().execute(new Runnable() { - @Override - public void run() { - try { - serverWrapperForXds.start(); - settableFuture.set(null); - } catch (Throwable e) { - settableFuture.set(e); - } - } - }); - // wait until xdsClientWrapperForServerSds.serverWatchers populated - for (int i = 0; i < 10; i++) { - synchronized (xdsClientWrapperForServerSds.serverWatchers) { - if (!xdsClientWrapperForServerSds.serverWatchers.isEmpty()) { - break; - } - } - Thread.sleep(100L); - } - return settableFuture; - } - - @Test - public void start() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - listenerWatcher.onError(Status.ABORTED); - verifyCapturedCodeAndNotServing(Status.Code.ABORTED, ServerWrapperForXds.ServingState.STARTING); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), - tlsContextManager); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - listenerWatcher.onResourceDoesNotExist("name"); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.NOT_FOUND); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.INVALID_ARGUMENT); - verifyCapturedCodeAndNotServing(Status.Code.INVALID_ARGUMENT, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.PERMISSION_DENIED); - verifyCapturedCodeAndNotServing(Status.Code.PERMISSION_DENIED, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.UNIMPLEMENTED); - verifyCapturedCodeAndNotServing(Status.Code.UNIMPLEMENTED, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.UNAUTHENTICATED); - verifyCapturedCodeAndNotServing(Status.Code.UNAUTHENTICATED, - ServerWrapperForXds.ServingState.NOT_SERVING); - listenerWatcher.onError(Status.ABORTED); - verifyCapturedCodeAndNotServing(null, ServerWrapperForXds.ServingState.NOT_SERVING); - Server mockServer1 = mock(Server.class); - Server mockServer2 = mock(Server.class); - Server mockServer3 = mock(Server.class); - final SettableFuture settableFutureForThrow = SettableFuture.create(); - final SettableFuture settableFutureToSignalStart = SettableFuture.create(); - doAnswer(new Answer() { - @Override - public Server answer(InvocationOnMock invocation) throws Throwable { - settableFutureToSignalStart.set(null); - throw settableFutureForThrow.get(); - } - }).when(mockServer1).start(); - doThrow(new BindException()).when(mockServer2).start(); - doReturn(mockServer3).when(mockServer3).start(); - when(mockDelegateBuilder.build()).thenReturn(mockServer1, mockServer2, mockServer3); - new Thread(new Runnable() { - @Override - public void run() { - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - } - }).start(); - assertThat(settableFutureToSignalStart.get()).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.ENTER_SERVING); - settableFutureForThrow.set(new IOException(new BindException())); - Thread.sleep(1000L); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - InOrder inOrder = inOrder(mockXdsServingStatusListener); - inOrder.verify(mockXdsServingStatusListener, times(2)).onNotServing(argCaptor.capture()); - List throwableList = argCaptor.getAllValues(); - assertThat(throwableList.size()).isEqualTo(2); - Throwable throwable = throwableList.remove(0); - assertThat(throwable).isInstanceOf(IOException.class); - assertThat(throwable.getCause()).isInstanceOf(BindException.class); - throwable = throwableList.remove(0); - assertThat(throwable).isInstanceOf(BindException.class); - inOrder.verify(mockXdsServingStatusListener).onServing(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - serverWrapperForXds.shutdown(); - } - - @Test - public void delegateInitialStartError() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - doThrow(new IOException("test exception")).when(mockServer).start(); - new Thread(new Runnable() { - @Override - public void run() { - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - } - }).start(); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isInstanceOf(IOException.class); - assertThat(exception).hasMessageThat().isEqualTo("test exception"); - } - - private void verifyCapturedCodeAndNotServing(Status.Code expected, - ServerWrapperForXds.ServingState servingState) { - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener, times(expected != null ? 1 : 0)) - .onNotServing(argCaptor.capture()); - if (expected != null) { - Throwable throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(expected); - } - assertThat(serverWrapperForXds.getCurrentServingState()).isEqualTo(servingState); - reset(mockXdsServingStatusListener); - } - - @Test - public void start_internalError() - throws InterruptedException, TimeoutException, ExecutionException { - Future future = startServerAsync(); - listenerWatcher.onError(Status.INTERNAL); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isInstanceOf(IOException.class); - Throwable cause = exception.getCause(); - assertThat(cause).isInstanceOf(StatusException.class); - assertThat(((StatusException) cause).getStatus().getCode()).isEqualTo(Status.Code.INTERNAL); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); - } - - @Test - public void delegateStartError_shutdown() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - listenerWatcher.onResourceDoesNotExist("name"); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - Server mockServer = mock(Server.class); - doThrow(new IOException(new NoRouteToHostException())).when(mockServer).start(); - when(mockDelegateBuilder.build()).thenReturn(mockServer); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"), - tlsContextManager); - Thread.sleep(100L); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); - } - - @Test - public void shutdownDuringRestart() - throws InterruptedException, TimeoutException, ExecutionException, IOException { - Future future = startServerAsync(); - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - Throwable exception = future.get(2, TimeUnit.SECONDS); - assertThat(exception).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.STARTED); - listenerWatcher.onResourceDoesNotExist("name"); - verifyCapturedCodeAndNotServing(Status.Code.NOT_FOUND, - ServerWrapperForXds.ServingState.NOT_SERVING); - Server mockServer = mock(Server.class); - final SettableFuture settableFutureForStart = SettableFuture.create(); - final SettableFuture settableFutureToSignalStart = SettableFuture.create(); - final SettableFuture settableFutureForInterrupt = SettableFuture.create(); - doAnswer(new Answer() { - @Override - public Server answer(InvocationOnMock invocation) - throws InterruptedException, ExecutionException { - settableFutureToSignalStart.set(null); - try { - settableFutureForStart.get(); - } catch (InterruptedException | CancellationException e) { - settableFutureForInterrupt.set(e); - throw e; - } - return null; // never reach here - } - }).when(mockServer).start(); - when(mockDelegateBuilder.build()).thenReturn(mockServer); - new Thread(new Runnable() { - @Override - public void run() { - XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), - tlsContextManager); - } - }).start(); - assertThat(settableFutureToSignalStart.get()).isNull(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.ENTER_SERVING); - serverWrapperForXds.shutdown(); - assertThat(serverWrapperForXds.getCurrentServingState()) - .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); - Throwable interruptedException = settableFutureForInterrupt.get(1L, TimeUnit.SECONDS); - assertThat(interruptedException).isInstanceOf(InterruptedException.class); - } -} diff --git a/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java b/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java index 9f2293d3c53..3051a021870 100644 --- a/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedCallCounterMapTest.java @@ -62,4 +62,22 @@ public boolean isDone() { map.cleanQueue(); assertThat(counters).isEmpty(); } + + @Test + public void gcAndRecreate() { + @SuppressWarnings("UnusedVariable") // assign to null for GC only + AtomicLong counter = map.getOrCreate(CLUSTER, EDS_SERVICE_NAME); + final CounterReference ref = counters.get(CLUSTER).get(EDS_SERVICE_NAME); + assertThat(counter.get()).isEqualTo(0); + counter = null; + GcFinalization.awaitDone(new FinalizationPredicate() { + @Override + public boolean isDone() { + return ref.isEnqueued(); + } + }); + map.getOrCreate(CLUSTER, EDS_SERVICE_NAME); + assertThat(counters.get(CLUSTER)).isNotNull(); + assertThat(counters.get(CLUSTER).get(EDS_SERVICE_NAME)).isNotNull(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index a8566a682bb..1871cb79770 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -17,6 +17,8 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector.NO_FILTER_CHAIN; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; @@ -27,50 +29,83 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Server; +import io.grpc.ServerBuilder; import io.grpc.Status; -import io.grpc.StatusException; import io.grpc.inprocess.InProcessSocketAddress; +import io.grpc.internal.TestUtils.NoopChannelLogger; +import io.grpc.netty.GrpcHttp2ConnectionHandler; +import io.grpc.netty.InternalProtocolNegotiationEvent; +import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; +import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProvider; import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http2.DefaultHttp2Connection; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; +import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; +import io.netty.handler.codec.http2.DefaultHttp2FrameReader; +import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; +import io.netty.handler.codec.http2.Http2ConnectionDecoder; +import io.netty.handler.codec.http2.Http2ConnectionEncoder; +import io.netty.handler.codec.http2.Http2Settings; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; -import java.net.UnknownHostException; import java.util.Arrays; import java.util.Collections; -import org.junit.After; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -/** Tests for {@link XdsClientWrapperForServerSds}. */ +/** Migration test XdsServerWrapper from previous XdsClientWrapperForServerSds. */ @RunWith(JUnit4.class) public class XdsClientWrapperForServerSdsTestMisc { private static final int PORT = 7000; - @Mock private Channel channel; + private EmbeddedChannel channel; + private ChannelPipeline pipeline; @Mock private TlsContextManager tlsContextManager; - @Mock private XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher; - - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; - private XdsClient.LdsResourceWatcher registeredWatcher; private InetSocketAddress localAddress; private DownstreamTlsContext tlsContext1; private DownstreamTlsContext tlsContext2; private DownstreamTlsContext tlsContext3; + @Mock + private ServerBuilder mockBuilder; + @Mock + Server mockServer; + @Mock + private XdsServingStatusListener listener; + private FakeXdsClient xdsClient = new FakeXdsClient(); + private FilterChainSelectorManager selectorManager = new FilterChainSelectorManager(); + private XdsServerWrapper xdsServerWrapper; + + @Before - public void setUp() throws IOException { + public void setUp() { MockitoAnnotations.initMocks(this); tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -78,54 +113,52 @@ public void setUp() throws IOException { CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); tlsContext3 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(PORT, tlsContextManager); - } - - @After - public void tearDown() { - xdsClientWrapperForServerSds.shutdown(); + when(mockBuilder.build()).thenReturn(mockServer); + when(mockServer.isShutdown()).thenReturn(false); + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:" + PORT, mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), FilterRegistry.newRegistry()); } @Test - public void nonInetSocketAddress_expectNull() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - assertThat( - sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager)) + public void nonInetSocketAddress_expectNull() throws Exception { + sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager); + assertThat(getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector())) .isNull(); } @Test - public void nonMatchingPort_expectException() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - try { - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT + 1); - sendListenerUpdate(localAddress, null, null, tlsContextManager); - fail("exception expected"); - } catch (IllegalStateException expected) { - assertThat(expected) - .hasMessageThat() - .isEqualTo("Channel localAddress port does not match requested listener port"); - } - } - - @Test - public void emptyFilterChain_expectNull() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); + public void emptyFilterChain_expectNull() throws Exception { InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - ArgumentCaptor listenerWatcherCaptor = ArgumentCaptor - .forClass(null); - XdsClient xdsClient = xdsClientWrapperForServerSds.getXdsClient(); - verify(xdsClient) - .watchLdsResource(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT), - listenerWatcherCaptor.capture()); - XdsClient.LdsResourceWatcher registeredWatcher = listenerWatcherCaptor.getValue(); - when(channel.localAddress()).thenReturn(localAddress); + final InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); + InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); + final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); + channel = new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return localAddress; + } + + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + }; + pipeline = channel.pipeline(); + + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT); + EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -133,50 +166,112 @@ public void emptyFilterChain_expectNull() throws UnknownHostException { Collections.emptyList(), null); LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext = getDownstreamTlsContext(); - assertThat(tlsContext).isNull(); - } - - @Test - public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException { - registerWatcherAndCreateListenerUpdate(tlsContext1); - verify(mockServerWatcher).onListenerUpdate(); + xdsClient.ldsWatcher.onChanged(listenerUpdate); + start.get(5, TimeUnit.SECONDS); + FilterChainSelector selector = selectorManager.getSelectorToUpdateSelector(); + assertThat(getSslContextProviderSupplier(selector)).isNull(); } @Test - public void registerServerWatcher_notifyNotFound() throws UnknownHostException { - commonErrorCheck(true, Status.NOT_FOUND, true); + public void registerServerWatcher_notifyNotFound() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsWatched); + try { + start.get(5, TimeUnit.SECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test - public void registerServerWatcher_notifyInternalError() throws UnknownHostException { - commonErrorCheck(false, Status.INTERNAL, false); + public void registerServerWatcher_notifyInternalError() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onError(Status.INTERNAL); + try { + start.get(5, TimeUnit.SECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test - public void registerServerWatcher_notifyPermDeniedError() throws UnknownHostException { - commonErrorCheck(false, Status.PERMISSION_DENIED, true); + public void registerServerWatcher_notifyPermDeniedError() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); + try { + start.get(5, TimeUnit.SECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(selectorManager.getSelectorToUpdateSelector()).isSameInstanceAs(NO_FILTER_CHAIN); } @Test - public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws UnknownHostException { - registerWatcherAndCreateListenerUpdate(tlsContext1); - XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext2, tlsContextManager); + public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws Exception { + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext2, null, + tlsContextManager); verify(tlsContextManager, never()) .findOrCreateServerSslContextProvider(any(DownstreamTlsContext.class)); } @Test - public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHostException { + public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); XdsServerTestHelper - .generateListenerUpdate(registeredWatcher, Arrays.asList(1234), tlsContext2, + .generateListenerUpdate(xdsClient, Arrays.asList(1234), tlsContext2, tlsContext3, tlsContextManager); + returnedSupplier = getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext2); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); reset(tlsContextManager); SslContextProvider sslContextProvider2 = mock(SslContextProvider.class); @@ -185,129 +280,175 @@ public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHo SslContextProvider sslContextProvider3 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext3))) .thenReturn(sslContextProvider3); - callUpdateSslContext(channel); + callUpdateSslContext(returnedSupplier); InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); - InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111); - when(channel.remoteAddress()).thenReturn(remoteAddress); - callUpdateSslContext(channel); - XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient(); - xdsClientWrapperForServerSds.shutdown(); - verify(mockXdsClient, times(1)) - .cancelLdsResourceWatch(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT), - eq(registeredWatcher)); + final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111); + channel = new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return localAddress; + } + + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + }; + pipeline = channel.pipeline(); + returnedSupplier = getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext3); + callUpdateSslContext(returnedSupplier); + xdsServerWrapper.shutdown(); + assertThat(xdsClient.ldsResource).isNull(); verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider2)); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider3)); } @Test - public void releaseOldSupplierOnNotFound_verifyClose() throws UnknownHostException { + public void releaseOldSupplierOnNotFound_verifyClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); - registeredWatcher.onResourceDoesNotExist("not-found Error"); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); + xdsClient.ldsWatcher.onResourceDoesNotExist("not-found Error"); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } @Test - public void releaseOldSupplierOnPermDeniedError_verifyClose() throws UnknownHostException { + public void releaseOldSupplierOnPermDeniedError_verifyClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); - registeredWatcher.onError(Status.PERMISSION_DENIED); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); + xdsClient.ldsWatcher.onError(Status.PERMISSION_DENIED); verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); } @Test - public void releaseOldSupplierOnInternalError_noClose() throws UnknownHostException { + public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) .thenReturn(sslContextProvider1); - registerWatcherAndCreateListenerUpdate(tlsContext1); - callUpdateSslContext(channel); - registeredWatcher.onError(Status.INTERNAL); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); + sendListenerUpdate(localAddress, tlsContext1, null, + tlsContextManager); + SslContextProviderSupplier returnedSupplier = + getSslContextProviderSupplier(selectorManager.getSelectorToUpdateSelector()); + assertThat(returnedSupplier.getTlsContext()).isSameInstanceAs(tlsContext1); + callUpdateSslContext(returnedSupplier); + xdsClient.ldsWatcher.onError(Status.CANCELLED); verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); } - private void callUpdateSslContext(Channel channel) { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + private void callUpdateSslContext(SslContextProviderSupplier sslContextProviderSupplier) { assertThat(sslContextProviderSupplier).isNotNull(); SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); sslContextProviderSupplier.updateSslContext(callback); } - private void registerWatcherAndCreateListenerUpdate(DownstreamTlsContext tlsContext) - throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - localAddress = new InetSocketAddress(ipLocalAddress, PORT); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext, null, - tlsContextManager); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - } + private void sendListenerUpdate( + final SocketAddress localAddress, DownstreamTlsContext tlsContext, + DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) + throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + XdsServerTestHelper + .generateListenerUpdate(xdsClient, Arrays.asList(), tlsContext, + tlsContextForDefaultFilterChain, tlsContextManager); + start.get(5, TimeUnit.SECONDS); + InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); + final InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); + channel = new EmbeddedChannel() { + @Override + public SocketAddress localAddress() { + return localAddress; + } - private void commonErrorCheck(boolean generateResourceDoesNotExist, Status status, - boolean isAbsent) throws UnknownHostException { - registerWatcherAndCreateListenerUpdate(tlsContext1); - reset(mockServerWatcher); - if (generateResourceDoesNotExist) { - registeredWatcher.onResourceDoesNotExist("not-found Error"); - } else { - registeredWatcher.onError(status); - } - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockServerWatcher).onError(argCaptor.capture(), eq(isAbsent)); - Throwable throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(status.getCode()); - if (isAbsent) { - assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNull(); - } else { - assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNotNull(); - } + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + }; + pipeline = channel.pipeline(); } - private DownstreamTlsContext sendListenerUpdate( - SocketAddress localAddress, DownstreamTlsContext tlsContext, - DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) - throws UnknownHostException { - when(channel.localAddress()).thenReturn(localAddress); - InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); - InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); - when(channel.remoteAddress()).thenReturn(remoteAddress); - XdsServerTestHelper - .generateListenerUpdate(registeredWatcher, Arrays.asList(), tlsContext, - tlsContextForDefaultFilterChain, tlsContextManager); - return getDownstreamTlsContext(); + private SslContextProviderSupplier getSslContextProviderSupplier( + FilterChainSelector selector) throws Exception { + final SettableFuture sslSet = SettableFuture.create(); + ChannelHandler next = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ProtocolNegotiationEvent e = (ProtocolNegotiationEvent)evt; + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(e) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + ctx.pipeline().remove(this); + } + }; + ProtocolNegotiator mockDelegate = mock(ProtocolNegotiator.class); + GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); + when(mockDelegate.newHandler(grpcHandler)).thenReturn(next); + FilterChainSelectorManager manager = new FilterChainSelectorManager(); + manager.updateSelector(selector); + FilterChainMatchingHandler filterChainMatchingHandler = + new FilterChainMatchingHandler(grpcHandler, manager, mockDelegate); + pipeline.addLast(filterChainMatchingHandler); + ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + pipeline.fireUserEventTriggered(event); + channel.runPendingTasks(); + sslSet.set(InternalProtocolNegotiationEvent.getAttributes(event) + .get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER)); + return sslSet.get(); } - private DownstreamTlsContext getDownstreamTlsContext() { - SslContextProviderSupplier sslContextProviderSupplier = - xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); - if (sslContextProviderSupplier != null) { - EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); - assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class); - return (DownstreamTlsContext)tlsContext; + private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { + FakeGrpcHttp2ConnectionHandler( + ChannelPromise channelUnused, + Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + super(channelUnused, decoder, encoder, initialSettings, new NoopChannelLogger()); } - return null; - } - /** Creates XdsClientWrapperForServerSds: also used by other classes. */ - public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds( - int port, DownstreamTlsContext downstreamTlsContext, TlsContextManager tlsContextManager) { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); - xdsClientWrapperForServerSds.start(); - XdsSdsClientServerTest.generateListenerUpdateToWatcher( - downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher(), tlsContextManager); - return xdsClientWrapperForServerSds; + static FakeGrpcHttp2ConnectionHandler newHandler() { + DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false); + DefaultHttp2ConnectionEncoder encoder = + new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader()); + Http2Settings settings = new Http2Settings(); + return new FakeGrpcHttp2ConnectionHandler( + /*channelUnused=*/ null, decoder, encoder, settings); + } + + @Override + public String getAuthority() { + return "authority"; + } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java index ba1e561410f..32850b441d7 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverProviderTest.java @@ -20,15 +20,20 @@ import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import com.google.common.collect.ImmutableMap; import io.grpc.ChannelLogger; import io.grpc.InternalServiceProviders; import io.grpc.NameResolver; import io.grpc.NameResolver.ServiceConfigParser; import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.GrpcUtil; import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -114,4 +119,54 @@ public void invalidName_hostnameContainsUnderscore() { // Expected } } + + @Test + public void newProvider_multipleScheme() { + NameResolverRegistry registry = NameResolverRegistry.getDefaultRegistry(); + XdsNameResolverProvider provider0 = XdsNameResolverProvider.createForTest("no-scheme", null); + registry.register(provider0); + XdsNameResolverProvider provider1 = XdsNameResolverProvider.createForTest("new-xds-scheme", + new HashMap()); + registry.register(provider1); + assertThat(registry.asFactory() + .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); + assertThat(registry.asFactory() + .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNotNull(); + assertThat(registry.asFactory() + .newNameResolver(URI.create("no-scheme:///localhost"), args)).isNotNull(); + registry.deregister(provider1); + assertThat(registry.asFactory() + .newNameResolver(URI.create("new-xds-scheme:///localhost"), args)).isNull(); + registry.deregister(provider0); + assertThat(registry.asFactory() + .newNameResolver(URI.create("xds:///localhost"), args)).isNotNull(); + } + + @Test + public void newProvider_overrideBootstrap() { + Map b = ImmutableMap.of( + "node", ImmutableMap.of( + "id", "ENVOY_NODE_ID", + "cluster", "ENVOY_CLUSTER"), + "xds_servers", Collections.singletonList( + ImmutableMap.of( + "server_uri", "trafficdirector.googleapis.com:443", + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ) + ) + ) + ); + NameResolverRegistry registry = new NameResolverRegistry(); + XdsNameResolverProvider provider = XdsNameResolverProvider.createForTest("no-scheme", b); + registry.register(provider); + NameResolver resolver = registry.asFactory() + .newNameResolver(URI.create("no-scheme:///localhost"), args); + resolver.start(mock(NameResolver.Listener2.class)); + assertThat(resolver).isInstanceOf(XdsNameResolver.class); + assertThat(((XdsNameResolver)resolver).getXdsClient().getBootstrapInfo().getNode().getId()) + .isEqualTo("ENVOY_NODE_ID"); + resolver.shutdown(); + registry.deregister(provider); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index cb2b4481fad..babaa2b3034 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -153,7 +153,7 @@ public void setUp() { new FaultFilter(mockRandom, new AtomicLong()), RouterFilter.INSTANCE); resolver = new XdsNameResolver(AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, filterRegistry); + xdsClientPoolFactory, mockRandom, filterRegistry, null); } @After @@ -172,7 +172,6 @@ public void resolving_failToCreateXdsClientPool() { XdsClientPoolFactory xdsClientPoolFactory = new XdsClientPoolFactory() { @Override public void setBootstrapOverride(Map bootstrap) { - throw new UnsupportedOperationException("Should not be called"); } @Override @@ -187,7 +186,7 @@ public ObjectPool getOrCreate() throws XdsInitializationException { } }; resolver = new XdsNameResolver(AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry()); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener).onError(errorCaptor.capture()); Status error = errorCaptor.getValue(); @@ -437,7 +436,7 @@ public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { ServiceConfigParser realParser = new ScParser( true, 5, 5, new AutoConfiguredLoadBalancerFactory("pick-first")); resolver = new XdsNameResolver(AUTHORITY, realParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry()); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); RetryPolicy retryPolicy = RetryPolicy.create( @@ -640,7 +639,7 @@ public void resolved_rpcHashingByChannelId() { resolver.shutdown(); reset(mockListener); resolver = new XdsNameResolver(AUTHORITY, serviceConfigParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry()); + xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( @@ -989,6 +988,8 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { RetryPolicy retryPolicy = RetryPolicy.create( 4, ImmutableList.of(Code.UNAVAILABLE, Code.CANCELLED), Durations.fromMillis(100), Durations.fromMillis(200), null); + RetryPolicy retryPolicyWithEmptyStatusCodes = RetryPolicy.create( + 4, ImmutableList.of(), Durations.fromMillis(100), Durations.fromMillis(200), null); // timeout only String expectedServiceConfigJson = "{\n" @@ -1002,6 +1003,11 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(timeoutNano, null)) .isEqualTo(expectedServiceConfig); + // timeout and retry with empty retriable status codes + assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig( + timeoutNano, retryPolicyWithEmptyStatusCodes)) + .isEqualTo(expectedServiceConfig); + // retry only expectedServiceConfigJson = "{\n" + " \"methodConfig\": [{\n" @@ -1022,6 +1028,7 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(null, retryPolicy)) .isEqualTo(expectedServiceConfig); + // timeout and retry expectedServiceConfigJson = "{\n" + " \"methodConfig\": [{\n" @@ -1044,12 +1051,16 @@ public void generateServiceConfig_forPerMethodConfig() throws IOException { .isEqualTo(expectedServiceConfig); // no timeout and no retry - // timeout and retry expectedServiceConfigJson = "{}"; expectedServiceConfig = (Map) JsonParser.parse(expectedServiceConfigJson); assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig(null, null)) .isEqualTo(expectedServiceConfig); + + // retry with emtry retriable status codes only + assertThat(XdsNameResolver.generateServiceConfigWithMethodConfig( + null, retryPolicyWithEmptyStatusCodes)) + .isEqualTo(expectedServiceConfig); } @Test @@ -1385,20 +1396,6 @@ public long nanoTime() { + " Deadline exceeded after 0.000004000s. ")); } - @Test - public void resolved_withNoRouterFilter() { - resolver.start(mockListener); - FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); - xdsClient.deliverLdsUpdateWithNoRouterFilter(); - verify(mockListener).onResult(resolutionResultCaptor.capture()); - ResolutionResult result = resolutionResultCaptor.getValue(); - InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); - ClientCall.Listener observer = startNewCall( - TestMethodDescriptors.voidMethod(), configSelector, Collections.emptyMap(), - CallOptions.DEFAULT); - verifyRpcFailed(observer, Status.UNAVAILABLE.withDescription("No router filter")); - } - @Test public void resolved_faultAbortAndDelayInLdsUpdateInLdsUpdate() { resolver.start(mockListener); @@ -1701,10 +1698,8 @@ public void routeMatching_withHeaders() { } private final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { - @Override public void setBootstrapOverride(Map bootstrap) { - throw new UnsupportedOperationException("Should not be called"); } @Override @@ -1829,16 +1824,6 @@ void deliverLdsUpdateWithFaultInjection( 0L, Collections.singletonList(virtualHost), filterChain))); } - void deliverLdsUpdateWithNoRouterFilter() { - VirtualHost virtualHost = VirtualHost.create( - "virtual-host", - Collections.singletonList(AUTHORITY), - Collections.emptyList(), - Collections.emptyMap()); - ldsWatcher.onChanged(LdsUpdate.forApiListener(HttpConnectionManager.forVirtualHosts( - 0L, Collections.singletonList(virtualHost), ImmutableList.of()))); - } - void deliverLdsUpdateForRdsNameWithFaultInjection( final String rdsName, @Nullable FaultConfig httpFilterFaultConfig) { if (httpFilterFaultConfig == null) { diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index c71841b2678..579542a2777 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -29,6 +29,8 @@ import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -38,6 +40,7 @@ import io.grpc.NameResolver; import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; +import io.grpc.Server; import io.grpc.ServerCredentials; import io.grpc.Status; import io.grpc.StatusRuntimeException; @@ -48,13 +51,19 @@ import io.grpc.testing.protobuf.SimpleServiceGrpc; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.XdsClient.LdsUpdate; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.grpc.xds.internal.sds.TlsContextManagerImpl; import io.netty.handler.ssl.NotSslRecordException; -import java.io.IOException; import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.URI; @@ -63,10 +72,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import javax.net.ssl.SSLException; import javax.net.ssl.SSLHandshakeException; import org.junit.After; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -86,14 +96,9 @@ public class XdsSdsClientServerTest { private Bootstrapper.BootstrapInfo bootstrapInfoForServer = null; private TlsContextManagerImpl tlsContextManagerForClient; private TlsContextManagerImpl tlsContextManagerForServer; - - @Before - public void setUp() throws Exception { - port = XdsServerTestHelper.findFreePort(); - URI expectedUri = new URI("sdstest://localhost:" + port); - fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); - NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); - } + private FakeXdsClient xdsClient = new FakeXdsClient(); + private FakeXdsClientPoolFactory fakePoolFactory = new FakeXdsClientPoolFactory(xdsClient); + private static final String OVERRIDE_AUTHORITY = "foo.test.google.fr"; @After public void tearDown() { @@ -103,16 +108,17 @@ public void tearDown() { } @Test - public void plaintextClientServer() throws IOException, URISyntaxException { + public void plaintextClientServer() throws Exception { buildServerWithTlsContext(/* downstreamTlsContext= */ null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); + getBlockingStub(/* upstreamTlsContext= */ null, + /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } @Test - public void nullFallbackCredentials_expectException() throws IOException, URISyntaxException { + public void nullFallbackCredentials_expectException() throws Exception { try { buildServerWithTlsContext(/* downstreamTlsContext= */ null, /* fallbackCredentials= */ null); fail("exception expected"); @@ -123,7 +129,7 @@ public void nullFallbackCredentials_expectException() throws IOException, URISyn /** TLS channel - no mTLS. */ @Test - public void tlsClientServer_noClientAuthentication() throws IOException, URISyntaxException { + public void tlsClientServer_noClientAuthentication() throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); buildServerWithTlsContext(downstreamTlsContext); @@ -134,13 +140,13 @@ public void tlsClientServer_noClientAuthentication() throws IOException, URISynt CLIENT_PEM_FILE, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); } @Test public void requireClientAuth_noClientCert_expectException() - throws IOException, URISyntaxException { + throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, true, true); buildServerWithTlsContext(downstreamTlsContext); @@ -151,7 +157,7 @@ public void requireClientAuth_noClientCert_expectException() CLIENT_PEM_FILE, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); try { unaryRpc(/* requestMessage= */ "buddy", blockingStub); fail("exception expected"); @@ -168,7 +174,7 @@ public void requireClientAuth_noClientCert_expectException() } @Test - public void noClientAuth_sendBadClientCert_passes() throws IOException, URISyntaxException { + public void noClientAuth_sendBadClientCert_passes() throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); buildServerWithTlsContext(downstreamTlsContext); @@ -178,17 +184,17 @@ public void noClientAuth_sendBadClientCert_passes() throws IOException, URISynta BAD_CLIENT_PEM_FILE, true); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } @Test - public void mtls_badClientCert_expectException() throws IOException, URISyntaxException { + public void mtls_badClientCert_expectException() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( BAD_CLIENT_KEY_FILE, BAD_CLIENT_PEM_FILE, true); try { - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, null, null, null, null); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); fail("exception expected"); } catch (StatusRuntimeException sre) { if (sre.getCause() instanceof SSLHandshakeException) { @@ -202,27 +208,18 @@ public void mtls_badClientCert_expectException() throws IOException, URISyntaxEx } } - /** mTLS - client auth enabled. */ - @Test - public void mtlsClientServer_withClientAuthentication() throws IOException, URISyntaxException { - UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( - CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, null, null, null, null); - } - /** mTLS - client auth enabled - using {@link XdsChannelCredentials} API. */ @Test public void mtlsClientServer_withClientAuthentication_withXdsChannelCreds() - throws IOException, URISyntaxException { + throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( CLIENT_KEY_FILE, CLIENT_PEM_FILE, true); - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, true, null, null, null, null); + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, null, null, null, null); } @Test - public void tlsServer_plaintextClient_expectException() throws IOException, URISyntaxException { + public void tlsServer_plaintextClient_expectException() throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(null, null, null, null, false, false); buildServerWithTlsContext(downstreamTlsContext); @@ -239,7 +236,7 @@ public void tlsServer_plaintextClient_expectException() throws IOException, URIS } @Test - public void plaintextServer_tlsClient_expectException() throws IOException, URISyntaxException { + public void plaintextServer_tlsClient_expectException() throws Exception { buildServerWithTlsContext(/* downstreamTlsContext= */ null); // for TLS, client only needs trustCa @@ -248,7 +245,7 @@ public void plaintextServer_tlsClient_expectException() throws IOException, URIS CLIENT_PEM_FILE, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); try { unaryRpc("buddy", blockingStub); fail("exception expected"); @@ -261,22 +258,23 @@ public void plaintextServer_tlsClient_expectException() throws IOException, URIS /** mTLS - client auth enabled then update server certs to untrusted. */ @Test public void mtlsClientServer_changeServerContext_expectException() - throws IOException, URISyntaxException { + throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContext( CLIENT_KEY_FILE, CLIENT_PEM_FILE, true); - XdsClient.LdsResourceWatcher listenerWatcher = - performMtlsTestAndGetListenerWatcher(upstreamTlsContext, false, "cert-instance-name2", + performMtlsTestAndGetListenerWatcher(upstreamTlsContext, "cert-instance-name2", BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContext( "cert-instance-name2", true, true); - generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher, - tlsContextManagerForServer); + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", + downstreamTlsContext, + tlsContextManagerForServer); + xdsClient.deliverLdsUpdate(LdsUpdate.forTcpListener(listener)); try { SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); + getBlockingStub(upstreamTlsContext, OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); fail("exception expected"); } catch (StatusRuntimeException sre) { @@ -285,27 +283,20 @@ public void mtlsClientServer_changeServerContext_expectException() } } - private XdsClient.LdsResourceWatcher performMtlsTestAndGetListenerWatcher( - UpstreamTlsContext upstreamTlsContext, boolean newApi, String certInstanceName2, + private void performMtlsTestAndGetListenerWatcher( + UpstreamTlsContext upstreamTlsContext, String certInstanceName2, String privateKey2, String cert2, String trustCa2) - throws IOException, URISyntaxException { + throws Exception { DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(certInstanceName2, privateKey2, cert2, trustCa2, true, true); - final XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - createXdsClientWrapperForServerSds(port); buildServerWithFallbackServerCredentials( - xdsClientWrapperForServerSds, InsecureServerCredentials.create(), downstreamTlsContext); + InsecureServerCredentials.create(), downstreamTlsContext); - XdsClient.LdsResourceWatcher listenerWatcher = xdsClientWrapperForServerSds - .getListenerWatcher(); - - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = newApi - ? getBlockingStub(upstreamTlsContext, "foo.test.google.fr") : - getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, OVERRIDE_AUTHORITY); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); - return listenerWatcher; } private DownstreamTlsContext setBootstrapInfoAndBuildDownstreamTlsContext( @@ -330,60 +321,43 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli } private void buildServerWithTlsContext(DownstreamTlsContext downstreamTlsContext) - throws IOException { + throws Exception { buildServerWithTlsContext(downstreamTlsContext, InsecureServerCredentials.create()); } private void buildServerWithTlsContext( DownstreamTlsContext downstreamTlsContext, ServerCredentials fallbackCredentials) - throws IOException { - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - createXdsClientWrapperForServerSds(port); - xdsClientWrapperForServerSds.start(); - buildServerWithFallbackServerCredentials( - xdsClientWrapperForServerSds, fallbackCredentials, downstreamTlsContext); + throws Exception { + buildServerWithFallbackServerCredentials(fallbackCredentials, downstreamTlsContext); } private void buildServerWithFallbackServerCredentials( - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, ServerCredentials fallbackCredentials, DownstreamTlsContext downstreamTlsContext) - throws IOException { + throws Exception { ServerCredentials xdsCredentials = XdsServerCredentials.create(fallbackCredentials); - buildServer(port, xdsCredentials, xdsClientWrapperForServerSds, downstreamTlsContext); - } - - /** Creates XdsClientWrapperForServerSds. */ - private XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) { - tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManagerForServer); - xdsClientWrapperForServerSds.start(); - return xdsClientWrapperForServerSds; - } - - static void generateListenerUpdateToWatcher( - DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher, - TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext, - tlsContextManager); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); + XdsServerBuilder builder = XdsServerBuilder.forPort(0, xdsCredentials) + .xdsClientPoolFactory(fakePoolFactory) + .addService(new SimpleServiceImpl()); + buildServer(builder, downstreamTlsContext); } private void buildServer( - int port, - ServerCredentials serverCredentials, - XdsClientWrapperForServerSds xdsClientWrapperForServerSds, + XdsServerBuilder builder, DownstreamTlsContext downstreamTlsContext) - throws IOException { - XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials) - .addService(new SimpleServiceImpl()); + throws Exception { tlsContextManagerForServer = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsServerTestHelper.generateListenerUpdate( - xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext, - tlsContextManagerForServer); - cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start(); + XdsServerWrapper xdsServer = (XdsServerWrapper) builder.build(); + SettableFuture startFuture = startServerAsync(xdsServer); + EnvoyServerProtoData.Listener listener = buildListener("listener1", "10.1.2.3", + downstreamTlsContext, tlsContextManagerForServer); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + startFuture.get(10, TimeUnit.SECONDS); + port = xdsServer.getPort(); + URI expectedUri = new URI("sdstest://localhost:" + port); + fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); + NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); } static EnvoyServerProtoData.Listener buildListener( @@ -399,9 +373,19 @@ static EnvoyServerProtoData.Listener buildListener( Arrays.asList(), Arrays.asList(), null); - // HttpConnectionManager currently not used for server side. - HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( - 0L, "does not matter", Collections.emptyList()); + String fullPath = "/" + SimpleServiceGrpc.SERVICE_NAME + "/" + "UnaryRpc"; + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath(fullPath, true), + Collections.emptyList(), null); + VirtualHost virtualHost = VirtualHost.create( + "virtual-host", Collections.singletonList(OVERRIDE_AUTHORITY), + Arrays.asList(Route.forAction(routeMatch, null, + ImmutableMap.of())), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), + new ArrayList()); EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( "filter-chain-foo", filterChainMatch, httpConnectionManager, tlsContext, tlsContextManager); @@ -445,6 +429,24 @@ private static String unaryRpc( return response.getResponseMessage(); } + private SettableFuture startServerAsync(final Server xdsServer) throws Exception { + cleanupRule.register(xdsServer); + final SettableFuture settableFuture = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + xdsServer.start(); + settableFuture.set(null); + } catch (Throwable e) { + settableFuture.set(e); + } + } + }); + xdsClient.ldsResource.get(8000, TimeUnit.MILLISECONDS); + return settableFuture; + } + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { @Override diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 0b174a4a313..0d15c1f660e 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -33,8 +33,9 @@ import io.grpc.Status; import io.grpc.StatusException; import io.grpc.testing.GrpcCleanupRule; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import io.grpc.xds.internal.sds.ServerWrapperForXds; import java.io.IOException; import java.net.InetSocketAddress; import java.net.ServerSocket; @@ -51,6 +52,7 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +// TODO (zivy@): move certain tests down to XdsServerWrapperTest, or up to XdsSdsClientServerTest. /** * Unit tests for {@link XdsServerBuilder}. */ @@ -59,31 +61,27 @@ public class XdsServerBuilderTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private XdsServerBuilder builder; - private ServerWrapperForXds xdsServer; - private XdsClient.LdsResourceWatcher listenerWatcher; + private XdsServerWrapper xdsServer; private int port; - private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private TlsContextManager tlsContextManager; + private FakeXdsClient xdsClient = new FakeXdsClient(); private void buildServer(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) throws IOException { buildBuilder(xdsServingStatusListener); - xdsServer = cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); } private void buildBuilder(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) throws IOException { - port = XdsServerTestHelper.findFreePort(); builder = XdsServerBuilder.forPort( port, XdsServerCredentials.create(InsecureServerCredentials.create())); + builder.xdsClientPoolFactory(new FakeXdsClientPoolFactory(xdsClient)); if (xdsServingStatusListener != null) { - builder = builder.xdsServingStatusListener(xdsServingStatusListener); + builder.xdsServingStatusListener(xdsServingStatusListener); } tlsContextManager = mock(TlsContextManager.class); - xdsClientWrapperForServerSds = XdsServerTestHelper - .createXdsClientWrapperForServerSds(port, tlsContextManager); - listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); } private void verifyServer( @@ -99,7 +97,7 @@ private void verifyServer( assertThat(list).hasSize(1); InetSocketAddress socketAddress = (InetSocketAddress) list.get(0); assertThat(socketAddress.getAddress().isAnyLocalAddress()).isTrue(); - assertThat(socketAddress.getPort()).isEqualTo(port); + assertThat(socketAddress.getPort()).isGreaterThan(-1); if (mockXdsServingStatusListener != null) { if (notServingStatus != null) { ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); @@ -117,10 +115,11 @@ private void verifyServer( private void verifyShutdown() throws InterruptedException { xdsServer.shutdown(); xdsServer.awaitTermination(500L, TimeUnit.MILLISECONDS); - assertThat(xdsClientWrapperForServerSds.getXdsClient()).isNull(); + assertThat(xdsClient.isShutDown()).isTrue(); } - private Future startServerAsync() throws InterruptedException { + private Future startServerAsync() throws + InterruptedException, TimeoutException, ExecutionException { final SettableFuture settableFuture = SettableFuture.create(); Executors.newSingleThreadExecutor().execute(new Runnable() { @Override @@ -133,15 +132,7 @@ public void run() { } } }); - // wait until xdsClientWrapperForServerSds.serverWatchers populated - for (int i = 0; i < 10; i++) { - synchronized (xdsClientWrapperForServerSds.serverWatchers) { - if (!xdsClientWrapperForServerSds.serverWatchers.isEmpty()) { - break; - } - } - Thread.sleep(100L); - } + xdsClient.ldsResource.get(5000, TimeUnit.MILLISECONDS); return settableFuture; } @@ -151,29 +142,29 @@ public void xdsServerStartAndShutdown() buildServer(null); Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), - tlsContextManager); + tlsContextManager); verifyServer(future, null, null); verifyShutdown(); } @Test - public void xdsServerStartAfterListenerUpdate() + public void xdsServerRestartAfterListenerUpdate() throws IOException, InterruptedException, TimeoutException, ExecutionException { buildServer(null); + Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); - xdsServer.start(); try { xdsServer.start(); fail("expected exception"); } catch (IllegalStateException expected) { assertThat(expected).hasMessageThat().contains("Already started"); } - verifyServer(null,null, null); + verifyServer(future,null, null); } @Test @@ -184,51 +175,37 @@ public void xdsServerStartAndShutdownWithXdsServingStatusListener() buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); } @Test - public void xdsServer_serverWatcher() - throws IOException, InterruptedException, TimeoutException, ExecutionException { + public void xdsServer_discoverState() throws Exception { XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); - listenerWatcher.onError(Status.ABORTED); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener).onNotServing(argCaptor.capture()); - Throwable throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.ABORTED); - assertThat(xdsClientWrapperForServerSds.serverWatchers).hasSize(1); - assertThat(future.isDone()).isFalse(); + XdsServerTestHelper.generateListenerUpdate( + xdsClient, + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); + future.get(5000, TimeUnit.MILLISECONDS); + xdsClient.ldsWatcher.onError(Status.ABORTED); + verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - listenerWatcher.onError(Status.NOT_FOUND); - argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener).onNotServing(argCaptor.capture()); - throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND); + xdsClient.ldsWatcher.onError(Status.CANCELLED); + verify(mockXdsServingStatusListener, never()).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); - listenerWatcher.onResourceDoesNotExist("not found error"); - argCaptor = ArgumentCaptor.forClass(null); - verify(mockXdsServingStatusListener).onNotServing(argCaptor.capture()); - throwable = argCaptor.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - captured = ((StatusException) throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND); - assertThat(future.isDone()).isFalse(); + xdsClient.ldsWatcher.onResourceDoesNotExist("not found error"); + verify(mockXdsServingStatusListener).onNotServing(any(StatusException.class)); reset(mockXdsServingStatusListener); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); - verifyServer(future, mockXdsServingStatusListener, null); + verifyServer(null, mockXdsServingStatusListener, null); } @Test @@ -236,12 +213,13 @@ public void xdsServer_startError() throws IOException, InterruptedException, TimeoutException, ExecutionException { XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); + ServerSocket serverSocket = new ServerSocket(0); + port = serverSocket.getLocalPort(); buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); // create port conflict for start to fail - ServerSocket serverSocket = new ServerSocket(port); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); Throwable exception = future.get(5, TimeUnit.SECONDS); @@ -259,16 +237,16 @@ public void xdsServerStartSecondUpdateAndError() buildServer(mockXdsServingStatusListener); Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, + xdsClient, CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), tlsContextManager); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verifyServer(future, mockXdsServingStatusListener, null); - listenerWatcher.onError(Status.ABORTED); + xdsClient.ldsWatcher.onError(Status.ABORTED); verifyServer(null, mockXdsServingStatusListener, null); } @@ -295,7 +273,7 @@ public void xdsServer_2ndSetter_expectException() throws IOException { .builder("mock").build(); when(mockBindableService.bindService()).thenReturn(serverServiceDefinition); builder.addService(mockBindableService); - xdsServer = cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)); + xdsServer = cleanupRule.register((XdsServerWrapper) builder.build()); try { builder.addService(mock(BindableService.class)); fail("exception expected"); @@ -303,4 +281,15 @@ public void xdsServer_2ndSetter_expectException() throws IOException { assertThat(expected).hasMessageThat().contains("Server already built!"); } } + + @Test + public void drainGraceTime_negativeThrows() throws IOException { + buildBuilder(null); + try { + builder.drainGraceTime(-1, TimeUnit.SECONDS); + fail("exception expected"); + } catch (IllegalArgumentException expected) { + assertThat(expected).hasMessageThat().contains("drain grace time"); + } + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 071fbd8a108..f289c4726fb 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -16,23 +16,27 @@ package io.grpc.xds; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static com.google.common.truth.Truth.assertThat; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.InsecureChannelCredentials; import io.grpc.internal.ObjectPool; +import io.grpc.xds.Bootstrapper.BootstrapInfo; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.EnvoyServerProtoData.Listener; +import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.XdsClient.LdsUpdate; -import java.io.IOException; -import java.net.ServerSocket; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import javax.annotation.Nullable; -import org.mockito.ArgumentCaptor; /** * Helper methods related to {@link XdsServerBuilder} and related classes. @@ -47,13 +51,66 @@ public class XdsServerTestHelper { static final Bootstrapper.BootstrapInfo BOOTSTRAP_INFO = new Bootstrapper.BootstrapInfo( Arrays.asList( - new Bootstrapper.ServerInfo(SERVER_URI, InsecureChannelCredentials.create(), false)), + new Bootstrapper.ServerInfo(SERVER_URI, InsecureChannelCredentials.create(), true)), BOOTSTRAP_NODE, null, "grpc/server?udpa.resource.listening_address=%s"); + static void generateListenerUpdate(FakeXdsClient xdsClient, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", + Arrays.asList(), tlsContext, null, tlsContextManager); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + } + + static void generateListenerUpdate( + FakeXdsClient xdsClient, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, + tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); + LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); + xdsClient.deliverLdsUpdate(listenerUpdate); + } + + static EnvoyServerProtoData.Listener buildTestListener( + String name, String address, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = + new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + null, + sourcePorts, + Arrays.asList(), + null); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", filterChainMatch1, httpConnectionManager, tlsContext, + tlsContextManager); + EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-bar", null, httpConnectionManager, tlsContextForDefaultFilterChain, + tlsContextManager); + EnvoyServerProtoData.Listener listener = + new EnvoyServerProtoData.Listener( + name, address, Arrays.asList(filterChain1), defaultFilterChain); + return listener; + } + static final class FakeXdsClientPoolFactory - implements XdsNameResolverProvider.XdsClientPoolFactory { + implements XdsNameResolverProvider.XdsClientPoolFactory { private XdsClient xdsClient; @@ -82,103 +139,77 @@ public XdsClient getObject() { @Override public XdsClient returnObject(Object object) { + xdsClient.shutdown(); return null; } }; } } - /** Create an XdsClientWrapperForServerSds with a mock XdsClient. */ - public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port, - TlsContextManager tlsContextManager) { - FakeXdsClientPoolFactory fakeXdsClientPoolFactory = new FakeXdsClientPoolFactory( - buildMockXdsClient(tlsContextManager)); - return new XdsClientWrapperForServerSds(port, fakeXdsClientPoolFactory); - } + static final class FakeXdsClient extends XdsClient { + boolean shutdown; + SettableFuture ldsResource = SettableFuture.create(); + LdsResourceWatcher ldsWatcher; + CountDownLatch rdsCount = new CountDownLatch(1); + final Map rdsWatchers = new HashMap<>(); - private static XdsClient buildMockXdsClient(TlsContextManager tlsContextManager) { - XdsClient xdsClient = mock(XdsClient.class); - when(xdsClient.getBootstrapInfo()).thenReturn(BOOTSTRAP_INFO); - when(xdsClient.getTlsContextManager()).thenReturn(tlsContextManager); - return xdsClient; - } + @Override + public TlsContextManager getTlsContextManager() { + return null; + } - static XdsClient.LdsResourceWatcher startAndGetWatcher( - XdsClientWrapperForServerSds xdsClientWrapperForServerSds) { - xdsClientWrapperForServerSds.start(); - XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient(); - ArgumentCaptor listenerWatcherCaptor = - ArgumentCaptor.forClass(null); - verify(mockXdsClient).watchLdsResource(any(String.class), listenerWatcherCaptor.capture()); - return listenerWatcherCaptor.getValue(); - } + @Override + public BootstrapInfo getBootstrapInfo() { + return BOOTSTRAP_INFO; + } - /** - * Creates a {@link XdsClient.LdsUpdate} with {@link - * io.grpc.xds.EnvoyServerProtoData.FilterChain} with a destination port and an optional {@link - * EnvoyServerProtoData.DownstreamTlsContext}. - * @param registeredWatcher the watcher on which to generate the update - * @param tlsContext if non-null, used to populate filterChain - */ - static void generateListenerUpdate( - XdsClient.LdsResourceWatcher registeredWatcher, - EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", - Arrays.asList(), tlsContext, null, tlsContextManager); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - } + @Override + void watchLdsResource(String resourceName, LdsResourceWatcher watcher) { + assertThat(ldsWatcher).isNull(); + ldsWatcher = watcher; + ldsResource.set(resourceName); + } - static void generateListenerUpdate( - XdsClient.LdsResourceWatcher registeredWatcher, List sourcePorts, - EnvoyServerProtoData.DownstreamTlsContext tlsContext, - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, - TlsContextManager tlsContextManager) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, - tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); - LdsUpdate listenerUpdate = LdsUpdate.forTcpListener(listener); - registeredWatcher.onChanged(listenerUpdate); - } + @Override + void cancelLdsResourceWatch(String resourceName, LdsResourceWatcher watcher) { + assertThat(ldsWatcher).isNotNull(); + ldsResource = null; + ldsWatcher = null; + } - public static void generateListenerUpdate( - XdsClient.LdsResourceWatcher registeredWatcher, EnvoyServerProtoData.Listener listener) { - registeredWatcher.onChanged(LdsUpdate.forTcpListener(listener)); - } + @Override + void watchRdsResource(String resourceName, RdsResourceWatcher watcher) { + assertThat(rdsWatchers.put(resourceName, watcher)).isNull(); //re-register is not allowed. + rdsCount.countDown(); + } - static int findFreePort() throws IOException { - try (ServerSocket socket = new ServerSocket(0)) { - socket.setReuseAddress(true); - return socket.getLocalPort(); + @Override + void cancelRdsResourceWatch(String resourceName, RdsResourceWatcher watcher) { + rdsWatchers.remove(resourceName); } - } - static EnvoyServerProtoData.Listener buildTestListener( - String name, String address, List sourcePorts, - EnvoyServerProtoData.DownstreamTlsContext tlsContext, - EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, - TlsContextManager tlsContextManager) { - EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = - new EnvoyServerProtoData.FilterChainMatch( - 0, - Arrays.asList(), - Arrays.asList(), - Arrays.asList(), - null, - sourcePorts, - Arrays.asList(), - null); - // HttpConnectionManager currently not used for server side. - HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( - 0L, "does not matter", Collections.emptyList()); - EnvoyServerProtoData.FilterChain filterChain1 = new EnvoyServerProtoData.FilterChain( - "filter-chain-foo", filterChainMatch1, httpConnectionManager, tlsContext, - tlsContextManager); - EnvoyServerProtoData.FilterChain defaultFilterChain = new EnvoyServerProtoData.FilterChain( - "filter-chain-bar", null, httpConnectionManager, tlsContextForDefaultFilterChain, - tlsContextManager); - EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener( - name, address, Arrays.asList(filterChain1), defaultFilterChain); - return listener; + @Override + void shutdown() { + shutdown = true; + } + + @Override + boolean isShutDown() { + return shutdown; + } + + void deliverLdsUpdate(List filterChains, + FilterChain defaultFilterChain) { + ldsWatcher.onChanged(LdsUpdate.forTcpListener(new Listener( + "listener", "0.0.0.0:1", filterChains, defaultFilterChain))); + } + + void deliverLdsUpdate(LdsUpdate ldsUpdate) { + ldsWatcher.onChanged(ldsUpdate); + } + + void deliverRdsUpdate(String rdsName, List virtualHosts) { + rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); + } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java new file mode 100644 index 00000000000..f2b6e9e4790 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -0,0 +1,1184 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsServerWrapper.ATTR_SERVER_ROUTING_CONFIG; +import static io.grpc.xds.XdsServerWrapper.RETRY_DELAY_NANOS; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.Attributes; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.FakeClock; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.EnvoyServerProtoData.FilterChain; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.Filter.NamedFilterConfig; +import io.grpc.xds.Filter.ServerInterceptorBuilder; +import io.grpc.xds.FilterChainMatchingProtocolNegotiators.FilterChainMatchingHandler.FilterChainSelector; +import io.grpc.xds.VirtualHost.Route; +import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClient.RdsResourceWatcher; +import io.grpc.xds.XdsClient.RdsUpdate; +import io.grpc.xds.XdsServerBuilder.XdsServingStatusListener; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClient; +import io.grpc.xds.XdsServerTestHelper.FakeXdsClientPoolFactory; +import io.grpc.xds.XdsServerWrapper.ConfigApplyingInterceptor; +import io.grpc.xds.XdsServerWrapper.ServerRoutingConfig; +import io.grpc.xds.internal.Matchers.HeaderMatcher; +import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class XdsServerWrapperTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private ServerBuilder mockBuilder; + @Mock + private Server mockServer; + @Mock + private static TlsContextManager tlsContextManager; + @Mock + private XdsServingStatusListener listener; + + private FilterChainSelectorManager selectorManager = new FilterChainSelectorManager(); + private FakeClock executor = new FakeClock(); + private FakeXdsClient xdsClient = new FakeXdsClient(); + private FilterRegistry filterRegistry = FilterRegistry.getDefaultRegistry(); + private XdsServerWrapper xdsServerWrapper; + private ServerRoutingConfig noopConfig = ServerRoutingConfig.create( + ImmutableList.of(), ImmutableMap.of()); + + @Before + public void setup() { + when(mockBuilder.build()).thenReturn(mockServer); + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), + filterRegistry, executor.getScheduledExecutorService()); + } + + @Test + public void testBootstrap_notV3() throws Exception { + Bootstrapper.BootstrapInfo b = + new Bootstrapper.BootstrapInfo( + Arrays.asList( + new Bootstrapper.ServerInfo("uri", InsecureChannelCredentials.create(), false)), + EnvoyProtoData.Node.newBuilder().setId("id").build(), + null, + "grpc/server?udpa.resource.listening_address=%s"); + verifyBootstrapFail(b); + } + + @Test + public void testBootstrap_noTemplate() throws Exception { + Bootstrapper.BootstrapInfo b = + new Bootstrapper.BootstrapInfo( + Arrays.asList( + new Bootstrapper.ServerInfo("uri", InsecureChannelCredentials.create(), true)), + EnvoyProtoData.Node.newBuilder().setId("id").build(), + null, + null); + verifyBootstrapFail(b); + } + + private void verifyBootstrapFail(Bootstrapper.BootstrapInfo b) throws Exception { + XdsClient xdsClient = mock(XdsClient.class); + when(xdsClient.getBootstrapInfo()).thenReturn(b); + xdsServerWrapper = new XdsServerWrapper("0.0.0.0:1", mockBuilder, listener, + selectorManager, new FakeXdsClientPoolFactory(xdsClient), filterRegistry); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + Throwable cause = ex.getCause().getCause(); + assertThat(cause).isInstanceOf(StatusException.class); + assertThat(((StatusException)cause).getStatus().getCode()) + .isEqualTo(Status.UNAVAILABLE.getCode()); + } + } + + @Test + public void shutdown() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + HttpConnectionManager hcm_virtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(createVirtualHost("virtual-host-0")), + new ArrayList()); + FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); + FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); + xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + xdsServerWrapper.shutdown(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.ldsResource).isNull(); + assertThat(xdsClient.shutdown).isTrue(); + verify(mockServer).shutdown(); + assertThat(f0.getSslContextProviderSupplier().isShutdown()).isTrue(); + assertThat(f1.getSslContextProviderSupplier().isShutdown()).isTrue(); + when(mockServer.isTerminated()).thenReturn(true); + when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); + assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + assertThat(start.get()).isSameInstanceAs(xdsServerWrapper); + } + + @Test + public void shutdown_inflight() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(createVirtualHost("virtual-host-0")), + new ArrayList()); + FilterChain f0 = createFilterChain("filter-chain-0", createRds("rds")); + FilterChain f1 = createFilterChain("filter-chain-1", hcmVirtual); + xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); + xdsServerWrapper.shutdown(); + when(mockServer.isTerminated()).thenReturn(true); + when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); + assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + verify(mockServer, never()).start(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.ldsResource).isNull(); + assertThat(xdsClient.shutdown).isTrue(); + verify(mockServer).shutdown(); + assertThat(f0.getSslContextProviderSupplier().isShutdown()).isTrue(); + assertThat(f1.getSslContextProviderSupplier().isShutdown()).isTrue(); + assertThat(start.isDone()).isFalse(); //shall we set initialStatus when shutdown? + } + + @Test + public void shutdown_afterResourceNotExist() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + verify(mockBuilder, times(1)).build(); + verify(mockServer, never()).start(); + verify(mockServer).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + when(mockServer.isTerminated()).thenReturn(true); + verify(listener, times(1)).onNotServing(any(Throwable.class)); + xdsServerWrapper.shutdown(); + assertThat(xdsServerWrapper.isShutdown()).isTrue(); + assertThat(xdsClient.ldsResource).isNull(); + assertThat(xdsClient.shutdown).isTrue(); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(1)).shutdown(); + xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + } + + @Test + public void shutdown_pendingRetry() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + when(mockServer.start()).thenThrow(new IOException("error!")); + FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); + SslContextProviderSupplier sslSupplier = filterChain.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + assertThat(executor.getPendingTasks().size()).isEqualTo(1); + verify(mockServer).start(); + verify(mockServer, never()).shutdown(); + xdsServerWrapper.shutdown(); + verify(mockServer).shutdown(); + when(mockServer.isTerminated()).thenReturn(true); + assertThat(sslSupplier.isShutdown()).isTrue(); + assertThat(executor.getPendingTasks().size()).isEqualTo(0); + verify(listener, never()).onNotServing(any(Throwable.class)); + verify(listener, never()).onServing(); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + } + + @Test + public void initialStartIoException() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + when(mockServer.start()).thenThrow(new IOException("error!")); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + assertThat(ex.getCause().getMessage()).isEqualTo("error!"); + } + } + + @Test + public void discoverState_virtualhost() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = + VirtualHost.create( + "virtual-host", Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain filterChain = new EnvoyServerProtoData.FilterChain( + "filter-chain-foo", createMatch(), httpConnectionManager, createTls(), + tlsContextManager); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + start.get(5000, TimeUnit.MILLISECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain).get(); + assertThat(realConfig.virtualHosts()).isEqualTo(httpConnectionManager.virtualHosts()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + verify(listener).onServing(); + verify(mockServer).start(); + } + + @Test + public void discoverState_rds() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = createVirtualHost("virtual-host-0"); + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); + EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); + xdsClient.rdsCount = new CountDownLatch(3); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); + assertThat(start.isDone()).isFalse(); + assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); + verify(mockServer, never()).start(); + verify(listener, never()).onServing(); + + EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r1")); + EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r2")); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); + verify(mockServer, never()).start(); + verify(listener, never()).onServing(); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + + xdsClient.deliverRdsUpdate("r1", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + verify(mockServer, never()).start(); + xdsClient.deliverRdsUpdate("r2", + Collections.singletonList(createVirtualHost("virtual-host-2"))); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(2); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f2).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig().get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-2"))); + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) + .isEqualTo(f3.getSslContextProviderSupplier()); + } + + @Test + public void discoverState_oneRdsToMultipleFilterChain() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", createRds("r0")); + EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); + EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); + + xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); + assertThat(start.isDone()).isFalse(); + assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); + + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("r0", + Collections.singletonList(createVirtualHost("virtual-host-0"))); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer, times(1)).start(); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig().get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) + .isSameInstanceAs(f2.getSslContextProviderSupplier()); + + EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); + EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); + EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); + xdsClient.rdsCount = new CountDownLatch(1); + xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("r1", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + xdsClient.deliverRdsUpdate("r0", + Collections.singletonList(createVirtualHost("virtual-host-0"))); + + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(2); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f5).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f3).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-0"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + realConfig = selectorManager.getSelectorToUpdateSelector().getDefaultRoutingConfig().get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + assertThat(selectorManager.getSelectorToUpdateSelector().getDefaultSslContextProviderSupplier()) + .isSameInstanceAs(f4.getSslContextProviderSupplier()); + verify(mockServer, times(1)).start(); + xdsServerWrapper.shutdown(); + verify(mockServer, times(1)).shutdown(); + when(mockServer.isTerminated()).thenReturn(true); + xdsServerWrapper.awaitTermination(); + assertThat(xdsServerWrapper.isTerminated()).isTrue(); + } + + @Test + public void discoverState_rds_onError_and_resourceNotExist() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsWatched = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + assertThat(ldsWatched).isEqualTo("grpc/server?udpa.resource.listening_address=0.0.0.0:1"); + VirtualHost virtualHost = createVirtualHost("virtual-host-0"); + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), new ArrayList()); + EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); + EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); + xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); + xdsClient.rdsCount.await(); + xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + start.get(5000, TimeUnit.MILLISECONDS); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(2); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEmpty(); + assertThat(realConfig.interceptors()).isEmpty(); + + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f0).get(); + assertThat(realConfig.virtualHosts()).isEqualTo(hcmVirtual.virtualHosts()); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + xdsClient.deliverRdsUpdate("r0", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(f1).get(); + assertThat(realConfig.virtualHosts()).isEmpty(); + assertThat(realConfig.interceptors()).isEmpty(); + } + + @Test + public void error() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + String ldsResource = xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + try { + start.get(5000, TimeUnit.MILLISECONDS); + fail("Start should throw exception"); + } catch (ExecutionException ex) { + assertThat(ex.getCause()).isInstanceOf(IOException.class); + } + verify(listener, times(1)).onNotServing(any(StatusException.class)); + verify(mockBuilder, times(1)).build(); + FilterChain filterChain0 = createFilterChain("filter-chain-0", createRds("rds")); + SslContextProviderSupplier sslSupplier0 = filterChain0.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain0), null); + xdsClient.ldsWatcher.onError(Status.INTERNAL); + assertThat(selectorManager.getSelectorToUpdateSelector()) + .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + assertThat(xdsClient.rdsWatchers).isEmpty(); + verify(mockBuilder, times(1)).build(); + verify(listener, times(2)).onNotServing(any(StatusException.class)); + assertThat(sslSupplier0.isShutdown()).isFalse(); + + when(mockServer.start()).thenThrow(new IOException("error!")) + .thenReturn(mockServer); + FilterChain filterChain1 = createFilterChain("filter-chain-1", createRds("rds")); + SslContextProviderSupplier sslSupplier1 = filterChain1.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain1), null); + assertThat(sslSupplier0.isShutdown()).isTrue(); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + RdsResourceWatcher saveRdsWatcher = xdsClient.rdsWatchers.get("rds"); + assertThat(executor.forwardNanos(RETRY_DELAY_NANOS)).isEqualTo(1); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(2)).start(); + verify(listener, times(1)).onServing(); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerRoutingConfig realConfig = + selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().get(filterChain1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + // xds update after start + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-2"))); + assertThat(sslSupplier1.isShutdown()).isFalse(); + xdsClient.ldsWatcher.onError(Status.DEADLINE_EXCEEDED); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(2)).start(); + verify(listener, times(2)).onNotServing(any(StatusException.class)); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain1).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-2"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + assertThat(sslSupplier1.isShutdown()).isFalse(); + + // not serving after serving + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + assertThat(xdsClient.rdsWatchers).isEmpty(); + verify(mockServer, times(3)).shutdown(); + when(mockServer.isShutdown()).thenReturn(true); + assertThat(selectorManager.getSelectorToUpdateSelector()) + .isSameInstanceAs(FilterChainSelector.NO_FILTER_CHAIN); + verify(listener, times(3)).onNotServing(any(StatusException.class)); + assertThat(sslSupplier1.isShutdown()).isTrue(); + // no op + saveRdsWatcher.onChanged( + new RdsUpdate(Collections.singletonList(createVirtualHost("virtual-host-1")))); + verify(mockBuilder, times(1)).build(); + verify(mockServer, times(2)).start(); + verify(listener, times(1)).onServing(); + + // cancel retry + when(mockServer.start()).thenThrow(new IOException("error1!")) + .thenThrow(new IOException("error2!")) + .thenReturn(mockServer); + FilterChain filterChain2 = createFilterChain("filter-chain-2", createRds("rds")); + SslContextProviderSupplier sslSupplier2 = filterChain2.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain2), null); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(sslSupplier1.isShutdown()).isTrue(); + verify(mockBuilder, times(2)).build(); + when(mockServer.isShutdown()).thenReturn(false); + verify(mockServer, times(3)).start(); + verify(listener, times(1)).onServing(); + verify(listener, times(3)).onNotServing(any(StatusException.class)); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain2).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + assertThat(executor.numPendingTasks()).isEqualTo(1); + xdsClient.ldsWatcher.onResourceDoesNotExist(ldsResource); + verify(mockServer, times(4)).shutdown(); + verify(listener, times(4)).onNotServing(any(StatusException.class)); + when(mockServer.isShutdown()).thenReturn(true); + assertThat(executor.numPendingTasks()).isEqualTo(0); + assertThat(sslSupplier2.isShutdown()).isTrue(); + + // serving after not serving + FilterChain filterChain3 = createFilterChain("filter-chain-2", createRds("rds")); + SslContextProviderSupplier sslSupplier3 = filterChain3.getSslContextProviderSupplier(); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain3), null); + xdsClient.deliverRdsUpdate("rds", + Collections.singletonList(createVirtualHost("virtual-host-1"))); + verify(mockBuilder, times(3)).build(); + verify(mockServer, times(4)).start(); + verify(listener, times(1)).onServing(); + when(mockServer.isShutdown()).thenReturn(false); + verify(listener, times(4)).onNotServing(any(StatusException.class)); + + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + realConfig = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain3).get(); + assertThat(realConfig.virtualHosts()).isEqualTo( + Collections.singletonList(createVirtualHost("virtual-host-1"))); + assertThat(realConfig.interceptors()).isEqualTo(ImmutableMap.of()); + + xdsServerWrapper.shutdown(); + verify(mockServer, times(5)).shutdown(); + assertThat(sslSupplier3.isShutdown()).isTrue(); + when(mockServer.awaitTermination(anyLong(), any(TimeUnit.class))).thenReturn(true); + assertThat(xdsServerWrapper.awaitTermination(5, TimeUnit.SECONDS)).isTrue(); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_success() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); + Route route = Route.forAction(routeMatch, null, + ImmutableMap.of()); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of()); + final List interceptorTrace = new ArrayList<>(); + ServerInterceptor interceptor0 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(0); + return next.startCall(call, headers); + } + }; + ServerRoutingConfig realConfig = ServerRoutingConfig.create( + ImmutableList.of(virtualHost), ImmutableMap.of(route, interceptor0)); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, + new AtomicReference<>(realConfig)).build()); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next).startCall(eq(serverCall), any(Metadata.class)); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(0)); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_virtualHostNotMatch() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url"); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, + new AtomicReference<>(routingConfig)).build()); + when(serverCall.getAuthority()).thenReturn("not-match.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo("Could not find xDS virtual host matching RPC"); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_routeNotMatch() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url"); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder() + .set(ATTR_SERVER_ROUTING_CONFIG, new AtomicReference<>(routingConfig)).build()); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("NotMatchMethod")); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo("Could not find xDS route matching RPC"); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_invalidRouteAction() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerRoutingConfig routingConfig = createRoutingConfig("/FooService/barMethod", + "foo.google.com", "filter-type-url", Route.RouteAction.forCluster( + "cluster", Collections.emptyList(), null, null + )); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder() + .set(ATTR_SERVER_ROUTING_CONFIG, new AtomicReference<>(routingConfig)).build()); + when(serverCall.getMethodDescriptor()).thenReturn(createMethod("FooService/barMethod")); + when(serverCall.getAuthority()).thenReturn("foo.google.com"); + + Filter filter = mock(Filter.class); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo("Invalid xDS route action for matching " + + "route: only Route.non_forwarding_action should be allowed."); + } + + @Test + @SuppressWarnings("unchecked") + public void interceptor_failingRouterConfig() throws Exception { + ArgumentCaptor interceptorCaptor = + ArgumentCaptor.forClass(ConfigApplyingInterceptor.class); + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + verify(mockBuilder).intercept(interceptorCaptor.capture()); + ConfigApplyingInterceptor interceptor = interceptorCaptor.getValue(); + ServerCall serverCall = mock(ServerCall.class); + + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(ATTR_SERVER_ROUTING_CONFIG, + new AtomicReference<>(ServerRoutingConfig.FAILING_ROUTING_CONFIG)).build()); + + ServerCallHandler next = mock(ServerCallHandler.class); + interceptor.interceptCall(serverCall, new Metadata(), next); + verify(next, never()).startCall(any(ServerCall.class), any(Metadata.class)); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(serverCall).close(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertThat(status.getCode()).isEqualTo(Status.UNAVAILABLE.getCode()); + assertThat(status.getDescription()).isEqualTo( + "Missing or broken xDS routing config: RDS config unavailable."); + } + + @Test + @SuppressWarnings("unchecked") + public void buildInterceptor_inline() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); + Filter filter = mock(Filter.class, withSettings() + .extraInterfaces(ServerInterceptorBuilder.class)); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + FilterConfig f0 = mock(FilterConfig.class); + FilterConfig f0Override = mock(FilterConfig.class); + when(f0.typeUrl()).thenReturn("filter-type-url"); + final List interceptorTrace = new ArrayList<>(); + ServerInterceptor interceptor0 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(0); + return next.startCall(call, headers); + } + }; + ServerInterceptor interceptor1 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(1); + return next.startCall(call, headers); + } + }; + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) + .thenReturn(interceptor0); + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) + .thenReturn(interceptor1); + Route route = Route.forAction(routeMatch, null, + ImmutableMap.of()); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of("filter-config-name-0", f0Override)); + HttpConnectionManager hcmVirtual = HttpConnectionManager.forVirtualHosts( + 0L, Collections.singletonList(virtualHost), + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0), + new NamedFilterConfig("filter-config-name-1", f0))); + EnvoyServerProtoData.FilterChain filterChain = createFilterChain("filter-chain-0", hcmVirtual); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerInterceptor realInterceptor = selectorManager.getSelectorToUpdateSelector() + .getRoutingConfigs().get(filterChain).get().interceptors().get(route); + assertThat(realInterceptor).isNotNull(); + + ServerCall serverCall = mock(ServerCall.class); + ServerCallHandler mockNext = mock(ServerCallHandler.class); + final ServerCall.Listener listener = new ServerCall.Listener() {}; + when(mockNext.startCall(any(ServerCall.class), any(Metadata.class))).thenReturn(listener); + realInterceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(1, 0)); + verify(mockNext).startCall(eq(serverCall), any(Metadata.class)); + } + + @Test + @SuppressWarnings("unchecked") + public void buildInterceptor_rds() throws Exception { + final SettableFuture start = SettableFuture.create(); + Executors.newSingleThreadExecutor().execute(new Runnable() { + @Override + public void run() { + try { + start.set(xdsServerWrapper.start()); + } catch (Exception ex) { + start.setException(ex); + } + } + }); + xdsClient.ldsResource.get(5, TimeUnit.SECONDS); + + Filter filter = mock(Filter.class, withSettings() + .extraInterfaces(ServerInterceptorBuilder.class)); + when(filter.typeUrls()).thenReturn(new String[]{"filter-type-url"}); + filterRegistry.register(filter); + FilterConfig f0 = mock(FilterConfig.class); + FilterConfig f0Override = mock(FilterConfig.class); + when(f0.typeUrl()).thenReturn("filter-type-url"); + final List interceptorTrace = new ArrayList<>(); + ServerInterceptor interceptor0 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(0); + return next.startCall(call, headers); + } + }; + ServerInterceptor interceptor1 = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + interceptorTrace.add(1); + return next.startCall(call, headers); + } + }; + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, null)) + .thenReturn(interceptor0); + when(((ServerInterceptorBuilder)filter).buildServerInterceptor(f0, f0Override)) + .thenReturn(interceptor1); + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", true), + Collections.emptyList(), null); + + HttpConnectionManager rdsHcm = HttpConnectionManager.forRdsName(0L, "r0", + Arrays.asList(new NamedFilterConfig("filter-config-name-0", f0), + new NamedFilterConfig("filter-config-name-1", f0))); + EnvoyServerProtoData.FilterChain filterChain = createFilterChain("filter-chain-0", rdsHcm); + xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); + Route route = Route.forAction(routeMatch, null, + ImmutableMap.of()); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of("filter-config-name-0", f0Override)); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); + xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost)); + start.get(5000, TimeUnit.MILLISECONDS); + verify(mockServer).start(); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) + .isEqualTo(1); + ServerInterceptor realInterceptor = selectorManager.getSelectorToUpdateSelector() + .getRoutingConfigs().get(filterChain).get().interceptors().get(route); + assertThat(realInterceptor).isNotNull(); + + ServerCall serverCall = mock(ServerCall.class); + ServerCallHandler mockNext = mock(ServerCallHandler.class); + final ServerCall.Listener listener = new ServerCall.Listener() {}; + when(mockNext.startCall(any(ServerCall.class), any(Metadata.class))).thenReturn(listener); + realInterceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(1, 0)); + verify(mockNext).startCall(eq(serverCall), any(Metadata.class)); + + virtualHost = VirtualHost.create( + "v1", Collections.singletonList("foo.google.com"), Arrays.asList(route), + ImmutableMap.of()); + xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost)); + realInterceptor = selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain).get().interceptors().get(route); + assertThat(realInterceptor).isNotNull(); + interceptorTrace.clear(); + realInterceptor.interceptCall(serverCall, new Metadata(), mockNext); + assertThat(interceptorTrace).isEqualTo(Arrays.asList(0, 0)); + verify(mockNext, times(2)).startCall(eq(serverCall), any(Metadata.class)); + + xdsClient.rdsWatchers.get("r0").onResourceDoesNotExist("r0"); + assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs() + .get(filterChain).get()).isEqualTo(noopConfig); + } + + private static FilterChain createFilterChain(String name, HttpConnectionManager hcm) { + return new EnvoyServerProtoData.FilterChain(name, createMatch(), + hcm, createTls(), tlsContextManager); + } + + private static VirtualHost createVirtualHost(String name) { + return VirtualHost.create( + name, Collections.singletonList("auth"), new ArrayList(), + ImmutableMap.of()); + } + + private static HttpConnectionManager createRds(String name) { + return createRds(name, null); + } + + private static HttpConnectionManager createRds(String name, FilterConfig filterConfig) { + return HttpConnectionManager.forRdsName(0L, name, + Arrays.asList(new NamedFilterConfig("named-config-" + name, filterConfig))); + } + + private static EnvoyServerProtoData.FilterChainMatch createMatch() { + return new EnvoyServerProtoData.FilterChainMatch( + 0, + Arrays.asList(), + Arrays.asList(), + Arrays.asList(), + EnvoyServerProtoData.ConnectionSourceType.ANY, + Arrays.asList(), + Arrays.asList(), + null); + } + + private static ServerRoutingConfig createRoutingConfig(String path, String domain, + String filterType) { + return createRoutingConfig(path, domain, filterType, null); + } + + private static ServerRoutingConfig createRoutingConfig(String path, String domain, + String filterType, Route.RouteAction action) { + RouteMatch routeMatch = + RouteMatch.create( + PathMatcher.fromPath(path, true), + Collections.emptyList(), null); + VirtualHost virtualHost = VirtualHost.create( + "v1", Collections.singletonList(domain), + Arrays.asList(Route.forAction(routeMatch, action, + ImmutableMap.of())), + Collections.emptyMap()); + FilterConfig f0 = mock(FilterConfig.class); + when(f0.typeUrl()).thenReturn(filterType); + return ServerRoutingConfig.create(ImmutableList.of(virtualHost), + ImmutableMap.of() + ); + } + + private static MethodDescriptor createMethod(String path) { + return MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setFullMethodName(path) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) + .build(); + } + + private static EnvoyServerProtoData.DownstreamTlsContext createTls() { + return CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java index 4fb4acc41f6..93a9b7087d6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java @@ -127,45 +127,58 @@ public void headerMatcher() { HeaderMatcher matcher = HeaderMatcher.forExactValue("version", "v1", false); assertThat(matcher.matches("v1")).isTrue(); assertThat(matcher.matches("v2")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forExactValue("version", "v1", true); assertThat(matcher.matches("v1")).isFalse(); assertThat(matcher.matches( "v2")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPresent("version", true, false); assertThat(matcher.matches("any")).isTrue(); assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPresent("version", true, true); assertThat(matcher.matches("version")).isFalse(); + assertThat(matcher.matches(null)).isTrue(); matcher = HeaderMatcher.forPresent("version", false, true); assertThat(matcher.matches("tag")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPresent("version", false, false); assertThat(matcher.matches("tag")).isFalse(); + assertThat(matcher.matches(null)).isTrue(); matcher = HeaderMatcher.forPrefix("version", "v2", false); assertThat(matcher.matches("v22")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forPrefix("version", "v2", true); assertThat(matcher.matches("v22")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSuffix("version", "v1", false); assertThat(matcher.matches("xv1")).isTrue(); assertThat(matcher.matches("v1x")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSuffix("version", "v2", true); assertThat(matcher.matches("xv1")).isTrue(); assertThat(matcher.matches("1v2")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v2.*"), false); assertThat(matcher.matches("v2..")).isTrue(); assertThat(matcher.matches("v1")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v1\\..*"), true); assertThat(matcher.matches("v1.43")).isFalse(); assertThat(matcher.matches("v2")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forRange("version", Range.create(8080L, 8090L), false); assertThat(matcher.matches("8080")).isTrue(); assertThat(matcher.matches("1")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); matcher = HeaderMatcher.forRange("version", Range.create(8080L, 8090L), true); assertThat(matcher.matches("1")).isTrue(); assertThat(matcher.matches("8080")).isFalse(); + assertThat(matcher.matches(null)).isFalse(); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java index 00b29014648..1eed5488aa0 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderClientSslContextProviderTest.java @@ -87,6 +87,27 @@ private CertProviderClientSslContextProvider getSslContextProvider( bootstrapInfo.getCertProviders()); } + /** Helper method to build CertProviderClientSslContextProvider. */ + private CertProviderClientSslContextProvider getNewSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + @Test public void testProviderForClient_mtls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -150,6 +171,69 @@ public void testProviderForClient_mtls() throws Exception { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_mtls_newXds() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getNewSslContextProvider( + "gcp_id", + "gcp_id", + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForClient_queueExecutor() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java index ef801ccc2c1..783ce2b11f7 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertProviderServerSslContextProviderTest.java @@ -31,12 +31,14 @@ import com.google.common.util.concurrent.MoreExecutors; import io.envoyproxy.envoy.config.core.v3.DataSource; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback; +import java.util.Arrays; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -81,6 +83,30 @@ private CertProviderServerSslContextProvider getSslContextProvider( bootstrapInfo.getCertProviders()); } + /** Helper method to build CertProviderServerSslContextProvider. */ + private CertProviderServerSslContextProvider getNewSslContextProvider( + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean requireClientCert) { + EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext = + CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext, + requireClientCert); + return certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); + } + + @Test public void testProviderForServer_mtls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = @@ -145,6 +171,74 @@ public void testProviderForServer_mtls() throws Exception { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForServer_mtls_newXds() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertificateValidationContext staticCertValidationContext = + CertificateValidationContext.newBuilder().addAllMatchSubjectAltNames(Arrays + .asList(StringMatcher.newBuilder().setExact("foo.com").build(), + StringMatcher.newBuilder().setExact("bar.com").build())).build(); + CertProviderServerSslContextProvider provider = + getNewSslContextProvider( + "gcp_id", + "gcp_id", + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + staticCertValidationContext, + /* requireClientCert= */ true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNotNull(); + assertThat(provider.savedCertChain).isNotNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForServer_queueExecutor() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = diff --git a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java index 504c9e8df2a..626a4cfc275 100644 --- a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java @@ -16,12 +16,14 @@ package io.grpc.xds.internal.rbac.engine; +import static com.google.common.base.Charsets.US_ASCII; import static com.google.common.truth.Truth.assertThat; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import com.google.common.io.BaseEncoding; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; @@ -177,6 +179,71 @@ public void headerMatcher() { assertThat(decision.decision()).isEqualTo(Action.DENY); } + @Test + public void headerMatcher_binaryHeader() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(HEADER_KEY + Metadata.BINARY_HEADER_SUFFIX, + BaseEncoding.base64().omitPadding().encode(HEADER_VALUE.getBytes(US_ASCII)), false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of(HEADER_KEY + Metadata.BINARY_HEADER_SUFFIX, + Metadata.BINARY_BYTE_MARSHALLER), HEADER_VALUE.getBytes(US_ASCII)); + AuthDecision decision = engine.evaluate(metadata, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void headerMatcher_hardcodePostMethod() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(":method", "POST", false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void headerMatcher_pathHeader() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(":path", "/" + PATH, false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void headerMatcher_aliasAuthorityAndHost() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue("Host", "google.com", false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + when(serverCall.getAuthority()).thenReturn("google.com"); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + @Test public void pathMatcher() { PathMatcher pathMatcher = new PathMatcher(STRING_MATCHER); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index dfadee957c1..06a3198b263 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -258,6 +258,7 @@ static void verifyWatcher( .isSameInstanceAs(sslContextProvider); } + @SuppressWarnings("deprecation") static CommonTlsContext.Builder addFilenames( CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) { TlsCertificate tlsCert = diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java index 2914e5f3937..840cced424f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java @@ -22,6 +22,7 @@ import com.google.common.io.CharStreams; import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; @@ -62,6 +63,7 @@ public class CommonTlsContextTestsUtil { public static final String BAD_CLIENT_KEY_FILE = "badclient.key"; /** takes additional values and creates CombinedCertificateValidationContext as needed. */ + @SuppressWarnings("deprecation") static CommonTlsContext buildCommonTlsContextWithAdditionalValues( String certInstanceName, String certName, String validationContextCertInstanceName, String validationContextCertName, @@ -146,7 +148,7 @@ public static DownstreamTlsContext buildTestDownstreamTlsContext( if (certName != null || validationContextCertName != null || useSans) { commonTlsContext = buildCommonTlsContextWithAdditionalValues( "cert-instance-name", certName, - "val-cert-instance-name", validationContextCertName, + "cert-instance-name", validationContextCertName, useSans ? Arrays.asList( StringMatcher.newBuilder() .setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob") @@ -208,6 +210,7 @@ public static String getResourceContents(String resourceName) throws IOException return text; } + @SuppressWarnings("deprecation") private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( String certInstanceName, String certName, @@ -232,6 +235,31 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( return builder.build(); } + private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( + String certInstanceName, + String certName, + String rootInstanceName, + String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); + if (certInstanceName != null) { + builder = + builder.setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(certInstanceName) + .setCertificateName(certName)); + } + builder = + addNewCertificateValidationContext( + builder, rootInstanceName, rootCertName, staticCertValidationContext); + if (alpnProtocols != null) { + builder.addAllAlpnProtocols(alpnProtocols); + } + return builder.build(); + } + + @SuppressWarnings("deprecation") private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, @@ -256,6 +284,26 @@ private static CommonTlsContext.Builder addCertificateValidationContext( return builder; } + private static CommonTlsContext.Builder addNewCertificateValidationContext( + CommonTlsContext.Builder builder, + String rootInstanceName, + String rootCertName, + CertificateValidationContext staticCertValidationContext) { + if (rootInstanceName != null) { + CertificateProviderPluginInstance providerInstance = + CertificateProviderPluginInstance.newBuilder() + .setInstanceName(rootInstanceName) + .setCertificateName(rootCertName) + .build(); + CertificateValidationContext.Builder validationContextBuilder = + staticCertValidationContext != null ? staticCertValidationContext.toBuilder() + : CertificateValidationContext.newBuilder(); + return builder.setValidationContext( + validationContextBuilder.setCaCertificateProviderInstance(providerInstance)); + } + return builder; + } + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextForCertProviderInstance( @@ -275,6 +323,25 @@ private static CommonTlsContext.Builder addCertificateValidationContext( staticCertValidationContext)); } + /** Helper method to build UpstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.UpstreamTlsContext + buildNewUpstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext) { + return buildUpstreamTlsContext( + buildNewCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext)); + } + /** Helper method to build DownstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContextForCertProviderInstance( @@ -295,6 +362,25 @@ private static CommonTlsContext.Builder addCertificateValidationContext( staticCertValidationContext), requireClientCert); } + /** Helper method to build DownstreamTlsContext for CertProvider tests. */ + public static EnvoyServerProtoData.DownstreamTlsContext + buildNewDownstreamTlsContextForCertProviderInstance( + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean requireClientCert) { + return buildInternalDownstreamTlsContext( + buildNewCommonTlsContextForCertProviderInstance( + certInstanceName, + certName, + rootInstanceName, + rootCertName, + alpnProtocols, + staticCertValidationContext), requireClientCert); + } /** Perform some simple checks on sslContext. */ public static void doChecksOnSslContext(boolean server, SslContext sslContext, diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java index a8a7a9c9e30..4c89aa4b79a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java @@ -22,6 +22,7 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; +import static io.grpc.xds.internal.sds.SdsProtocolNegotiators.ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -42,16 +43,13 @@ import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; +import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.Bootstrapper; import io.grpc.xds.CommonBootstrapperTestUtils; -import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.TlsContextManager; -import io.grpc.xds.XdsClientWrapperForServerSds; -import io.grpc.xds.XdsClientWrapperForServerSdsTestMisc; -import io.grpc.xds.XdsServerTestHelper; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsHandler; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsProtocolNegotiator; import io.netty.channel.ChannelHandler; @@ -74,7 +72,6 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertStoreException; -import java.util.Arrays; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -216,18 +213,19 @@ public SocketAddress remoteAddress() { "google_cloud_private_spiffe-server", true, true); TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext, tlsContextManager); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, - InternalProtocolNegotiators.serverPlaintext()); + new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, + InternalProtocolNegotiators.serverPlaintext()); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler // kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event) + .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager)).build(); + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr)); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNull(); channelHandlerCtx = pipeline.context(SdsProtocolNegotiators.ServerSdsHandler.class); @@ -278,23 +276,19 @@ public SocketAddress localAddress() { } }; pipeline = channel.pipeline(); - DownstreamTlsContext downstreamTlsContext = - DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - .getDefaultInstance()); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext, mock(TlsContextManager.class)); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( - grpcHandler, xdsClientWrapperForServerSds, mockProtocolNegotiator); + grpcHandler, mockProtocolNegotiator); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler // kick off protocol negotiation: should replace HandlerPickerHandler with ServerSdsHandler - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + ProtocolNegotiationEvent event = InternalProtocolNegotiationEvent.getDefault(); + Attributes attr = InternalProtocolNegotiationEvent.getAttributes(event) + .toBuilder().set(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER, null).build(); + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.withAttributes(event, attr)); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNull(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop @@ -311,8 +305,7 @@ public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { when(mockProtocolNegotiator.newHandler(grpcHandler)).thenReturn(mockChannelHandler); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( - grpcHandler, /* xdsClientWrapperForServerSds= */ null, - mockProtocolNegotiator); + grpcHandler, mockProtocolNegotiator); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler @@ -332,8 +325,7 @@ public void serverSdsHandler_nullTlsContext_expectFallbackProtocolNegotiator() { public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( - grpcHandler, /* xdsClientWrapperForServerSds= */ null, - null); + grpcHandler, null); pipeline.addLast(handlerPickerHandler); channelHandlerCtx = pipeline.context(handlerPickerHandler); assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler @@ -351,54 +343,6 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { } } - @Test - public void noMatchingFilterChain_expectException() { - // we need InetSocketAddress instead of EmbeddedSocketAddress as localAddress for this test - channel = - new EmbeddedChannel() { - @Override - public SocketAddress localAddress() { - return new InetSocketAddress("172.168.1.1", 80); - } - - @Override - public SocketAddress remoteAddress() { - return new InetSocketAddress("172.168.2.2", 90); - } - }; - pipeline = channel.pipeline(); - Bootstrapper.BootstrapInfo bootstrapInfoForServer = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-server", SERVER_1_KEY_FILE, - SERVER_1_PEM_FILE, CA_PEM_FILE, null, null, null, null); - - TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(bootstrapInfoForServer); - XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(80, tlsContextManager); - xdsClientWrapperForServerSds.start(); - EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( - "listener1", "0.0.0.0", Arrays.asList(), null); - XdsServerTestHelper.generateListenerUpdate( - xdsClientWrapperForServerSds.getListenerWatcher(), listener); - - SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = - new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, - InternalProtocolNegotiators.serverPlaintext()); - pipeline.addLast(handlerPickerHandler); - channelHandlerCtx = pipeline.context(handlerPickerHandler); - assertThat(channelHandlerCtx).isNotNull(); // should find HandlerPickerHandler - - // kick off protocol negotiation - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); - channelHandlerCtx = pipeline.context(handlerPickerHandler); - assertThat(channelHandlerCtx).isNotNull(); // HandlerPickerHandler still there - try { - channel.checkException(); - fail("exception expected!"); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("no matching filter chain"); - } - } - @Test public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() throws InterruptedException, TimeoutException, ExecutionException { diff --git a/xds/third_party/envoy/import.sh b/xds/third_party/envoy/import.sh index 4c3fd1b3c70..c77ee9272e0 100755 --- a/xds/third_party/envoy/import.sh +++ b/xds/third_party/envoy/import.sh @@ -18,7 +18,7 @@ set -e BRANCH=main # import VERSION from one of the google internal CLs -VERSION=62ca8bd2b5960ed1c6ce2be97d3120cee719ecab +VERSION=c223756b0856f734a6a5cff2d0b95388cd2583d4 GIT_REPO="https://github.com/envoyproxy/envoy.git" GIT_BASE_DIR=envoy SOURCE_PROTO_BASE_DIR=envoy/api diff --git a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto index a6791c86cd0..08738962c5e 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/api/v2/listener/listener_components.proto @@ -230,7 +230,7 @@ message FilterChain { // rules: // - destination_port_range: // start: 3306 -// end: 3306 +// end: 3307 // - destination_port_range: // start: 15000 // end: 15001 diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto index ad129a3ed64..bb53286380c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/accesslog/v3/accesslog.proto @@ -246,6 +246,7 @@ message ResponseFlagFilter { in: "DT" in: "UPE" in: "NC" + in: "OM" } } }]; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto index 431b45b6617..0e8de366333 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/bootstrap/v3/bootstrap.proto @@ -40,7 +40,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // ` for more detail. // Bootstrap :ref:`configuration overview `. -// [#next-free-field: 31] +// [#next-free-field: 33] message Bootstrap { option (udpa.annotations.versioning).previous_message_type = "envoy.config.bootstrap.v2.Bootstrap"; @@ -260,8 +260,25 @@ message Bootstrap { // This may be overridden on a per-cluster basis in cds_config, when // :ref:`dns_resolution_config ` // is specified. + // *dns_resolution_config* will be deprecated once + // :ref:'typed_dns_resolver_config ' + // is fully supported. core.v3.DnsResolutionConfig dns_resolution_config = 30; + // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, + // or any other DNS resolver types and the related parameters. + // For example, an object of :ref:`DnsResolutionConfig ` + // can be packed into this *typed_dns_resolver_config*. This configuration will replace the + // :ref:'dns_resolution_config ' + // configuration eventually. + // TODO(yanjunxiang): Investigate the deprecation plan for *dns_resolution_config*. + // During the transition period when both *dns_resolution_config* and *typed_dns_resolver_config* exists, + // this configuration is optional. + // When *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. + // When *typed_dns_resolver_config* is missing, the default behavior is in place. + // [#not-implemented-hide:] + core.v3.TypedExtensionConfig typed_dns_resolver_config = 31; + // Specifies optional bootstrap extensions to be instantiated at startup time. // Each item contains extension specific configuration. // [#extension-category: envoy.bootstrap] @@ -305,6 +322,13 @@ message Bootstrap { // field. // [#not-implemented-hide:] map certificate_provider_instances = 25; + + // Specifies a set of headers that need to be registered as inline header. This configuration + // allows users to customize the inline headers on-demand at Envoy startup without modifying + // Envoy's source code. + // + // Note that the 'set-cookie' header cannot be registered as inline header. + repeated CustomInlineHeader inline_headers = 32; } // Administration interface :ref:`operations documentation @@ -578,3 +602,43 @@ message LayeredRuntime { // such that later layers in the list overlay earlier entries. repeated RuntimeLayer layers = 1; } + +// Used to specify the header that needs to be registered as an inline header. +// +// If request or response contain multiple headers with the same name and the header +// name is registered as an inline header. Then multiple headers will be folded +// into one, and multiple header values will be concatenated by a suitable delimiter. +// The delimiter is generally a comma. +// +// For example, if 'foo' is registered as an inline header, and the headers contains +// the following two headers: +// +// .. code-block:: text +// +// foo: bar +// foo: eep +// +// Then they will eventually be folded into: +// +// .. code-block:: text +// +// foo: bar, eep +// +// Inline headers provide O(1) search performance, but each inline header imposes +// an additional memory overhead on all instances of the corresponding type of +// HeaderMap or TrailerMap. +message CustomInlineHeader { + enum InlineHeaderType { + REQUEST_HEADER = 0; + REQUEST_TRAILER = 1; + RESPONSE_HEADER = 2; + RESPONSE_TRAILER = 3; + } + + // The name of the header that is expected to be set as the inline header. + string inline_header_name = 1 + [(validate.rules).string = {min_len: 1 well_known_regex: HTTP_HEADER_NAME strict: false}]; + + // The type of the header that is expected to be set as the inline header. + InlineHeaderType inline_header_type = 2 [(validate.rules).enum = {defined_only: true}]; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto index 5470b1807d4..d6213d6fe94 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/cluster/v3/cluster.proto @@ -43,7 +43,7 @@ message ClusterCollection { } // Configuration for a single upstream cluster. -// [#next-free-field: 54] +// [#next-free-field: 56] message Cluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Cluster"; @@ -110,7 +110,7 @@ message Cluster { // this option or not. CLUSTER_PROVIDED = 6; - // [#not-implemented-hide:] Use the new :ref:`load_balancing_policy + // Use the new :ref:`load_balancing_policy // ` field to determine the LB policy. // [#next-major-version: In the v3 API, we should consider deprecating the lb_policy field // and instead using the new load_balancing_policy field as the one and only mechanism for @@ -413,8 +413,8 @@ message Cluster { // The table size for Maglev hashing. The Maglev aims for ‘minimal disruption’ rather than an absolute guarantee. // Minimal disruption means that when the set of upstreams changes, a connection will likely be sent to the same // upstream as it was before. Increasing the table size reduces the amount of disruption. - // The table size must be prime number. If it is not specified, the default is 65537. - google.protobuf.UInt64Value table_size = 1; + // The table size must be prime number limited to 5000011. If it is not specified, the default is 65537. + google.protobuf.UInt64Value table_size = 1 [(validate.rules).uint64 = {lte: 5000011}]; } // Specific configuration for the @@ -720,8 +720,7 @@ message Cluster { // The :ref:`load balancer type ` to use // when picking a host in the cluster. - // [#comment:TODO: Remove enum constraint :ref:`LOAD_BALANCING_POLICY_CONFIG` when implemented.] - LbPolicy lb_policy = 6 [(validate.rules).enum = {defined_only: true not_in: 7}]; + LbPolicy lb_policy = 6 [(validate.rules).enum = {defined_only: true}]; // Setting this is required for specifying members of // :ref:`STATIC`, @@ -746,7 +745,11 @@ message Cluster { // is respected by both the HTTP/1.1 and HTTP/2 connection pool // implementations. If not specified, there is no limit. Setting this // parameter to 1 will effectively disable keep alive. - google.protobuf.UInt32Value max_requests_per_connection = 9; + // + // .. attention:: + // This field has been deprecated in favor of the :ref:`max_requests_per_connection ` field. + google.protobuf.UInt32Value max_requests_per_connection = 9 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Optional :ref:`circuit breaking ` for the cluster. CircuitBreakers circuit_breakers = 10; @@ -778,7 +781,7 @@ message Cluster { [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Additional options when handling HTTP1 requests. - // This has been deprecated in favor of http_protocol_options fields in the in the + // This has been deprecated in favor of http_protocol_options fields in the // :ref:`http_protocol_options ` message. // http_protocol_options can be set via the cluster's // :ref:`extension_protocol_options`. @@ -794,7 +797,7 @@ message Cluster { // supports prior knowledge for upstream connections. Even if TLS is used // with ALPN, `http2_protocol_options` must be specified. As an aside this allows HTTP/2 // connections to happen over plain text. - // This has been deprecated in favor of http2_protocol_options fields in the in the + // This has been deprecated in favor of http2_protocol_options fields in the // :ref:`http_protocol_options ` // message. http2_protocol_options can be set via the cluster's // :ref:`extension_protocol_options`. @@ -874,8 +877,32 @@ message Cluster { [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // DNS resolution configuration which includes the underlying dns resolver addresses and options. + // *dns_resolution_config* will be deprecated once + // :ref:'typed_dns_resolver_config ' + // is fully supported. core.v3.DnsResolutionConfig dns_resolution_config = 53; + // DNS resolver type configuration extension. This extension can be used to configure c-ares, apple, + // or any other DNS resolver types and the related parameters. + // For example, an object of :ref:`DnsResolutionConfig ` + // can be packed into this *typed_dns_resolver_config*. This configuration will replace the + // :ref:'dns_resolution_config ' + // configuration eventually. + // TODO(yanjunxiang): Investigate the deprecation plan for *dns_resolution_config*. + // During the transition period when both *dns_resolution_config* and *typed_dns_resolver_config* exists, + // this configuration is optional. + // When *typed_dns_resolver_config* is in place, Envoy will use it and ignore *dns_resolution_config*. + // When *typed_dns_resolver_config* is missing, the default behavior is in place. + // [#not-implemented-hide:] + core.v3.TypedExtensionConfig typed_dns_resolver_config = 55; + + // Optional configuration for having cluster readiness block on warm-up. Currently, only applicable for + // :ref:`STRICT_DNS`, + // or :ref:`LOGICAL_DNS`. + // If true, cluster readiness blocks on warm-up. If false, the cluster will complete + // initialization whether or not warm-up has completed. Defaults to true. + google.protobuf.BoolValue wait_for_warm_on_init = 54; + // If specified, outlier detection will be enabled for this upstream cluster. // Each of the configuration values can be overridden via // :ref:`runtime values `. @@ -930,7 +957,7 @@ message Cluster { CommonLbConfig common_lb_config = 27; // Optional custom transport socket implementation to use for upstream connections. - // To setup TLS, set a transport socket with name `tls` and + // To setup TLS, set a transport socket with name `envoy.transport_sockets.tls` and // :ref:`UpstreamTlsContexts ` in the `typed_config`. // If no transport socket configuration is specified, new connections // will be set up with plaintext. @@ -980,7 +1007,7 @@ message Cluster { // servers of this cluster. repeated Filter filters = 40; - // [#not-implemented-hide:] New mechanism for LB policy configuration. Used only if the + // New mechanism for LB policy configuration. Used only if the // :ref:`lb_policy` field has the value // :ref:`LOAD_BALANCING_POLICY_CONFIG`. LoadBalancingPolicy load_balancing_policy = 41; @@ -1045,7 +1072,7 @@ message Cluster { bool connection_pool_per_downstream_connection = 51; } -// [#not-implemented-hide:] Extensible load balancing policy configuration. +// Extensible load balancing policy configuration. // // Every LB policy defined via this mechanism will be identified via a unique name using reverse // DNS notation. If the policy needs configuration parameters, it must define a message for its @@ -1071,14 +1098,11 @@ message LoadBalancingPolicy { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.LoadBalancingPolicy.Policy"; - reserved 2; - - reserved "config"; + reserved 2, 1, 3; - // Required. The name of the LB policy. - string name = 1; + reserved "config", "name", "typed_config"; - google.protobuf.Any typed_config = 3; + core.v3.TypedExtensionConfig typed_extension_config = 4; } // Each client will iterate over the list in order and stop at the first policy that it diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto index cf98e537261..8f2347eb551 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/core/v3/protocol.proto @@ -73,7 +73,7 @@ message UpstreamHttpProtocolOptions { // Configures the alternate protocols cache which tracks alternate protocols that can be used to // make an HTTP connection to an origin server. See https://tools.ietf.org/html/rfc7838 for -// HTTP Alternate Services and https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-04 +// HTTP Alternative Services and https://datatracker.ietf.org/doc/html/draft-ietf-dnsop-svcb-https-04 // for the "HTTPS" DNS resource record. message AlternateProtocolsCacheOptions { // The name of the cache. Multiple named caches allow independent alternate protocols cache @@ -93,7 +93,7 @@ message AlternateProtocolsCacheOptions { google.protobuf.UInt32Value max_entries = 2 [(validate.rules).uint32 = {gt: 0}]; } -// [#next-free-field: 6] +// [#next-free-field: 7] message HttpProtocolOptions { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.core.HttpProtocolOptions"; @@ -157,6 +157,12 @@ message HttpProtocolOptions { // If this setting is not specified, the value defaults to ALLOW. // Note: upstream responses are not affected by this setting. HeadersWithUnderscoresAction headers_with_underscores_action = 5; + + // Optional maximum requests for both upstream and downstream connections. + // If not specified, there is no limit. + // Setting this parameter to 1 will effectively disable keep alive. + // For HTTP/2 and HTTP/3, due to concurrent stream processing, the limit is approximate. + google.protobuf.UInt32Value max_requests_per_connection = 6; } // [#next-free-field: 8] @@ -478,3 +484,11 @@ message Http3ProtocolOptions { // `. google.protobuf.BoolValue override_stream_error_on_invalid_http_message = 2; } + +// A message to control transformations to the :scheme header +message SchemeHeaderTransformation { + oneof transformation { + // Overwrite any Scheme header with the contents of this string. + string scheme_to_overwrite = 1 [(validate.rules).string = {in: "http" in: "https"}]; + } +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto index 0e10ac3b2fc..0a9aac105e7 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/endpoint/v3/endpoint_components.proto @@ -4,10 +4,12 @@ package envoy.config.endpoint.v3; import "envoy/config/core/v3/address.proto"; import "envoy/config/core/v3/base.proto"; +import "envoy/config/core/v3/config_source.proto"; import "envoy/config/core/v3/health_check.proto"; import "google/protobuf/wrappers.proto"; +import "udpa/annotations/migrate.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -108,21 +110,51 @@ message LbEndpoint { google.protobuf.UInt32Value load_balancing_weight = 4 [(validate.rules).uint32 = {gte: 1}]; } +// [#not-implemented-hide:] +// A configuration for a LEDS collection. +message LedsClusterLocalityConfig { + // Configuration for the source of LEDS updates for a Locality. + core.v3.ConfigSource leds_config = 1; + + // The xDS transport protocol glob collection resource name. + // The service is only supported in delta xDS (incremental) mode. + string leds_collection_name = 2; +} + // A group of endpoints belonging to a Locality. // One can have multiple LocalityLbEndpoints for a locality, but this is // generally only done if the different groups need to have different load // balancing weights or different priorities. -// [#next-free-field: 7] +// [#next-free-field: 9] message LocalityLbEndpoints { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.endpoint.LocalityLbEndpoints"; + // [#not-implemented-hide:] + // A list of endpoints of a specific locality. + message LbEndpointList { + repeated LbEndpoint lb_endpoints = 1; + } + // Identifies location of where the upstream hosts run. core.v3.Locality locality = 1; // The group of endpoints belonging to the locality specified. + // [#comment:TODO(adisuissa): Once LEDS is implemented this field needs to be + // deprecated and replaced by *load_balancer_endpoints*.] repeated LbEndpoint lb_endpoints = 2; + // [#not-implemented-hide:] + oneof lb_config { + // The group of endpoints belonging to the locality. + // [#comment:TODO(adisuissa): Once LEDS is implemented the *lb_endpoints* field + // needs to be deprecated.] + LbEndpointList load_balancer_endpoints = 7; + + // LEDS Configuration for the current locality. + LedsClusterLocalityConfig leds_cluster_locality_config = 8; + } + // Optional: Per priority/region/zone/sub_zone weight; at least 1. The load // balancing weight for a locality is divided by the sum of the weights of all // localities at the same priority level to produce the effective percentage diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto index 1dc94edc74b..77db7caaff5 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/api_listener.proto @@ -23,6 +23,7 @@ message ApiListener { // The type in this field determines the type of API listener. At present, the following // types are supported: // envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager (HTTP) + // envoy.extensions.filters.network.http_connection_manager.v3.EnvoyMobileHttpConnectionManager (HTTP) // [#next-major-version: In the v3 API, replace this Any field with a oneof containing the // specific config message for each type of API listener. We could not do this in v2 because // it would have caused circular dependencies for go protos: lds.proto depends on this file, diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto index b5bda9562ce..a5cd4bfe976 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener.proto @@ -35,7 +35,7 @@ message ListenerCollection { repeated xds.core.v3.CollectionEntry entries = 1; } -// [#next-free-field: 29] +// [#next-free-field: 30] message Listener { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.Listener"; @@ -255,17 +255,30 @@ message Listener { // enable the balance config in Y1 and Y2 to balance the connections among the workers. ConnectionBalanceConfig connection_balance_config = 20; + // Deprecated. Use `enable_reuse_port` instead. + bool reuse_port = 21 [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + // When this flag is set to true, listeners set the *SO_REUSEPORT* socket option and // create one socket for each worker thread. This makes inbound connections // distribute among worker threads roughly evenly in cases where there are a high number - // of connections. When this flag is set to false, all worker threads share one socket. + // of connections. When this flag is set to false, all worker threads share one socket. This field + // defaults to true. + // + // .. attention:: + // + // Although this field defaults to true, it has different behavior on different platforms. See + // the following text for more information. // - // Before Linux v4.19-rc1, new TCP connections may be rejected during hot restart - // (see `3rd paragraph in 'soreuseport' commit message - // `_). - // This issue was fixed by `tcp: Avoid TCP syncookie rejected by SO_REUSEPORT socket - // `_. - bool reuse_port = 21; + // * On Linux, reuse_port is respected for both TCP and UDP listeners. It also works correctly + // with hot restart. + // * On macOS, reuse_port for TCP does not do what it does on Linux. Instead of load balancing, + // the last socket wins and receives all connections/packets. For TCP, reuse_port is force + // disabled and the user is warned. For UDP, it is enabled, but only one worker will receive + // packets. For QUIC/H3, SW routing will send packets to other workers. For "raw" UDP, only + // a single worker will currently receive packets. + // * On Windows, reuse_port for TCP has undefined behavior. It is force disabled and the user + // is warned similar to macOS. It is left enabled for UDP with undefined behavior currently. + google.protobuf.BoolValue enable_reuse_port = 29; // Configuration for :ref:`access logs ` // emitted by this listener. diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto index e6d73b791c2..e737b14b174 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/listener/v3/listener_components.proto @@ -64,9 +64,12 @@ message Filter { // 3. Server name (e.g. SNI for TLS protocol), // 4. Transport protocol. // 5. Application protocols (e.g. ALPN for TLS protocol). -// 6. Source type (e.g. any, local or external network). -// 7. Source IP address. -// 8. Source port. +// 6. Directly connected source IP address (this will only be different from the source IP address +// when using a listener filter that overrides the source address, such as the :ref:`Proxy Protocol +// listener filter `). +// 7. Source type (e.g. any, local or external network). +// 8. Source IP address. +// 9. Source port. // // For criteria that allow ranges or wildcards, the most specific value in any // of the configured filter chains that matches the incoming connection is going @@ -90,7 +93,7 @@ message Filter { // listed at the end, because that's how we want to list them in the docs. // // [#comment:TODO(PiotrSikora): Add support for configurable precedence of the rules] -// [#next-free-field: 13] +// [#next-free-field: 14] message FilterChainMatch { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.listener.FilterChainMatch"; @@ -124,6 +127,11 @@ message FilterChainMatch { // [#not-implemented-hide:] google.protobuf.UInt32Value suffix_len = 5; + // The criteria is satisfied if the directly connected source IP address of the downstream + // connection is contained in at least one of the specified subnets. If the parameter is not + // specified or the list is empty, the directly connected source IP address is ignored. + repeated core.v3.CidrRange direct_source_prefix_ranges = 13; + // Specifies the connection source IP match type. Can be any, local or external network. ConnectionSourceType source_type = 12 [(validate.rules).enum = {defined_only: true}]; @@ -238,7 +246,7 @@ message FilterChain { core.v3.Metadata metadata = 5; // Optional custom transport socket implementation to use for downstream connections. - // To setup TLS, set a transport socket with name `tls` and + // To setup TLS, set a transport socket with name `envoy.transport_sockets.tls` and // :ref:`DownstreamTlsContext ` in the `typed_config`. // If no transport socket configuration is specified, new connections // will be set up with plaintext. @@ -282,7 +290,7 @@ message FilterChain { // rules: // - destination_port_range: // start: 3306 -// end: 3306 +// end: 3307 // - destination_port_range: // start: 15000 // end: 15001 diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto index 4445af63211..85fa761dbdd 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/overload/v3/overload.proto @@ -141,6 +141,26 @@ message OverloadAction { google.protobuf.Any typed_config = 3; } +// Configuration for which accounts the WatermarkBuffer Factories should +// track. +message BufferFactoryConfig { + // The minimum power of two at which Envoy starts tracking an account. + // + // Envoy has 8 power of two buckets starting with the provided exponent below. + // Concretely the 1st bucket contains accounts for streams that use + // [2^minimum_account_to_track_power_of_two, + // 2^(minimum_account_to_track_power_of_two + 1)) bytes. + // With the 8th bucket tracking accounts + // >= 128 * 2^minimum_account_to_track_power_of_two. + // + // The maximum value is 56, since we're using uint64_t for bytes counting, + // and that's the last value that would use the 8 buckets. In practice, + // we don't expect the proxy to be holding 2^56 bytes. + // + // If omitted, Envoy should not do any tracking. + uint32 minimum_account_to_track_power_of_two = 1 [(validate.rules).uint32 = {lte: 56 gte: 10}]; +} + message OverloadManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.overload.v2alpha.OverloadManager"; @@ -153,4 +173,7 @@ message OverloadManager { // The set of overload actions. repeated OverloadAction actions = 3; + + // Configuration for buffer factory. + BufferFactoryConfig buffer_factory_config = 4; } diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto index 3b7f79d605d..d66f9be2b49 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/rbac/v3/rbac.proto @@ -7,6 +7,7 @@ import "envoy/config/route/v3/route_components.proto"; import "envoy/type/matcher/v3/metadata.proto"; import "envoy/type/matcher/v3/path.proto"; import "envoy/type/matcher/v3/string.proto"; +import "envoy/type/v3/range.proto"; import "google/api/expr/v1alpha1/checked.proto"; import "google/api/expr/v1alpha1/syntax.proto"; @@ -60,7 +61,10 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // permissions: // - and_rules: // rules: -// - header: { name: ":method", exact_match: "GET" } +// - header: +// name: ":method" +// string_match: +// exact: "GET" // - url_path: // path: { prefix: "/products" } // - or_rules: @@ -142,7 +146,7 @@ message Policy { } // Permission defines an action (or actions) that a principal can take. -// [#next-free-field: 11] +// [#next-free-field: 12] message Permission { option (udpa.annotations.versioning).previous_message_type = "envoy.config.rbac.v2.Permission"; @@ -182,6 +186,9 @@ message Permission { // A port number that describes the destination port connecting to. uint32 destination_port = 6 [(validate.rules).uint32 = {lte: 65535}]; + // A port number range that describes a range of destination ports connecting to. + type.v3.Int32Range destination_port_range = 11; + // Metadata that describes additional information about the action. type.matcher.v3.MetadataMatcher metadata = 7; diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto index 80956fdeb4e..e2bf52165be 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route.proto @@ -4,6 +4,7 @@ package envoy.config.route.v3; import "envoy/config/core/v3/base.proto"; import "envoy/config/core/v3/config_source.proto"; +import "envoy/config/core/v3/extension.proto"; import "envoy/config/route/v3/route_components.proto"; import "google/protobuf/wrappers.proto"; @@ -21,7 +22,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // * Routing :ref:`architecture overview ` // * HTTP :ref:`router filter ` -// [#next-free-field: 12] +// [#next-free-field: 13] message RouteConfiguration { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.RouteConfiguration"; @@ -119,6 +120,18 @@ message RouteConfiguration { // is not subject to data plane buffering controls. // google.protobuf.UInt32Value max_direct_response_body_size_bytes = 11; + + // [#not-implemented-hide:] + // A list of plugins and their configurations which may be used by a + // :ref:`envoy_v3_api_field_config.route.v3.RouteAction.cluster_specifier_plugin` + // within the route. All *extension.name* fields in this list must be unique. + repeated ClusterSpecifierPlugin cluster_specifier_plugins = 12; +} + +// Configuration for a cluster specifier plugin. +message ClusterSpecifierPlugin { + // The name of the plugin and its opaque configuration. + core.v3.TypedExtensionConfig extension = 1; } message Vhds { diff --git a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto index ee82e8f7322..dfb8b8ed1a1 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/config/route/v3/route_components.proto @@ -311,7 +311,7 @@ message Route { message WeightedCluster { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster"; - // [#next-free-field: 11] + // [#next-free-field: 12] message ClusterWeight { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.WeightedCluster.ClusterWeight"; @@ -378,6 +378,13 @@ message WeightedCluster { // :ref:`FilterConfig` // message to specify additional options.] map typed_per_filter_config = 10; + + oneof host_rewrite_specifier { + // Indicates that during forwarding, the host header will be swapped with + // this value. + string host_rewrite_literal = 11 + [(validate.rules).string = {well_known_regex: HTTP_HEADER_VALUE strict: false}]; + } } // Specifies one or more upstream clusters associated with the route. @@ -466,7 +473,7 @@ message RouteMatch { } // Indicates that prefix/path matching should be case sensitive. The default - // is true. + // is true. Ignored for safe_regex matching. google.protobuf.BoolValue case_sensitive = 4; // Indicates that the route should additionally match on a runtime key. Every time the route @@ -563,7 +570,7 @@ message CorsPolicy { core.v3.RuntimeFractionalPercent shadow_enabled = 10; } -// [#next-free-field: 37] +// [#next-free-field: 38] message RouteAction { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.RouteAction"; @@ -839,6 +846,14 @@ message RouteAction { // :ref:`traffic splitting ` // for additional documentation. WeightedCluster weighted_clusters = 3; + + // [#not-implemented-hide:] + // Name of the cluster specifier plugin to use to determine the cluster for + // requests on this route. The plugin name must be defined in the associated + // :ref:`envoy_v3_api_field_config.route.v3.RouteConfiguration.cluster_specifier_plugins` + // in the + // :ref:`envoy_v3_api_field_config.core.v3.TypedExtensionConfig.name` field. + string cluster_specifier_plugin = 37; } // The HTTP status code to use when configured cluster is not found. @@ -1850,7 +1865,7 @@ message RateLimit { // value. // // [#next-major-version: HeaderMatcher should be refactored to use StringMatcher.] -// [#next-free-field: 13] +// [#next-free-field: 14] message HeaderMatcher { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.route.HeaderMatcher"; @@ -1865,12 +1880,16 @@ message HeaderMatcher { // Specifies how the header match will be performed to route the request. oneof header_match_specifier { // If specified, header match will be performed based on the value of the header. - string exact_match = 4; + // This field is deprecated. Please use :ref:`string_match `. + string exact_match = 4 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If specified, this regex string is a regular expression rule which implies the entire request // header value must match the regex. The rule will not match if only a subsequence of the // request header value matches the regex. - type.matcher.v3.RegexMatcher safe_regex_match = 11; + // This field is deprecated. Please use :ref:`string_match `. + type.matcher.v3.RegexMatcher safe_regex_match = 11 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // If specified, header match will be performed based on range. // The rule will match if the request header value is within this range. @@ -1891,28 +1910,46 @@ message HeaderMatcher { // If specified, header match will be performed based on the prefix of the header value. // Note: empty prefix is not allowed, please use present_match instead. + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // // * The prefix *abcd* matches the value *abcdxyz*, but not for *abcxyz*. - string prefix_match = 9 [(validate.rules).string = {min_len: 1}]; + string prefix_match = 9 [ + deprecated = true, + (validate.rules).string = {min_len: 1}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // If specified, header match will be performed based on the suffix of the header value. // Note: empty suffix is not allowed, please use present_match instead. + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // // * The suffix *abcd* matches the value *xyzabcd*, but not for *xyzbcd*. - string suffix_match = 10 [(validate.rules).string = {min_len: 1}]; + string suffix_match = 10 [ + deprecated = true, + (validate.rules).string = {min_len: 1}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; // If specified, header match will be performed based on whether the header value contains // the given value or not. // Note: empty contains match is not allowed, please use present_match instead. + // This field is deprecated. Please use :ref:`string_match `. // // Examples: // // * The value *abcd* matches the value *xyzabcdpqr*, but not for *xyzbcdpqr*. - string contains_match = 12 [(validate.rules).string = {min_len: 1}]; + string contains_match = 12 [ + deprecated = true, + (validate.rules).string = {min_len: 1}, + (envoy.annotations.deprecated_at_minor_version) = "3.0" + ]; + + // If specified, header match will be performed based on the string match of the header value. + type.matcher.v3.StringMatcher string_match = 13; } // If specified, the match result will be inverted before checking. Defaults to false. diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto index b5b1dbd463f..62da059e264 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/common/fault/v3/fault.proto @@ -18,7 +18,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // [#protodoc-title: Common fault injection types] // Delay specification is used to inject latency into the -// HTTP/gRPC/Mongo/Redis operation or delay proxying of TCP connections. +// HTTP/Mongo operation. // [#next-free-field: 6] message FaultDelay { option (udpa.annotations.versioning).previous_message_type = @@ -46,10 +46,9 @@ message FaultDelay { // Add a fixed delay before forwarding the operation upstream. See // https://developers.google.com/protocol-buffers/docs/proto3#json for - // the JSON/YAML Duration mapping. For HTTP/Mongo/Redis, the specified - // delay will be injected before a new request/operation. For TCP - // connections, the proxying of the connection upstream will be delayed - // for the specified period. This is required if type is FIXED. + // the JSON/YAML Duration mapping. For HTTP/Mongo, the specified + // delay will be injected before a new request/operation. + // This is required if type is FIXED. google.protobuf.Duration fixed_delay = 3 [(validate.rules).duration = {gt {}}]; // Fault delays are controlled via an HTTP header (if applicable). diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto index 856249c2a25..3fb4bfa09e2 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto @@ -19,7 +19,6 @@ import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "envoy/annotations/deprecation.proto"; import "udpa/annotations/migrate.proto"; import "udpa/annotations/security.proto"; import "udpa/annotations/status.proto"; @@ -35,7 +34,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // HTTP connection manager :ref:`configuration overview `. // [#extension: envoy.filters.network.http_connection_manager] -// [#next-free-field: 48] +// [#next-free-field: 49] message HttpConnectionManager { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.http_connection_manager.v2.HttpConnectionManager"; @@ -371,6 +370,11 @@ message HttpConnectionManager { ServerHeaderTransformation server_header_transformation = 34 [(validate.rules).enum = {defined_only: true}]; + // Allows for explicit transformation of the :scheme header on the request path. + // If not set, Envoy's default :ref:`scheme ` + // handling applies. + config.core.v3.SchemeHeaderTransformation scheme_header_transformation = 48; + // The maximum request headers size for incoming connections. // If unconfigured, the default max request headers allowed is 60 KiB. // Requests that exceed this limit will receive a 431 response. @@ -496,23 +500,7 @@ message HttpConnectionManager { // determining the origin client's IP address. The default is zero if this option // is not specified. See the documentation for // :ref:`config_http_conn_man_headers_x-forwarded-for` for more information. - // - // .. note:: - // This field is deprecated and instead :ref:`original_ip_detection_extensions - // ` - // should be used to configure the :ref:`xff extension ` - // to configure IP detection using the :ref:`config_http_conn_man_headers_x-forwarded-for` header. To replace - // this field use a config like the following: - // - // .. code-block:: yaml - // - // original_ip_detection_extensions: - // typed_config: - // "@type": type.googleapis.com/envoy.extensions.http.original_ip_detection.xff.v3.XffConfig - // xff_num_trusted_hops: 1 - // - uint32 xff_num_trusted_hops = 19 - [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; + uint32 xff_num_trusted_hops = 19; // The configuration for the original IP detection extensions. // @@ -524,6 +512,12 @@ message HttpConnectionManager { // the request. If the request isn't rejected nor any extension succeeds, the HCM will // fallback to using the remote address. // + // .. WARNING:: + // Extensions cannot be used in conjunction with :ref:`use_remote_address + // ` + // nor :ref:`xff_num_trusted_hops + // `. + // // [#extension-category: envoy.http.original_ip_detection] repeated config.core.v3.TypedExtensionConfig original_ip_detection_extensions = 46; @@ -1000,3 +994,12 @@ message RequestIDExtension { // Request ID extension specific configuration. google.protobuf.Any typed_config = 1; } + +// [#protodoc-title: Envoy Mobile HTTP connection manager] +// HTTP connection manager for use in Envoy mobile. +// [#extension: envoy.filters.network.envoy_mobile_http_connection_manager] +message EnvoyMobileHttpConnectionManager { + // The configuration for the underlying HttpConnectionManager which will be + // instantiated for Envoy mobile. + HttpConnectionManager config = 1; +} diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto index aa05a31f23d..82dcb37cd7c 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/common.proto @@ -9,6 +9,7 @@ import "envoy/type/matcher/v3/string.proto"; import "google/protobuf/any.proto"; import "google/protobuf/wrappers.proto"; +import "udpa/annotations/migrate.proto"; import "udpa/annotations/sensitive.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; @@ -232,7 +233,27 @@ message TlsSessionTicketKeys { [(validate.rules).repeated = {min_items: 1}, (udpa.annotations.sensitive) = true]; } -// [#next-free-field: 13] +// Indicates a certificate to be obtained from a named CertificateProvider plugin instance. +// The plugin instances are defined in the client's bootstrap file. +// The plugin allows certificates to be fetched/refreshed over the network asynchronously with +// respect to the TLS handshake. +// [#not-implemented-hide:] +message CertificateProviderPluginInstance { + // Provider instance name. If not present, defaults to "default". + // + // Instance names should generally be defined not in terms of the underlying provider + // implementation (e.g., "file_watcher") but rather in terms of the function of the + // certificates (e.g., "foo_deployment_identity"). + string instance_name = 1; + + // Opaque name used to specify certificate instances or types. For example, "ROOTCA" to specify + // a root-certificate (validation context) or "example.com" to specify a certificate for a + // particular domain. Not all provider instances will actually use this field, so the value + // defaults to the empty string. + string certificate_name = 2; +} + +// [#next-free-field: 14] message CertificateValidationContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CertificateValidationContext"; @@ -279,7 +300,20 @@ message CertificateValidationContext { // directory for any file moves to support rotation. This currently only // applies to dynamic secrets, when the *CertificateValidationContext* is // delivered via SDS. - config.core.v3.DataSource trusted_ca = 1; + // + // Only one of *trusted_ca* and *ca_certificate_provider_instance* may be specified. + // + // [#next-major-version: This field and watched_directory below should ideally be moved into a + // separate sub-message, since there's no point in specifying the latter field without this one.] + config.core.v3.DataSource trusted_ca = 1 + [(udpa.annotations.field_migrate).oneof_promotion = "ca_cert_source"]; + + // Certificate provider instance for fetching TLS certificates. + // + // Only one of *trusted_ca* and *ca_certificate_provider_instance* may be specified. + // [#not-implemented-hide:] + CertificateProviderPluginInstance ca_certificate_provider_instance = 13 + [(udpa.annotations.field_migrate).oneof_promotion = "ca_cert_source"]; // If specified, updates of a file-based *trusted_ca* source will be triggered // by this watch. This allows explicit control over the path watched, by diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto index 02287de5875..f680207955a 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -9,7 +9,7 @@ import "envoy/extensions/transport_sockets/tls/v3/secret.proto"; import "google/protobuf/duration.proto"; import "google/protobuf/wrappers.proto"; -import "udpa/annotations/migrate.proto"; +import "envoy/annotations/deprecation.proto"; import "udpa/annotations/status.proto"; import "udpa/annotations/versioning.proto"; import "validate/validate.proto"; @@ -125,12 +125,18 @@ message DownstreamTlsContext { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 14] +// [#next-free-field: 15] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; // Config for Certificate provider to get certificates. This provider should allow certificates to be // fetched/refreshed over the network asynchronously with respect to the TLS handshake. + // + // DEPRECATED: This message is not currently used, but if we ever do need it, we will want to + // move it out of CommonTlsContext and into common.proto, similar to the existing + // CertificateProviderPluginInstance message. + // + // [#not-implemented-hide:] message CertificateProvider { // opaque name used to specify certificate instances or types. For example, "ROOTCA" to specify // a root-certificate (validation context) or "TLS" to specify a new tls-certificate. @@ -151,6 +157,11 @@ message CommonTlsContext { // Similar to CertificateProvider above, but allows the provider instances to be configured on // the client side instead of being sent from the control plane. + // + // DEPRECATED: This message was moved outside of CommonTlsContext + // and now lives in common.proto. + // + // [#not-implemented-hide:] message CertificateProviderInstance { // Provider instance name. This name must be defined in the client's configuration (e.g., a // bootstrap file) to correspond to a provider instance (i.e., the same data in the typed_config @@ -179,26 +190,20 @@ message CommonTlsContext { // Config for fetching validation context via SDS API. Note SDS API allows certificates to be // fetched/refreshed over the network asynchronously with respect to the TLS handshake. - // Only one of validation_context_sds_secret_config, validation_context_certificate_provider, - // or validation_context_certificate_provider_instance may be used. - SdsSecretConfig validation_context_sds_secret_config = 2 [ - (validate.rules).message = {required: true}, - (udpa.annotations.field_migrate).oneof_promotion = "dynamic_validation_context" - ]; + SdsSecretConfig validation_context_sds_secret_config = 2 + [(validate.rules).message = {required: true}]; - // Certificate provider for fetching validation context. - // Only one of validation_context_sds_secret_config, validation_context_certificate_provider, - // or validation_context_certificate_provider_instance may be used. + // Certificate provider for fetching CA certs. This will populate the + // *default_validation_context.trusted_ca* field. // [#not-implemented-hide:] CertificateProvider validation_context_certificate_provider = 3 - [(udpa.annotations.field_migrate).oneof_promotion = "dynamic_validation_context"]; + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; - // Certificate provider instance for fetching validation context. - // Only one of validation_context_sds_secret_config, validation_context_certificate_provider, - // or validation_context_certificate_provider_instance may be used. + // Certificate provider instance for fetching CA certs. This will populate the + // *default_validation_context.trusted_ca* field. // [#not-implemented-hide:] CertificateProviderInstance validation_context_certificate_provider_instance = 4 - [(udpa.annotations.field_migrate).oneof_promotion = "dynamic_validation_context"]; + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } reserved 5; @@ -212,6 +217,12 @@ message CommonTlsContext { // Only a single TLS certificate is supported in client contexts. In server contexts, the first // RSA certificate is used for clients that only support RSA and the first ECDSA certificate is // used for clients that support ECDSA. + // + // Only one of *tls_certificates*, *tls_certificate_sds_secret_configs*, + // and *tls_certificate_provider_instance* may be used. + // [#next-major-version: These mutually exclusive fields should ideally be in a oneof, but it's + // not legal to put a repeated field in a oneof. In the next major version, we should rework + // this to avoid this problem.] repeated TlsCertificate tls_certificates = 2; // Configs for fetching TLS certificates via SDS API. Note SDS API allows certificates to be @@ -220,18 +231,30 @@ message CommonTlsContext { // The same number and types of certificates as :ref:`tls_certificates ` // are valid in the the certificates fetched through this setting. // - // If :ref:`tls_certificates ` - // is non-empty, this field is ignored. + // Only one of *tls_certificates*, *tls_certificate_sds_secret_configs*, + // and *tls_certificate_provider_instance* may be used. + // [#next-major-version: These mutually exclusive fields should ideally be in a oneof, but it's + // not legal to put a repeated field in a oneof. In the next major version, we should rework + // this to avoid this problem.] repeated SdsSecretConfig tls_certificate_sds_secret_configs = 6 [(validate.rules).repeated = {max_items: 2}]; + // Certificate provider instance for fetching TLS certs. + // + // Only one of *tls_certificates*, *tls_certificate_sds_secret_configs*, + // and *tls_certificate_provider_instance* may be used. + // [#not-implemented-hide:] + CertificateProviderPluginInstance tls_certificate_provider_instance = 14; + // Certificate provider for fetching TLS certificates. // [#not-implemented-hide:] - CertificateProvider tls_certificate_certificate_provider = 9; + CertificateProvider tls_certificate_certificate_provider = 9 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Certificate provider instance for fetching TLS certificates. // [#not-implemented-hide:] - CertificateProviderInstance tls_certificate_certificate_provider_instance = 11; + CertificateProviderInstance tls_certificate_certificate_provider_instance = 11 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; oneof validation_context_type { // How to validate peer certificates. @@ -252,11 +275,13 @@ message CommonTlsContext { // Certificate provider for fetching validation context. // [#not-implemented-hide:] - CertificateProvider validation_context_certificate_provider = 10; + CertificateProvider validation_context_certificate_provider = 10 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; // Certificate provider instance for fetching validation context. // [#not-implemented-hide:] - CertificateProviderInstance validation_context_certificate_provider_instance = 12; + CertificateProviderInstance validation_context_certificate_provider_instance = 12 + [deprecated = true, (envoy.annotations.deprecated_at_minor_version) = "3.0"]; } // Supplies the list of ALPN protocols that the listener should expose. In diff --git a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto index 78e1572bf8c..c64edde142f 100644 --- a/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto +++ b/xds/third_party/envoy/src/main/proto/envoy/type/matcher/v3/string.proto @@ -62,8 +62,8 @@ message StringMatcher { string contains = 7 [(validate.rules).string = {min_len: 1}]; } - // If true, indicates the exact/prefix/suffix matching should be case insensitive. This has no - // effect for the safe_regex match. + // If true, indicates the exact/prefix/suffix/contains matching should be case insensitive. This + // has no effect for the safe_regex match. // For example, the matcher *data* will match both input string *Data* and *data* if set to true. bool ignore_case = 6; }