diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index a5070c937c6..b827468d719 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -9,5 +9,5 @@ jobs: name: "Gradle wrapper validation" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: gradle/wrapper-validation-action@v1 diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml index cb648e7fc73..ba37d4db4be 100644 --- a/.github/workflows/lock.yml +++ b/.github/workflows/lock.yml @@ -13,8 +13,8 @@ jobs: lock: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v2 + - uses: dessant/lock-threads@v3 with: github-token: ${{ github.token }} - issue-lock-inactive-days: 90 - pr-lock-inactive-days: 90 + issue-inactive-days: 90 + pr-inactive-days: 90 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 609d0841494..4788ebfc7f0 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -21,14 +21,14 @@ jobs: fail-fast: false # Should swap to true if we grow a large matrix steps: - - uses: actions/checkout@v2 - - uses: actions/setup-java@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-java@v3 with: java-version: ${{ matrix.jre }} distribution: 'temurin' - name: Gradle cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | ~/.gradle/caches @@ -37,7 +37,7 @@ jobs: restore-keys: | ${{ runner.os }}-gradle- - name: Maven cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: | ~/.m2/repository @@ -46,7 +46,7 @@ jobs: restore-keys: | ${{ runner.os }}-maven- - name: Protobuf cache - uses: actions/cache@v2 + uses: actions/cache@v3 with: path: /tmp/protobuf-cache key: ${{ runner.os }}-maven-${{ hashFiles('buildscripts/make_dependencies.sh') }} @@ -55,7 +55,7 @@ jobs: run: buildscripts/kokoro/unix.sh - name: Post Failure Upload Test Reports to Artifacts if: ${{ failure() }} - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v3 with: name: Test Reports (JRE ${{ matrix.jre }}) path: ./*/build/reports/tests/** @@ -67,4 +67,4 @@ jobs: if: matrix.jre == 8 # Upload once, instead of for each job in the matrix run: ./gradlew :grpc-all:coveralls -x compileJava - name: Codecov - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 diff --git a/COMPILING.md b/COMPILING.md index 1423ec88d78..3c5ad537e07 100644 --- a/COMPILING.md +++ b/COMPILING.md @@ -44,11 +44,11 @@ This section is only necessary if you are making changes to the code generation. Most users only need to use `skipCodegen=true` as discussed above. ### Build Protobuf -The codegen plugin is C++ code and requires protobuf 3.21.1 or later. +The codegen plugin is C++ code and requires protobuf 21.7 or later. For Linux, Mac and MinGW: ``` -$ PROTOBUF_VERSION=3.21.1 +$ PROTOBUF_VERSION=21.7 $ curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v$PROTOBUF_VERSION/protobuf-all-$PROTOBUF_VERSION.tar.gz $ tar xzf protobuf-all-$PROTOBUF_VERSION.tar.gz $ cd protobuf-$PROTOBUF_VERSION diff --git a/README.md b/README.md index 9290082d96b..2c18763b180 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.49.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.49.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.51.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.51.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,18 +43,18 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.49.0 + 1.51.0 runtime io.grpc grpc-protobuf - 1.49.0 + 1.51.0 io.grpc grpc-stub - 1.49.0 + 1.51.0 org.apache.tomcat @@ -66,23 +66,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.49.0' -implementation 'io.grpc:grpc-protobuf:1.49.0' -implementation 'io.grpc:grpc-stub:1.49.0' +runtimeOnly 'io.grpc:grpc-netty-shaded:1.51.0' +implementation 'io.grpc:grpc-protobuf:1.51.0' +implementation 'io.grpc:grpc-stub:1.51.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.49.0' -implementation 'io.grpc:grpc-protobuf-lite:1.49.0' -implementation 'io.grpc:grpc-stub:1.49.0' +implementation 'io.grpc:grpc-okhttp:1.51.0' +implementation 'io.grpc:grpc-protobuf-lite:1.51.0' +implementation 'io.grpc:grpc-stub:1.51.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.49.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.51.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -112,9 +112,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.21.1:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.21.7:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.49.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.51.0:exe:${os.detected.classifier} @@ -140,11 +140,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.21.1" + artifact = "com.google.protobuf:protoc:3.21.7" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.49.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.51.0' } } generateProtoTasks { @@ -173,11 +173,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.21.1" + artifact = "com.google.protobuf:protoc:3.21.7" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.49.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.51.0' } } generateProtoTasks { diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index 62ec4d167a4..cfe2e934c0f 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -73,7 +73,9 @@ public abstract class NameResolver { public abstract String getServiceAuthority(); /** - * Starts the resolution. + * Starts the resolution. The method is not supposed to throw any exceptions. That might cause the + * Channel that the name resolver is serving to crash. Errors should be propagated + * through {@link Listener#onError}. * * @param listener used to receive updates on the target * @since 1.0.0 @@ -97,7 +99,9 @@ public void onResult(ResolutionResult resolutionResult) { } /** - * Starts the resolution. + * Starts the resolution. The method is not supposed to throw any exceptions. That might cause the + * Channel that the name resolver is serving to crash. Errors should be propagated + * through {@link Listener2#onError}. * * @param listener used to receive updates on the target * @since 1.21.0 diff --git a/build.gradle b/build.gradle index a5243b0d573..0c0861d44d1 100644 --- a/build.gradle +++ b/build.gradle @@ -20,7 +20,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.50.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.51.0" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() @@ -300,7 +300,7 @@ subprojects { // depends on core; core's testCompile depends on testing) includeTests = false if (project.hasProperty('jmhIncludeSingleClass')) { - include = [ + includes = [ project.property('jmhIncludeSingleClass') ] } diff --git a/buildscripts/make_dependencies.bat b/buildscripts/make_dependencies.bat index 30e8dd548a1..2bbfd394d46 100644 --- a/buildscripts/make_dependencies.bat +++ b/buildscripts/make_dependencies.bat @@ -1,6 +1,4 @@ -set PROTOBUF_VER=21.1 -@rem Workaround https://github.com/protocolbuffers/protobuf/issues/10172 -set PROTOBUF_VER_ISSUE_10172=3.%PROTOBUF_VER% +set PROTOBUF_VER=21.7 set CMAKE_NAME=cmake-3.3.2-win32-x86 if not exist "protobuf-%PROTOBUF_VER%\build\Release\" ( @@ -25,7 +23,6 @@ set PATH=%PATH%;%cd%\%CMAKE_NAME%\bin powershell -command "$ErrorActionPreference = 'stop'; & { [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12 ; iwr https://github.com/google/protobuf/archive/v%PROTOBUF_VER%.zip -OutFile protobuf.zip }" || exit /b 1 powershell -command "$ErrorActionPreference = 'stop'; & { Add-Type -AssemblyName System.IO.Compression.FileSystem; [System.IO.Compression.ZipFile]::ExtractToDirectory('protobuf.zip', '.') }" || exit /b 1 del protobuf.zip -rename protobuf-%PROTOBUF_VER_ISSUE_10172% protobuf-%PROTOBUF_VER% mkdir protobuf-%PROTOBUF_VER%\build pushd protobuf-%PROTOBUF_VER%\build diff --git a/buildscripts/make_dependencies.sh b/buildscripts/make_dependencies.sh index 8f0033e8a75..29eed3d5050 100755 --- a/buildscripts/make_dependencies.sh +++ b/buildscripts/make_dependencies.sh @@ -3,9 +3,7 @@ # Build protoc set -evux -o pipefail -PROTOBUF_VERSION=21.1 -# https://github.com/protocolbuffers/protobuf/issues/10172 -PROTOBUF_VERSION_ISSUE_10172=3.$PROTOBUF_VERSION +PROTOBUF_VERSION=21.7 # ARCH is x86_64 bit unless otherwise specified. ARCH="${ARCH:-x86_64}" @@ -30,7 +28,6 @@ if [ -f ${INSTALL_DIR}/bin/protoc ]; then else if [[ ! -d "$DOWNLOAD_DIR"/protobuf-"${PROTOBUF_VERSION}" ]]; then curl -Ls https://github.com/google/protobuf/releases/download/v${PROTOBUF_VERSION}/protobuf-all-${PROTOBUF_VERSION}.tar.gz | tar xz -C $DOWNLOAD_DIR - mv "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION_ISSUE_10172}" "$DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION}" fi pushd $DOWNLOAD_DIR/protobuf-${PROTOBUF_VERSION} # install here so we don't need sudo diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index 9fddc958d81..6229341887a 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.50.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.51.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 0bbe6f65ddc..e99bb98026e 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.50.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.51.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index b13295c37b4..b783fcf00c3 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.50.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.51.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index d2f9c12e923..8f2c0e5a0c0 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.50.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.51.0)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 6564862d8af..735b43d5f42 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -217,7 +217,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - private static final String IMPLEMENTATION_VERSION = "1.50.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + private static final String IMPLEMENTATION_VERSION = "1.51.0"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java index 542e84b9c8b..aed3a461fb4 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelOrphanWrapper.java @@ -29,6 +29,13 @@ import java.util.logging.LogRecord; import java.util.logging.Logger; +/** + * Best effort detecting channels that has not been properly cleaned up. + * Use {@link WeakReference} to avoid keeping the channel alive and retaining too much memory. + * Check lost references only on new channel creation and log message to indicate + * the previous channel (id and target) that has not been shutdown. This is done to avoid Object + * finalizers. + */ final class ManagedChannelOrphanWrapper extends ForwardingManagedChannel { private static final ReferenceQueue refqueue = new ReferenceQueue<>(); @@ -148,7 +155,7 @@ static int cleanQueue(ReferenceQueue refqueue) { Level level = Level.SEVERE; if (logger.isLoggable(level)) { String fmt = - "*~*~*~ Channel {0} was not shutdown properly!!! ~*~*~*" + "*~*~*~ Previous channel {0} was not shutdown properly!!! ~*~*~*" + System.getProperty("line.separator") + " Make sure to call shutdown()/shutdownNow() and wait " + "until awaitTermination() returns true."; diff --git a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 4a6a1ff5611..b715f756144 100644 --- a/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/core/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -69,7 +69,14 @@ final class RoundRobinLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getAddresses().isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } + List servers = resolvedAddresses.getAddresses(); Set currentAddrs = subchannels.keySet(); Map latestAddrs = stripAttrs(servers); @@ -126,6 +133,8 @@ public void onSubchannelState(ConnectivityStateInfo state) { for (Subchannel removedSubchannel : removedSubchannels) { shutdownSubchannel(removedSubchannel); } + + return true; } @Override diff --git a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 35b04d59825..d4c07e3d50e 100644 --- a/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -49,7 +49,6 @@ import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; -import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; @@ -148,8 +147,9 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -199,9 +199,10 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -221,8 +222,9 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); verify(newSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); @@ -240,25 +242,16 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); - // test going from non-empty to empty - loadBalancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes(affinity) - .build()); - - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); - verifyNoMoreInteractions(mockHelper); } @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); Ref subchannelStateInfo = subchannel.getAttributes().get( STATE_INFO); @@ -296,9 +289,10 @@ public void pickAfterStateChange() throws Exception { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); @@ -315,9 +309,10 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -353,9 +348,10 @@ public void stayTransientFailureUntilReady() { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -420,8 +416,9 @@ public void nameResolutionErrorWithNoChannels() throws Exception { @Test public void nameResolutionErrorWithActiveChannels() throws Exception { final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); @@ -449,9 +446,10 @@ public void subchannelStateIsolation() throws Exception { Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -522,6 +520,15 @@ public void internalPickerComparisons() { assertFalse(ready1.isEquivalentTo(emptyOk1)); } + @Test + public void emptyAddresses() { + assertThat(loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(affinity) + .build())).isFalse(); + } + private static List getList(SubchannelPicker picker) { return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : Collections.emptyList(); diff --git a/cronet/README.md b/cronet/README.md index 3c767c4e972..b29d19c99b6 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.49.0' +implementation 'io.grpc:grpc-cronet:1.51.0' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index 70149fc9b2c..82bfff47e91 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.49.0' -implementation 'io.grpc:grpc-okhttp:1.49.0' +implementation 'io.grpc:grpc-android:1.51.0' +implementation 'io.grpc:grpc-okhttp:1.51.0' ``` You also need permission to access the device's network state in your diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 78cbdd305e0..f21bda07313 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -32,9 +32,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.21.1' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.51.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.51.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.12' testImplementation 'com.google.truth:truth:1.0.1' - testImplementation 'io.grpc:grpc-testing:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.51.0' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index f30a1110f12..1b121c6ab94 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -30,9 +30,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.21.1' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.51.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.51.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index 5c861238223..2529a491f3c 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -30,9 +30,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.21.1' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.51.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.51.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index 788c015823a..98d526f3291 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -31,9 +31,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.21.1' } + protoc { artifact = 'com.google.protobuf:protoc:3.21.7' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.51.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.51.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.51.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/build.gradle b/examples/build.gradle index 9f779ff819a..4d7c9e4cf33 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -22,8 +22,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' def protocVersion = protobufVersion dependencies { diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index 4d44a130d4b..1c7dac5ef39 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -23,8 +23,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' dependencies { // grpc-alts transitively depends on grpc-netty-shaded, grpc-protobuf, and grpc-stub diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 64b870daf28..ca3a1417140 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -23,8 +23,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 8c9c8673d3f..be1bb554627 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,14 +6,14 @@ jar - 1.50.0-SNAPSHOT + 1.51.0 example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.50.0-SNAPSHOT - 3.21.1 + 1.51.0 + 3.21.7 1.7 1.7 diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index e9db5ef7c61..cbcbcdca3f7 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,8 +21,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 9ca5bfbacda..ae803c96edb 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,14 +6,14 @@ jar - 1.50.0-SNAPSHOT + 1.51.0 example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.50.0-SNAPSHOT - 3.21.1 + 1.51.0 + 3.21.7 1.7 1.7 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index 1d892bfb0d7..d90a993698e 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -22,8 +22,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.21.7' def protocVersion = protobufVersion dependencies { diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index b57eb8a801b..dccaf31ea4f 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,15 +7,15 @@ jar - 1.50.0-SNAPSHOT + 1.51.0 example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.50.0-SNAPSHOT - 3.21.1 - 3.21.1 + 1.51.0 + 3.21.7 + 3.21.7 1.7 1.7 diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle index 2bbfc97e834..80eccee6d24 100644 --- a/examples/example-orca/build.gradle +++ b/examples/example-orca/build.gradle @@ -17,8 +17,8 @@ repositories { sourceCompatibility = 1.8 targetCompatibility = 1.8 -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 895d6c98a7c..df8381481ba 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -23,8 +23,8 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.21.1' +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index 42b6d2d4594..60cfb44a9d3 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,14 +6,14 @@ jar - 1.50.0-SNAPSHOT + 1.51.0 example-tls https://github.com/grpc/grpc-java UTF-8 - 1.50.0-SNAPSHOT - 3.21.1 + 1.51.0 + 3.21.7 2.0.54.Final 1.7 diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 5f48215e876..f9048a91180 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -22,9 +22,9 @@ targetCompatibility = 1.8 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.50.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.51.0' // CURRENT_GRPC_VERSION def nettyTcNativeVersion = '2.0.31.Final' -def protocVersion = '3.21.1' +def protocVersion = '3.21.7' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/pom.xml b/examples/pom.xml index 7aa05b7db95..91f3e174566 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,15 +6,15 @@ jar - 1.50.0-SNAPSHOT + 1.51.0 examples https://github.com/grpc/grpc-java UTF-8 - 1.50.0-SNAPSHOT - 3.21.1 - 3.21.1 + 1.51.0 + 3.21.7 + 3.21.7 1.7 1.7 diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java index 6587444db59..770d764e0cd 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/GcpObservability.java @@ -35,7 +35,6 @@ import io.grpc.gcp.observability.interceptors.LogHelper; import io.grpc.gcp.observability.logging.GcpLogSink; import io.grpc.gcp.observability.logging.Sink; -import io.grpc.internal.TimeProvider; import io.opencensus.common.Duration; import io.opencensus.contrib.grpc.metrics.RpcViewConstants; import io.opencensus.exporter.stats.stackdriver.StackdriverStatsConfiguration; @@ -77,15 +76,15 @@ public static synchronized GcpObservability grpcInit() throws IOException { if (instance == null) { GlobalLocationTags globalLocationTags = new GlobalLocationTags(); ObservabilityConfigImpl observabilityConfig = ObservabilityConfigImpl.getInstance(); - Sink sink = new GcpLogSink(observabilityConfig.getDestinationProjectId(), + Sink sink = new GcpLogSink(observabilityConfig.getProjectId(), globalLocationTags.getLocationTags(), observabilityConfig.getCustomTags(), - observabilityConfig.getFlushMessageCount(), SERVICES_TO_EXCLUDE); - LogHelper helper = new LogHelper(sink, TimeProvider.SYSTEM_TIME_PROVIDER); - ConfigFilterHelper configFilterHelper = ConfigFilterHelper.factory(observabilityConfig); + SERVICES_TO_EXCLUDE); + LogHelper helper = new LogHelper(sink); + ConfigFilterHelper configFilterHelper = ConfigFilterHelper.getInstance(observabilityConfig); instance = grpcInit(sink, observabilityConfig, new InternalLoggingChannelInterceptor.FactoryImpl(helper, configFilterHelper), new InternalLoggingServerInterceptor.FactoryImpl(helper, configFilterHelper)); - instance.registerStackDriverExporter(observabilityConfig.getDestinationProjectId(), + instance.registerStackDriverExporter(observabilityConfig.getProjectId(), observabilityConfig.getCustomTags()); } return instance; diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java index 48dd480973b..0489c8b5e3b 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfig.java @@ -17,10 +17,11 @@ package io.grpc.gcp.observability; import io.grpc.Internal; -import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; import io.opencensus.trace.Sampler; import java.util.List; import java.util.Map; +import java.util.Set; +import javax.annotation.concurrent.ThreadSafe; @Internal public interface ObservabilityConfig { @@ -33,17 +34,14 @@ public interface ObservabilityConfig { /** Is Cloud Tracing enabled. */ boolean isEnableCloudTracing(); - /** Get destination project ID - where logs will go. */ - String getDestinationProjectId(); + /** Get project ID - where logs will go. */ + String getProjectId(); - /** Get message count threshold to flush - flush once message count is reached. */ - Long getFlushMessageCount(); + /** Get filters for client logging. */ + List getClientLogFilters(); - /** Get filters set for logging. */ - List getLogFilters(); - - /** Get event types to log. */ - List getEventTypes(); + /** Get filters for server logging. */ + List getServerLogFilters(); /** Get sampler for TraceConfig - when Cloud Tracing is enabled. */ Sampler getSampler(); @@ -54,27 +52,44 @@ public interface ObservabilityConfig { /** * POJO for representing a filter used in configuration. */ + @ThreadSafe class LogFilter { - /** Pattern indicating which service/method to log. */ - public final String pattern; + /** Set of services. */ + public final Set services; + + /* Set of fullMethodNames. */ + public final Set methods; + + /** Boolean to indicate all services and methods. */ + public final boolean matchAll; + + /** Number of bytes of header to log. */ + public final int headerBytes; - /** Number of bytes of each header to log. */ - public final Integer headerBytes; + /** Number of bytes of message to log. */ + public final int messageBytes; - /** Number of bytes of each header to log. */ - public final Integer messageBytes; + /** Boolean to indicate if services and methods matching pattern needs to be excluded. */ + public final boolean excludePattern; /** * Object used to represent filter used in configuration. - * - * @param pattern Pattern indicating which service/method to log - * @param headerBytes Number of bytes of each header to log - * @param messageBytes Number of bytes of each header to log + * @param services Set of services derived from pattern + * @param serviceMethods Set of fullMethodNames derived from pattern + * @param matchAll If true, match all services and methods + * @param headerBytes Total number of bytes of header to log + * @param messageBytes Total number of bytes of message to log + * @param excludePattern If true, services and methods matching pattern be excluded */ - public LogFilter(String pattern, Integer headerBytes, Integer messageBytes) { - this.pattern = pattern; + public LogFilter(Set services, Set serviceMethods, boolean matchAll, + int headerBytes, int messageBytes, + boolean excludePattern) { + this.services = services; + this.methods = serviceMethods; + this.matchAll = matchAll; this.headerBytes = headerBytes; this.messageBytes = messageBytes; + this.excludePattern = excludePattern; } } } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java index 1d0505e2818..2b0a44473d0 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/ObservabilityConfigImpl.java @@ -18,36 +18,47 @@ import static com.google.common.base.Preconditions.checkArgument; +import com.google.cloud.ServiceOptions; import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.grpc.internal.JsonParser; import io.grpc.internal.JsonUtil; -import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; import io.opencensus.trace.Sampler; import io.opencensus.trace.samplers.Samplers; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** * gRPC GcpObservability configuration processor. */ final class ObservabilityConfigImpl implements ObservabilityConfig { - private static final String CONFIG_ENV_VAR_NAME = "GRPC_CONFIG_OBSERVABILITY"; - private static final String CONFIG_FILE_ENV_VAR_NAME = "GRPC_CONFIG_OBSERVABILITY_JSON"; + private static final Logger logger = Logger + .getLogger(ObservabilityConfigImpl.class.getName()); + private static final String CONFIG_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG"; + private static final String CONFIG_FILE_ENV_VAR_NAME = "GRPC_GCP_OBSERVABILITY_CONFIG_FILE"; // Tolerance for floating-point comparisons. private static final double EPSILON = 1e-6; + private static final Pattern METHOD_NAME_REGEX = + Pattern.compile("^([*])|((([\\w.]+)/((?:\\w+)|[*])))$"); + private boolean enableCloudLogging = false; private boolean enableCloudMonitoring = false; private boolean enableCloudTracing = false; - private String destinationProjectId = null; - private Long flushMessageCount = null; - private List logFilters; - private List eventTypes; + private String projectId = null; + + private List clientLogFilters; + private List serverLogFilters; private Sampler sampler; private Map customTags; @@ -63,7 +74,10 @@ static ObservabilityConfigImpl getInstance() throws IOException { } void parseFile(String configFile) throws IOException { - parse(new String(Files.readAllBytes(Paths.get(configFile)), Charsets.UTF_8)); + String configFileContent = + new String(Files.readAllBytes(Paths.get(configFile)), Charsets.UTF_8); + checkArgument(!configFileContent.isEmpty(), CONFIG_FILE_ENV_VAR_NAME + " is empty!"); + parse(configFileContent); } @SuppressWarnings("unchecked") @@ -73,96 +87,151 @@ void parse(String config) throws IOException { } private void parseConfig(Map config) { - if (config != null) { - Boolean value = JsonUtil.getBoolean(config, "enable_cloud_logging"); - if (value != null) { - enableCloudLogging = value; - } - value = JsonUtil.getBoolean(config, "enable_cloud_monitoring"); - if (value != null) { - enableCloudMonitoring = value; - } - value = JsonUtil.getBoolean(config, "enable_cloud_trace"); - if (value != null) { - enableCloudTracing = value; - } - destinationProjectId = JsonUtil.getString(config, "destination_project_id"); - flushMessageCount = JsonUtil.getNumberAsLong(config, "flush_message_count"); - List rawList = JsonUtil.getList(config, "log_filters"); - if (rawList != null) { - List> jsonLogFilters = JsonUtil.checkObjectList(rawList); - ImmutableList.Builder logFiltersBuilder = new ImmutableList.Builder<>(); - for (Map jsonLogFilter : jsonLogFilters) { - logFiltersBuilder.add(parseJsonLogFilter(jsonLogFilter)); - } - this.logFilters = logFiltersBuilder.build(); - } - rawList = JsonUtil.getList(config, "event_types"); - if (rawList != null) { - List jsonEventTypes = JsonUtil.checkStringList(rawList); - ImmutableList.Builder eventTypesBuilder = new ImmutableList.Builder<>(); - for (String jsonEventType : jsonEventTypes) { - eventTypesBuilder.add(convertEventType(jsonEventType)); - } - this.eventTypes = eventTypesBuilder.build(); - } - Double samplingRate = JsonUtil.getNumberAsDouble(config, "global_trace_sampling_rate"); - if (samplingRate == null) { - this.sampler = Samplers.probabilitySampler(0.0); - } else { - checkArgument( - samplingRate >= 0.0 && samplingRate <= 1.0, - "'global_trace_sampling_rate' needs to be between [0.0, 1.0]"); - // Using alwaysSample() instead of probabilitySampler() because according to - // {@link io.opencensus.trace.samplers.ProbabilitySampler#shouldSample} - // there is a (very) small chance of *not* sampling if probability = 1.00. - if (1 - samplingRate < EPSILON) { - this.sampler = Samplers.alwaysSample(); - } else { - this.sampler = Samplers.probabilitySampler(samplingRate); - } - } - Map rawCustomTags = JsonUtil.getObject(config, "custom_tags"); - if (rawCustomTags != null) { - ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); - for (Map.Entry entry: rawCustomTags.entrySet()) { - checkArgument( - entry.getValue() instanceof String, - "'custom_tags' needs to be a map of "); - builder.put(entry.getKey(), (String) entry.getValue()); - } - customTags = builder.build(); - } + checkArgument(config != null, "Invalid configuration"); + if (config.isEmpty()) { + clientLogFilters = Collections.emptyList(); + serverLogFilters = Collections.emptyList(); + customTags = Collections.emptyMap(); + return; + } + projectId = fetchProjectId(JsonUtil.getString(config, "project_id")); + + Map rawCloudLoggingObject = JsonUtil.getObject(config, "cloud_logging"); + if (rawCloudLoggingObject != null) { + enableCloudLogging = true; + ImmutableList.Builder clientFiltersBuilder = new ImmutableList.Builder<>(); + ImmutableList.Builder serverFiltersBuilder = new ImmutableList.Builder<>(); + parseLoggingObject(rawCloudLoggingObject, clientFiltersBuilder, serverFiltersBuilder); + clientLogFilters = clientFiltersBuilder.build(); + serverLogFilters = serverFiltersBuilder.build(); + } + + Map rawCloudMonitoringObject = JsonUtil.getObject(config, "cloud_monitoring"); + if (rawCloudMonitoringObject != null) { + enableCloudMonitoring = true; + } + + Map rawCloudTracingObject = JsonUtil.getObject(config, "cloud_trace"); + if (rawCloudTracingObject != null) { + enableCloudTracing = true; + sampler = parseTracingObject(rawCloudTracingObject); + } + + Map rawCustomTagsObject = JsonUtil.getObject(config, "labels"); + if (rawCustomTagsObject != null) { + customTags = parseCustomTags(rawCustomTagsObject); + } + + if (clientLogFilters == null) { + clientLogFilters = Collections.emptyList(); + } + if (serverLogFilters == null) { + serverLogFilters = Collections.emptyList(); + } + if (customTags == null) { + customTags = Collections.emptyMap(); + } + } + + private static String fetchProjectId(String configProjectId) { + // If project_id is not specified in config, get default GCP project id from the environment + String projectId = configProjectId != null ? configProjectId : getDefaultGcpProjectId(); + checkArgument(projectId != null, "Unable to detect project_id"); + logger.log(Level.FINEST, "Found project ID : ", projectId); + return projectId; + } + + private static String getDefaultGcpProjectId() { + return ServiceOptions.getDefaultProjectId(); + } + + private static void parseLoggingObject( + Map rawLoggingConfig, + ImmutableList.Builder clientFilters, + ImmutableList.Builder serverFilters) { + parseRpcEvents(JsonUtil.getList(rawLoggingConfig, "client_rpc_events"), clientFilters); + parseRpcEvents(JsonUtil.getList(rawLoggingConfig, "server_rpc_events"), serverFilters); + } + + private static Sampler parseTracingObject(Map rawCloudTracingConfig) { + Sampler defaultSampler = Samplers.probabilitySampler(0.0); + Double samplingRate = JsonUtil.getNumberAsDouble(rawCloudTracingConfig, "sampling_rate"); + if (samplingRate == null) { + return defaultSampler; } + checkArgument(samplingRate >= 0.0 && samplingRate <= 1.0, + "'sampling_rate' needs to be between [0.0, 1.0]"); + // Using alwaysSample() instead of probabilitySampler() because according to + // {@link io.opencensus.trace.samplers.ProbabilitySampler#shouldSample} + // there is a (very) small chance of *not* sampling if probability = 1.00. + return 1 - samplingRate < EPSILON ? Samplers.alwaysSample() + : Samplers.probabilitySampler(samplingRate); } - private EventType convertEventType(String val) { - switch (val) { - case "GRPC_CALL_UNKNOWN": - return EventType.GRPC_CALL_UNKNOWN; - case "GRPC_CALL_REQUEST_HEADER": - return EventType.GRPC_CALL_REQUEST_HEADER; - case "GRPC_CALL_RESPONSE_HEADER": - return EventType.GRPC_CALL_RESPONSE_HEADER; - case "GRPC_CALL_REQUEST_MESSAGE": - return EventType.GRPC_CALL_REQUEST_MESSAGE; - case "GRPC_CALL_RESPONSE_MESSAGE": - return EventType.GRPC_CALL_RESPONSE_MESSAGE; - case "GRPC_CALL_TRAILER": - return EventType.GRPC_CALL_TRAILER; - case "GRPC_CALL_HALF_CLOSE": - return EventType.GRPC_CALL_HALF_CLOSE; - case "GRPC_CALL_CANCEL": - return EventType.GRPC_CALL_CANCEL; - default: - throw new IllegalArgumentException("Unknown event type value:" + val); + private static Map parseCustomTags(Map rawCustomTags) { + ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + for (Map.Entry entry: rawCustomTags.entrySet()) { + checkArgument( + entry.getValue() instanceof String, + "'labels' needs to be a map of "); + builder.put(entry.getKey(), (String) entry.getValue()); } + return builder.build(); } - private LogFilter parseJsonLogFilter(Map logFilterMap) { - return new LogFilter(JsonUtil.getString(logFilterMap, "pattern"), - JsonUtil.getNumberAsInteger(logFilterMap, "header_bytes"), - JsonUtil.getNumberAsInteger(logFilterMap, "message_bytes")); + private static void parseRpcEvents(List rpcEvents, ImmutableList.Builder filters) { + if (rpcEvents == null) { + return; + } + List> jsonRpcEvents = JsonUtil.checkObjectList(rpcEvents); + for (Map jsonClientRpcEvent : jsonRpcEvents) { + filters.add(parseJsonLogFilter(jsonClientRpcEvent)); + } + } + + private static LogFilter parseJsonLogFilter(Map logFilterMap) { + ImmutableSet.Builder servicesSetBuilder = new ImmutableSet.Builder<>(); + ImmutableSet.Builder methodsSetBuilder = new ImmutableSet.Builder<>(); + boolean wildCardFilter = false; + + boolean excludeFilter = + Boolean.TRUE.equals(JsonUtil.getBoolean(logFilterMap, "exclude")); + List methodsList = JsonUtil.getListOfStrings(logFilterMap, "methods"); + if (methodsList != null) { + wildCardFilter = extractMethodOrServicePattern( + methodsList, excludeFilter, servicesSetBuilder, methodsSetBuilder); + } + Integer maxHeaderBytes = JsonUtil.getNumberAsInteger(logFilterMap, "max_metadata_bytes"); + Integer maxMessageBytes = JsonUtil.getNumberAsInteger(logFilterMap, "max_message_bytes"); + + return new LogFilter( + servicesSetBuilder.build(), + methodsSetBuilder.build(), + wildCardFilter, + maxHeaderBytes != null ? maxHeaderBytes.intValue() : 0, + maxMessageBytes != null ? maxMessageBytes.intValue() : 0, + excludeFilter); + } + + private static boolean extractMethodOrServicePattern(List patternList, boolean exclude, + ImmutableSet.Builder servicesSetBuilder, + ImmutableSet.Builder methodsSetBuilder) { + boolean globalFilter = false; + for (String methodOrServicePattern : patternList) { + Matcher matcher = METHOD_NAME_REGEX.matcher(methodOrServicePattern); + checkArgument( + matcher.matches(), "invalid service or method filter : " + methodOrServicePattern); + if ("*".equals(methodOrServicePattern)) { + checkArgument(!exclude, "cannot have 'exclude' and '*' wildcard in the same filter"); + globalFilter = true; + } else if ("*".equals(matcher.group(5))) { + String service = matcher.group(4); + servicesSetBuilder.add(service); + } else { + methodsSetBuilder.add(methodOrServicePattern); + } + } + return globalFilter; } @Override @@ -181,23 +250,18 @@ public boolean isEnableCloudTracing() { } @Override - public String getDestinationProjectId() { - return destinationProjectId; - } - - @Override - public Long getFlushMessageCount() { - return flushMessageCount; + public String getProjectId() { + return projectId; } @Override - public List getLogFilters() { - return logFilters; + public List getClientLogFilters() { + return clientLogFilters; } @Override - public List getEventTypes() { - return eventTypes; + public List getServerLogFilters() { + return serverLogFilters; } @Override diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java index 38a3c80861a..9b05634dbfe 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelper.java @@ -16,51 +16,27 @@ package io.grpc.gcp.observability.interceptors; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import io.grpc.Internal; -import io.grpc.MethodDescriptor; import io.grpc.gcp.observability.ObservabilityConfig; import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; -import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.logging.Level; -import java.util.logging.Logger; /** * Parses gRPC GcpObservability configuration filters for interceptors usage. */ @Internal public class ConfigFilterHelper { - - private static final Logger logger = Logger.getLogger(ConfigFilterHelper.class.getName()); - public static final FilterParams NO_FILTER_PARAMS = FilterParams.create(false, 0, 0); - public static final String globalPattern = "*"; private final ObservabilityConfig config; - @VisibleForTesting - boolean methodOrServiceFilterPresent; - // Flag to log every service and method - @VisibleForTesting - Map perServiceFilters; - @VisibleForTesting - Map perMethodFilters; - @VisibleForTesting - Set logEventTypeSet; - - @VisibleForTesting - ConfigFilterHelper(ObservabilityConfig config) { + + private ConfigFilterHelper(ObservabilityConfig config) { this.config = config; - this.methodOrServiceFilterPresent = false; - this.perServiceFilters = new HashMap<>(); - this.perMethodFilters = new HashMap<>(); } /** @@ -69,82 +45,44 @@ public class ConfigFilterHelper { * @param config processed ObservabilityConfig object * @return helper instance for filtering */ - public static ConfigFilterHelper factory(ObservabilityConfig config) { - ConfigFilterHelper filterHelper = new ConfigFilterHelper(config); - if (config.isEnableCloudLogging()) { - filterHelper.setMethodOrServiceFilterMaps(); - filterHelper.setEventFilterSet(); - } - return filterHelper; + public static ConfigFilterHelper getInstance(ObservabilityConfig config) { + return new ConfigFilterHelper(config); } - @VisibleForTesting - void setMethodOrServiceFilterMaps() { - List logFilters = config.getLogFilters(); - if (logFilters == null) { - return; - } - Map perServiceFilters = new HashMap<>(); - Map perMethodFilters = new HashMap<>(); - - for (LogFilter currentFilter : logFilters) { - // '*' for global, 'service/*' for service glob, or 'service/method' for fully qualified - String methodOrServicePattern = currentFilter.pattern; - int currentHeaderBytes - = currentFilter.headerBytes != null ? currentFilter.headerBytes : 0; - int currentMessageBytes - = currentFilter.messageBytes != null ? currentFilter.messageBytes : 0; - if (methodOrServicePattern.equals("*")) { - // parse config for global, e.g. "*" - if (perServiceFilters.containsKey(globalPattern)) { - logger.log(Level.WARNING, "Duplicate entry : {0}", methodOrServicePattern); - continue; - } - FilterParams params = FilterParams.create(true, - currentHeaderBytes, currentMessageBytes); - perServiceFilters.put(globalPattern, params); - } else if (methodOrServicePattern.endsWith("/*")) { - // TODO(DNVindhya): check if service name is a valid string for a service name - // parse config for a service, e.g. "service/*" - String service = MethodDescriptor.extractFullServiceName(methodOrServicePattern); - if (perServiceFilters.containsKey(service)) { - logger.log(Level.WARNING, "Duplicate entry : {0)", methodOrServicePattern); - continue; - } - FilterParams params = FilterParams.create(true, - currentHeaderBytes, currentMessageBytes); - perServiceFilters.put(service, params); - } else { - // TODO(DNVVindhya): check if methodOrServicePattern is a valid full qualified method name - // parse pattern for a fully qualified method, e.g "service/method" - if (perMethodFilters.containsKey(methodOrServicePattern)) { - logger.log(Level.WARNING, "Duplicate entry : {0}", methodOrServicePattern); - continue; + /** + * Checks if the corresponding service/method passed needs to be logged according to user provided + * observability configuration. + * Filters are evaluated in text order, first match is used. + * + * @param fullMethodName the fully qualified name of the method + * @param client set to true if method being checked is a client method; false otherwise + * @return FilterParams object 1. specifies if the corresponding method needs to be logged + * (log field will be set to true) 2. values of payload limits retrieved from configuration + */ + public FilterParams logRpcMethod(String fullMethodName, boolean client) { + FilterParams params = NO_FILTER_PARAMS; + + int index = checkNotNull(fullMethodName, "fullMethodName").lastIndexOf('/'); + String serviceName = fullMethodName.substring(0, index); + + List logFilters = + client ? config.getClientLogFilters() : config.getServerLogFilters(); + + // TODO (dnvindhya): Optimize by caching results for fullMethodName. + for (LogFilter logFilter : logFilters) { + if (logFilter.matchAll + || logFilter.services.contains(serviceName) + || logFilter.methods.contains(fullMethodName)) { + if (logFilter.excludePattern) { + return params; } - FilterParams params = FilterParams.create(true, - currentHeaderBytes, currentMessageBytes); - perMethodFilters.put(methodOrServicePattern, params); + int currentHeaderBytes = logFilter.headerBytes; + int currentMessageBytes = logFilter.messageBytes; + return FilterParams.create(true, currentHeaderBytes, currentMessageBytes); } } - this.perServiceFilters = ImmutableMap.copyOf(perServiceFilters); - this.perMethodFilters = ImmutableMap.copyOf(perMethodFilters); - if (!perServiceFilters.isEmpty() || !perMethodFilters.isEmpty()) { - this.methodOrServiceFilterPresent = true; - } - } - - @VisibleForTesting - void setEventFilterSet() { - List eventFilters = config.getEventTypes(); - if (eventFilters == null) { - return; - } - if (eventFilters.isEmpty()) { - this.logEventTypeSet = ImmutableSet.of(); - return; - } - this.logEventTypeSet = ImmutableSet.copyOf(eventFilters); + return params; } /** @@ -166,50 +104,4 @@ public static FilterParams create(boolean log, int headerBytes, int messageBytes log, headerBytes, messageBytes); } } - - /** - * Checks if the corresponding service/method passed needs to be logged as per the user provided - * configuration. - * - * @param method the fully qualified name of the method - * @return MethodFilterParams object 1. specifies if the corresponding method needs to be logged - * (log field will be set to true) 2. values of payload limits retrieved from configuration - */ - public FilterParams isMethodToBeLogged(MethodDescriptor method) { - FilterParams params = NO_FILTER_PARAMS; - if (methodOrServiceFilterPresent) { - String fullMethodName = method.getFullMethodName(); - if (perMethodFilters.containsKey(fullMethodName)) { - params = perMethodFilters.get(fullMethodName); - } else { - String serviceName = method.getServiceName(); - if (perServiceFilters.containsKey(serviceName)) { - params = perServiceFilters.get(serviceName); - } else if (perServiceFilters.containsKey(globalPattern)) { - params = perServiceFilters.get(globalPattern); - } - } - } - return params; - } - - /** - * Checks if the corresponding event passed needs to be logged as per the user provided - * configuration. - * - *

All events are logged by default if event_types is not specified or {} in configuration. - * If event_types is specified as [], no events will be logged. - * If events types is specified as a non-empty list, only the events specified in the - * list will be logged. - *

- * - * @param event gRPC observability event - * @return true if event needs to be logged, false otherwise - */ - public boolean isEventToBeLogged(EventType event) { - if (logEventTypeSet == null) { - return true; - } - return logEventTypeSet.contains(event); - } } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java index 81e0a9819af..517745a5afc 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptor.java @@ -85,7 +85,7 @@ public ClientCall interceptCall(MethodDescriptor ClientCall interceptCall(MethodDescriptor ClientCall interceptCall(MethodDescriptor responseListener, Metadata headers) { - // Event: EventType.GRPC_CALL_REQUEST_HEADER + // Event: EventType.CLIENT_HEADER // The timeout should reflect the time remaining when the call is started, so compute // remaining time here. final Duration timeout = deadline == null ? null : Durations.fromNanos(deadline.timeRemaining(TimeUnit.NANOSECONDS)); - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_REQUEST_HEADER)) { - try { - helper.logRequestHeader( - seq.getAndIncrement(), - serviceName, - methodName, - authority, - timeout, - headers, - maxHeaderBytes, - EventLogger.LOGGER_CLIENT, - rpcId, - null); - } catch (Exception e) { - // Catching generic exceptions instead of specific ones for all the events. - // This way we can catch both expected and unexpected exceptions instead of re-throwing - // exceptions to callers which will lead to RPC getting aborted. - // Expected exceptions to be caught: - // 1. IllegalArgumentException - // 2. NullPointerException - logger.log(Level.SEVERE, "Unable to log request header", e); - } + try { + helper.logClientHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + timeout, + headers, + maxHeaderBytes, + EventLogger.CLIENT, + callId, + null); + } catch (Exception e) { + // Catching generic exceptions instead of specific ones for all the events. + // This way we can catch both expected and unexpected exceptions instead of re-throwing + // exceptions to callers which will lead to RPC getting aborted. + // Expected exceptions to be caught: + // 1. IllegalArgumentException + // 2. NullPointerException + logger.log(Level.SEVERE, "Unable to log request header", e); } Listener observabilityListener = new SimpleForwardingClientCallListener(responseListener) { @Override public void onMessage(RespT message) { - // Event: EventType.GRPC_CALL_RESPONSE_MESSAGE - EventType responseMessageType = EventType.GRPC_CALL_RESPONSE_MESSAGE; - if (filterHelper.isEventToBeLogged(responseMessageType)) { - try { - helper.logRpcMessage( - seq.getAndIncrement(), - serviceName, - methodName, - responseMessageType, - message, - maxMessageBytes, - EventLogger.LOGGER_CLIENT, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log response message", e); - } + // Event: EventType.SERVER_MESSAGE + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventType.SERVER_MESSAGE, + message, + maxMessageBytes, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response message", e); } super.onMessage(message); } @Override public void onHeaders(Metadata headers) { - // Event: EventType.GRPC_CALL_RESPONSE_HEADER - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_RESPONSE_HEADER)) { - try { - helper.logResponseHeader( - seq.getAndIncrement(), - serviceName, - methodName, - headers, - maxHeaderBytes, - EventLogger.LOGGER_CLIENT, - rpcId, - LogHelper.getPeerAddress(getAttributes())); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log response header", e); - } + // Event: EventType.SERVER_HEADER + try { + helper.logServerHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + headers, + maxHeaderBytes, + EventLogger.CLIENT, + callId, + LogHelper.getPeerAddress(getAttributes())); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response header", e); } super.onHeaders(headers); } @Override public void onClose(Status status, Metadata trailers) { - // Event: EventType.GRPC_CALL_TRAILER - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_TRAILER)) { - try { - helper.logTrailer( - seq.getAndIncrement(), - serviceName, - methodName, - status, - trailers, - maxHeaderBytes, - EventLogger.LOGGER_CLIENT, - rpcId, - LogHelper.getPeerAddress(getAttributes())); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log trailer", e); - } + // Event: EventType.SERVER_TRAILER + try { + helper.logTrailer( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + status, + trailers, + maxHeaderBytes, + EventLogger.CLIENT, + callId, + LogHelper.getPeerAddress(getAttributes())); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log trailer", e); } super.onClose(status, trailers); } @@ -207,58 +201,54 @@ public void onClose(Status status, Metadata trailers) { @Override public void sendMessage(ReqT message) { - // Event: EventType.GRPC_CALL_REQUEST_MESSAGE - EventType requestMessageType = EventType.GRPC_CALL_REQUEST_MESSAGE; - if (filterHelper.isEventToBeLogged(requestMessageType)) { - try { - helper.logRpcMessage( - seq.getAndIncrement(), - serviceName, - methodName, - requestMessageType, - message, - maxMessageBytes, - EventLogger.LOGGER_CLIENT, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log request message", e); - } + // Event: EventType.CLIENT_MESSAGE + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventType.CLIENT_MESSAGE, + message, + maxMessageBytes, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log request message", e); } super.sendMessage(message); } @Override public void halfClose() { - // Event: EventType.GRPC_CALL_HALF_CLOSE - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_HALF_CLOSE)) { - try { - helper.logHalfClose( - seq.getAndIncrement(), - serviceName, - methodName, - EventLogger.LOGGER_CLIENT, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log half close", e); - } + // Event: EventType.CLIENT_HALF_CLOSE + try { + helper.logHalfClose( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log half close", e); } super.halfClose(); } @Override public void cancel(String message, Throwable cause) { - // Event: EventType.GRPC_CALL_CANCEL - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_CANCEL)) { - try { - helper.logCancel( - seq.getAndIncrement(), - serviceName, - methodName, - EventLogger.LOGGER_CLIENT, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log cancel", e); - } + // Event: EventType.CANCEL + try { + helper.logCancel( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.CLIENT, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log cancel", e); } super.cancel(message, cause); } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java index 112a1c067b1..fe98fbdc6d5 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptor.java @@ -84,7 +84,7 @@ private InternalLoggingServerInterceptor(LogHelper helper, ConfigFilterHelper fi public ServerCall.Listener interceptCall(ServerCall call, Metadata headers, ServerCallHandler next) { final AtomicLong seq = new AtomicLong(1); - final String rpcId = UUID.randomUUID().toString(); + final String callId = UUID.randomUUID().toString(); final String authority = call.getAuthority(); final String serviceName = call.getMethodDescriptor().getServiceName(); final String methodName = call.getMethodDescriptor().getBareMethodName(); @@ -93,7 +93,8 @@ public ServerCall.Listener interceptCall(ServerCall ServerCall.Listener interceptCall(ServerCall wrapperCall = new SimpleForwardingServerCall(call) { @Override public void sendHeaders(Metadata headers) { - // Event: EventType.GRPC_CALL_RESPONSE_HEADER - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_RESPONSE_HEADER)) { - try { - helper.logResponseHeader( - seq.getAndIncrement(), - serviceName, - methodName, - headers, - maxHeaderBytes, - EventLogger.LOGGER_SERVER, - rpcId, - null); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log response header", e); - } + // Event: EventType.SERVER_HEADER + try { + helper.logServerHeader( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + headers, + maxHeaderBytes, + EventLogger.SERVER, + callId, + null); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response header", e); } super.sendHeaders(headers); } @Override public void sendMessage(RespT message) { - // Event: EventType.GRPC_CALL_RESPONSE_MESSAGE - EventType responseMessageType = EventType.GRPC_CALL_RESPONSE_MESSAGE; - if (filterHelper.isEventToBeLogged(responseMessageType)) { - try { - helper.logRpcMessage( - seq.getAndIncrement(), - serviceName, - methodName, - responseMessageType, - message, - maxMessageBytes, - EventLogger.LOGGER_SERVER, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log response message", e); - } + // Event: EventType.SERVER_MESSAGE + EventType responseMessageType = EventType.SERVER_MESSAGE; + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + responseMessageType, + message, + maxMessageBytes, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log response message", e); } super.sendMessage(message); } @Override public void close(Status status, Metadata trailers) { - // Event: EventType.GRPC_CALL_TRAILER - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_TRAILER)) { - try { - helper.logTrailer( - seq.getAndIncrement(), - serviceName, - methodName, - status, - trailers, - maxHeaderBytes, - EventLogger.LOGGER_SERVER, - rpcId, - null); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log trailer", e); - } + // Event: EventType.SERVER_TRAILER + try { + helper.logTrailer( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + status, + trailers, + maxHeaderBytes, + EventLogger.SERVER, + callId, + null); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log trailer", e); } super.close(status, trailers); } @@ -198,58 +194,56 @@ public void close(Status status, Metadata trailers) { return new SimpleForwardingServerCallListener(listener) { @Override public void onMessage(ReqT message) { - // Event: EventType.GRPC_CALL_REQUEST_MESSAGE - EventType requestMessageType = EventType.GRPC_CALL_REQUEST_MESSAGE; - if (filterHelper.isEventToBeLogged(requestMessageType)) { - try { - helper.logRpcMessage( - seq.getAndIncrement(), - serviceName, - methodName, - requestMessageType, - message, - maxMessageBytes, - EventLogger.LOGGER_SERVER, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log request message", e); - } + + // Event: EventType.CLIENT_MESSAGE + EventType requestMessageType = EventType.CLIENT_MESSAGE; + try { + helper.logRpcMessage( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + requestMessageType, + message, + maxMessageBytes, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log request message", e); } super.onMessage(message); } @Override public void onHalfClose() { - // Event: EventType.GRPC_CALL_HALF_CLOSE - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_HALF_CLOSE)) { - try { - helper.logHalfClose( - seq.getAndIncrement(), - serviceName, - methodName, - EventLogger.LOGGER_SERVER, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log half close", e); - } + // Event: EventType.CLIENT_HALF_CLOSE + try { + helper.logHalfClose( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log half close", e); } super.onHalfClose(); } @Override public void onCancel() { - // Event: EventType.GRPC_CALL_CANCEL - if (filterHelper.isEventToBeLogged(EventType.GRPC_CALL_CANCEL)) { - try { - helper.logCancel( - seq.getAndIncrement(), - serviceName, - methodName, - EventLogger.LOGGER_SERVER, - rpcId); - } catch (Exception e) { - logger.log(Level.SEVERE, "Unable to log cancel", e); - } + // Event: EventType.CANCEL + try { + helper.logCancel( + seq.getAndIncrement(), + serviceName, + methodName, + authority, + EventLogger.SERVER, + callId); + } catch (Exception e) { + logger.log(Level.SEVERE, "Unable to log cancel", e); } super.onCancel(); } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java index 46589f93845..9b46699efaf 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/interceptors/LogHelper.java @@ -18,32 +18,32 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; -import com.google.common.base.Charsets; +import com.google.common.base.Joiner; import com.google.protobuf.ByteString; import com.google.protobuf.Duration; -import com.google.protobuf.util.Timestamps; import io.grpc.Attributes; import io.grpc.Deadline; import io.grpc.Grpc; import io.grpc.Internal; -import io.grpc.InternalMetadata; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.gcp.observability.logging.Sink; -import io.grpc.internal.TimeProvider; +import io.grpc.observabilitylog.v1.Address; import io.grpc.observabilitylog.v1.GrpcLogRecord; -import io.grpc.observabilitylog.v1.GrpcLogRecord.Address; import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; -import io.grpc.observabilitylog.v1.GrpcLogRecord.LogLevel; +import io.grpc.observabilitylog.v1.Payload; import java.net.Inet4Address; import java.net.Inet6Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -63,23 +63,20 @@ public class LogHelper { Metadata.BINARY_BYTE_MARSHALLER); private final Sink sink; - private final TimeProvider timeProvider; /** * Creates a LogHelper instance. + * @param sink sink * - * @param sink sink - * @param timeProvider timeprovider */ - public LogHelper(Sink sink, TimeProvider timeProvider) { + public LogHelper(Sink sink) { this.sink = sink; - this.timeProvider = timeProvider; } /** * Logs the request header. Binary logging equivalent of logClientHeader. */ - void logRequestHeader( + void logClientHeader( long seqId, String serviceName, String methodName, @@ -88,35 +85,33 @@ void logRequestHeader( Metadata metadata, int maxHeaderBytes, GrpcLogRecord.EventLogger eventLogger, - String rpcId, + String callId, // null on client side @Nullable SocketAddress peerAddress) { checkNotNull(serviceName, "serviceName"); checkNotNull(methodName, "methodName"); - checkNotNull(rpcId, "rpcId"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); checkArgument( - peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.LOGGER_SERVER, + peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.SERVER, "peerAddress can only be specified by server"); - - PayloadBuilder pair = + PayloadBuilderHelper pair = createMetadataProto(metadata, maxHeaderBytes); - GrpcLogRecord.Builder logEntryBuilder = createTimestamp() + if (timeout != null) { + pair.payloadBuilder.setTimeout(timeout); + } + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) .setAuthority(authority) - .setEventType(EventType.GRPC_CALL_REQUEST_HEADER) - .setEventLogger(eventLogger) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setMetadata(pair.payload) - .setPayloadSize(pair.size) + .setType(EventType.CLIENT_HEADER) + .setLogger(eventLogger) + .setPayload(pair.payloadBuilder) .setPayloadTruncated(pair.truncated) - .setRpcId(rpcId); - if (timeout != null) { - logEntryBuilder.setTimeout(timeout); - } + .setCallId(callId); if (peerAddress != null) { - logEntryBuilder.setPeerAddress(socketAddressToProto(peerAddress)); + logEntryBuilder.setPeer(socketAddressToProto(peerAddress)); } sink.write(logEntryBuilder.build()); } @@ -124,39 +119,40 @@ void logRequestHeader( /** * Logs the response header. Binary logging equivalent of logServerHeader. */ - void logResponseHeader( + void logServerHeader( long seqId, String serviceName, String methodName, + String authority, Metadata metadata, int maxHeaderBytes, GrpcLogRecord.EventLogger eventLogger, - String rpcId, + String callId, @Nullable SocketAddress peerAddress) { checkNotNull(serviceName, "serviceName"); checkNotNull(methodName, "methodName"); - checkNotNull(rpcId, "rpcId"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); // Logging peer address only on the first incoming event. On server side, peer address will // of logging request header checkArgument( - peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.LOGGER_CLIENT, + peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.CLIENT, "peerAddress can only be specified for client"); - PayloadBuilder pair = + PayloadBuilderHelper pair = createMetadataProto(metadata, maxHeaderBytes); - GrpcLogRecord.Builder logEntryBuilder = createTimestamp() + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_RESPONSE_HEADER) - .setEventLogger(eventLogger) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setMetadata(pair.payload) - .setPayloadSize(pair.size) + .setAuthority(authority) + .setType(EventType.SERVER_HEADER) + .setLogger(eventLogger) + .setPayload(pair.payloadBuilder) .setPayloadTruncated(pair.truncated) - .setRpcId(rpcId); + .setCallId(callId); if (peerAddress != null) { - logEntryBuilder.setPeerAddress(socketAddressToProto(peerAddress)); + logEntryBuilder.setPeer(socketAddressToProto(peerAddress)); } sink.write(logEntryBuilder.build()); } @@ -168,44 +164,45 @@ void logTrailer( long seqId, String serviceName, String methodName, + String authority, Status status, Metadata metadata, int maxHeaderBytes, GrpcLogRecord.EventLogger eventLogger, - String rpcId, + String callId, @Nullable SocketAddress peerAddress) { checkNotNull(serviceName, "serviceName"); checkNotNull(methodName, "methodName"); + checkNotNull(authority, "authority"); checkNotNull(status, "status"); - checkNotNull(rpcId, "rpcId"); + checkNotNull(callId, "callId"); checkArgument( - peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.LOGGER_CLIENT, + peerAddress == null || eventLogger == GrpcLogRecord.EventLogger.CLIENT, "peerAddress can only be specified for client"); - PayloadBuilder pair = + PayloadBuilderHelper pair = createMetadataProto(metadata, maxHeaderBytes); - GrpcLogRecord.Builder logEntryBuilder = createTimestamp() - .setSequenceId(seqId) - .setServiceName(serviceName) - .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_TRAILER) - .setEventLogger(eventLogger) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setMetadata(pair.payload) - .setPayloadSize(pair.size) - .setPayloadTruncated(pair.truncated) - .setStatusCode(status.getCode().value()) - .setRpcId(rpcId); + pair.payloadBuilder.setStatusCode(status.getCode().value()); String statusDescription = status.getDescription(); if (statusDescription != null) { - logEntryBuilder.setStatusMessage(statusDescription); + pair.payloadBuilder.setStatusMessage(statusDescription); } byte[] statusDetailBytes = metadata.get(STATUS_DETAILS_KEY); if (statusDetailBytes != null) { - logEntryBuilder.setStatusDetails(ByteString.copyFrom(statusDetailBytes)); + pair.payloadBuilder.setStatusDetails(ByteString.copyFrom(statusDetailBytes)); } + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() + .setSequenceId(seqId) + .setServiceName(serviceName) + .setMethodName(methodName) + .setAuthority(authority) + .setType(EventType.SERVER_TRAILER) + .setLogger(eventLogger) + .setPayload(pair.payloadBuilder) + .setPayloadTruncated(pair.truncated) + .setCallId(callId); if (peerAddress != null) { - logEntryBuilder.setPeerAddress(socketAddressToProto(peerAddress)); + logEntryBuilder.setPeer(socketAddressToProto(peerAddress)); } sink.write(logEntryBuilder.build()); } @@ -217,17 +214,19 @@ void logRpcMessage( long seqId, String serviceName, String methodName, + String authority, EventType eventType, T message, int maxMessageBytes, EventLogger eventLogger, - String rpcId) { + String callId) { checkNotNull(serviceName, "serviceName"); checkNotNull(methodName, "methodName"); - checkNotNull(rpcId, "rpcId"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); checkArgument( - eventType == EventType.GRPC_CALL_REQUEST_MESSAGE - || eventType == EventType.GRPC_CALL_RESPONSE_MESSAGE, + eventType == EventType.CLIENT_MESSAGE + || eventType == EventType.SERVER_MESSAGE, "event type must correspond to client message or server message"); checkNotNull(message, "message"); @@ -241,27 +240,23 @@ void logRpcMessage( } else if (message instanceof byte[]) { messageBytesArray = (byte[]) message; } else { - logger.log(Level.WARNING, "message is of UNKNOWN type, message and payload_size fields" + logger.log(Level.WARNING, "message is of UNKNOWN type, message and payload_size fields " + "of GrpcLogRecord proto will not be logged"); } - PayloadBuilder pair = null; + PayloadBuilderHelper pair = null; if (messageBytesArray != null) { pair = createMessageProto(messageBytesArray, maxMessageBytes); } - - GrpcLogRecord.Builder logEntryBuilder = createTimestamp() + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(eventType) - .setEventLogger(eventLogger) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setRpcId(rpcId); - if (pair != null && pair.size != 0) { - logEntryBuilder.setPayloadSize(pair.size); - } - if (pair != null && pair.payload != null) { - logEntryBuilder.setMessage(pair.payload) + .setAuthority(authority) + .setType(eventType) + .setLogger(eventLogger) + .setCallId(callId); + if (pair != null) { + logEntryBuilder.setPayload(pair.payloadBuilder) .setPayloadTruncated(pair.truncated); } sink.write(logEntryBuilder.build()); @@ -274,20 +269,22 @@ void logHalfClose( long seqId, String serviceName, String methodName, + String authority, GrpcLogRecord.EventLogger eventLogger, - String rpcId) { + String callId) { checkNotNull(serviceName, "serviceName"); checkNotNull(methodName, "methodName"); - checkNotNull(rpcId, "rpcId"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); - GrpcLogRecord.Builder logEntryBuilder = createTimestamp() + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_HALF_CLOSE) - .setEventLogger(eventLogger) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setRpcId(rpcId); + .setAuthority(authority) + .setType(EventType.CLIENT_HALF_CLOSE) + .setLogger(eventLogger) + .setCallId(callId); sink.write(logEntryBuilder.build()); } @@ -298,28 +295,25 @@ void logCancel( long seqId, String serviceName, String methodName, + String authority, GrpcLogRecord.EventLogger eventLogger, - String rpcId) { + String callId) { checkNotNull(serviceName, "serviceName"); checkNotNull(methodName, "methodName"); - checkNotNull(rpcId, "rpcId"); + checkNotNull(authority, "authority"); + checkNotNull(callId, "callId"); - GrpcLogRecord.Builder logEntryBuilder = createTimestamp() + GrpcLogRecord.Builder logEntryBuilder = GrpcLogRecord.newBuilder() .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_CANCEL) - .setEventLogger(eventLogger) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setRpcId(rpcId); + .setAuthority(authority) + .setType(EventType.CANCEL) + .setLogger(eventLogger) + .setCallId(callId); sink.write(logEntryBuilder.build()); } - GrpcLogRecord.Builder createTimestamp() { - long nanos = timeProvider.currentTimeNanos(); - return GrpcLogRecord.newBuilder().setTimestamp(Timestamps.fromNanos(nanos)); - } - // TODO(DNVindhya): Evaluate if we need following clause for metadata logging in GcpObservability // Leaving the implementation for now as is to have same behavior across Java and Go private static final Set NEVER_INCLUDED_METADATA = new HashSet<>( @@ -331,58 +325,65 @@ GrpcLogRecord.Builder createTimestamp() { Collections.singletonList( "grpc-trace-bin")); - static final class PayloadBuilder { - T payload; - int size; + static final class PayloadBuilderHelper { + T payloadBuilder; boolean truncated; - private PayloadBuilder(T payload, int size, boolean truncated) { - this.payload = payload; - this.size = size; + private PayloadBuilderHelper(T payload, boolean truncated) { + this.payloadBuilder = payload; this.truncated = truncated; } } - static PayloadBuilder createMetadataProto(Metadata metadata, + static PayloadBuilderHelper createMetadataProto(Metadata metadata, int maxHeaderBytes) { checkNotNull(metadata, "metadata"); checkArgument(maxHeaderBytes >= 0, "maxHeaderBytes must be non negative"); - GrpcLogRecord.Metadata.Builder metadataBuilder = GrpcLogRecord.Metadata.newBuilder(); - // This code is tightly coupled with io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata - // implementation - byte[][] serialized = InternalMetadata.serialize(metadata); + Joiner joiner = Joiner.on(",").skipNulls(); + Payload.Builder payloadBuilder = Payload.newBuilder(); boolean truncated = false; int totalMetadataBytes = 0; - if (serialized != null) { - // Calculate bytes for each GrpcLogRecord.Metadata.MetadataEntry - for (int i = 0; i < serialized.length; i += 2) { - String key = new String(serialized[i], Charsets.UTF_8); - byte[] value = serialized[i + 1]; - if (NEVER_INCLUDED_METADATA.contains(key)) { - continue; - } - boolean forceInclude = ALWAYS_INCLUDED_METADATA.contains(key); - int metadataBytesAfterAdd = totalMetadataBytes + key.length() + value.length; - if (!forceInclude && metadataBytesAfterAdd > maxHeaderBytes) { - truncated = true; - continue; - } - metadataBuilder.addEntryBuilder() - .setKey(key) - .setValue(ByteString.copyFrom(value)); - if (!forceInclude) { - // force included keys do not count towards the size limit - totalMetadataBytes = metadataBytesAfterAdd; - } + for (String key : metadata.keys()) { + if (NEVER_INCLUDED_METADATA.contains(key)) { + continue; + } + boolean forceInclude = ALWAYS_INCLUDED_METADATA.contains(key); + String metadataValue; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + Iterable metadataValues = + metadata.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + List numList = new ArrayList(); + metadataValues.forEach( + (element) -> { + numList.add(BASE64_ENCODING_OMIT_PADDING.encode(element)); + }); + metadataValue = joiner.join(numList); + } else { + Iterable metadataValues = metadata.getAll( + Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + metadataValue = joiner.join(metadataValues); + } + + int metadataBytesAfterAdd = totalMetadataBytes + key.length() + metadataValue.length(); + if (!forceInclude && metadataBytesAfterAdd > maxHeaderBytes) { + truncated = true; + continue; + } + payloadBuilder.putMetadata(key, metadataValue); + if (!forceInclude) { + // force included keys do not count towards the size limit + totalMetadataBytes = metadataBytesAfterAdd; } } - return new PayloadBuilder<>(metadataBuilder, totalMetadataBytes, truncated); + return new PayloadBuilderHelper<>(payloadBuilder, truncated); } - static PayloadBuilder createMessageProto(byte[] message, int maxMessageBytes) { + static PayloadBuilderHelper createMessageProto( + byte[] message, int maxMessageBytes) { checkArgument(maxMessageBytes >= 0, "maxMessageBytes must be non negative"); + Payload.Builder payloadBuilder = Payload.newBuilder(); int desiredBytes = 0; int messageLength = message.length; if (maxMessageBytes > 0) { @@ -390,8 +391,10 @@ static PayloadBuilder createMessageProto(byte[] message, int maxMess } ByteString messageData = ByteString.copyFrom(message, 0, desiredBytes); + payloadBuilder.setMessage(messageData); + payloadBuilder.setMessageLength(messageLength); - return new PayloadBuilder<>(messageData, messageLength, + return new PayloadBuilderHelper<>(payloadBuilder, maxMessageBytes < message.length); } diff --git a/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java b/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java index 0209677aae9..e91f310e647 100644 --- a/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java +++ b/gcp-observability/src/main/java/io/grpc/gcp/observability/logging/GcpLogSink.java @@ -33,6 +33,7 @@ import io.grpc.internal.JsonParser; import io.grpc.observabilitylog.v1.GrpcLogRecord; import java.io.IOException; +import java.time.Instant; import java.util.Collection; import java.util.Collections; import java.util.Map; @@ -50,41 +51,37 @@ public class GcpLogSink implements Sink { private static final String DEFAULT_LOG_NAME = "microservices.googleapis.com%2Fobservability%2Fgrpc"; + private static final Severity DEFAULT_LOG_LEVEL = Severity.DEBUG; private static final String K8S_MONITORED_RESOURCE_TYPE = "k8s_container"; private static final Set kubernetesResourceLabelSet = ImmutableSet.of("project_id", "location", "cluster_name", "namespace_name", "pod_name", "container_name"); - private static final long FALLBACK_FLUSH_LIMIT = 100L; private final String projectId; private final Map customTags; private final MonitoredResource kubernetesResource; - private final Long flushLimit; /** Lazily initialize cloud logging client to avoid circular initialization. Because cloud * logging APIs also uses gRPC. */ private volatile Logging gcpLoggingClient; - private long flushCounter; private final Collection servicesToExclude; @VisibleForTesting - GcpLogSink(Logging loggingClient, String destinationProjectId, Map locationTags, - Map customTags, Long flushLimit, Collection servicesToExclude) { - this(destinationProjectId, locationTags, customTags, flushLimit, servicesToExclude); + GcpLogSink(Logging loggingClient, String projectId, Map locationTags, + Map customTags, Collection servicesToExclude) { + this(projectId, locationTags, customTags, servicesToExclude); this.gcpLoggingClient = loggingClient; } /** * Retrieves a single instance of GcpLogSink. * - * @param destinationProjectId cloud project id to write logs + * @param projectId GCP project id to write logs * @param servicesToExclude service names for which log entries should not be generated */ - public GcpLogSink(String destinationProjectId, Map locationTags, - Map customTags, Long flushLimit, Collection servicesToExclude) { - this.projectId = destinationProjectId; - this.customTags = getCustomTags(customTags, locationTags, destinationProjectId); + public GcpLogSink(String projectId, Map locationTags, + Map customTags, Collection servicesToExclude) { + this.projectId = projectId; + this.customTags = getCustomTags(customTags, locationTags, projectId); this.kubernetesResource = getResource(locationTags); - this.flushLimit = flushLimit != null ? flushLimit : FALLBACK_FLUSH_LIMIT; - this.flushCounter = 0L; this.servicesToExclude = checkNotNull(servicesToExclude, "servicesToExclude"); } @@ -106,28 +103,24 @@ public void write(GrpcLogRecord logProto) { return; } try { - GrpcLogRecord.EventType event = logProto.getEventType(); - Severity logEntrySeverity = getCloudLoggingLevel(logProto.getLogLevel()); + GrpcLogRecord.EventType eventType = logProto.getType(); // TODO(DNVindhya): make sure all (int, long) values are not displayed as double // For now, every value is being converted as string because of JsonFormat.printer().print + Map logProtoMap = protoToMapConverter(logProto); LogEntry.Builder grpcLogEntryBuilder = - LogEntry.newBuilder(JsonPayload.of(protoToMapConverter(logProto))) - .setSeverity(logEntrySeverity) + LogEntry.newBuilder(JsonPayload.of(logProtoMap)) + .setSeverity(DEFAULT_LOG_LEVEL) .setLogName(DEFAULT_LOG_NAME) - .setResource(kubernetesResource); + .setResource(kubernetesResource) + .setTimestamp(Instant.now()); if (!customTags.isEmpty()) { grpcLogEntryBuilder.setLabels(customTags); } LogEntry grpcLogEntry = grpcLogEntryBuilder.build(); synchronized (this) { - logger.log(Level.FINEST, "Writing gRPC event : {0} to Cloud Logging", event); + logger.log(Level.FINEST, "Writing gRPC event : {0} to Cloud Logging", eventType); gcpLoggingClient.write(Collections.singleton(grpcLogEntry)); - flushCounter = ++flushCounter; - if (flushCounter >= flushLimit) { - gcpLoggingClient.flush(); - flushCounter = 0L; - } } } catch (Exception e) { logger.log(Level.SEVERE, "Caught exception while writing to Cloud Logging", e); @@ -144,12 +137,12 @@ Logging createLoggingClient() { @VisibleForTesting static Map getCustomTags(Map customTags, - Map locationTags, String destinationProjectId) { + Map locationTags, String projectId) { ImmutableMap.Builder tagsBuilder = ImmutableMap.builder(); String sourceProjectId = locationTags.get("project_id"); - if (!Strings.isNullOrEmpty(destinationProjectId) + if (!Strings.isNullOrEmpty(projectId) && !Strings.isNullOrEmpty(sourceProjectId) - && !Objects.equals(sourceProjectId, destinationProjectId)) { + && !Objects.equals(sourceProjectId, projectId)) { tagsBuilder.put("source_project_id", sourceProjectId); } if (customTags != null) { @@ -175,29 +168,11 @@ static MonitoredResource getResource(Map resourceTags) { @SuppressWarnings("unchecked") private Map protoToMapConverter(GrpcLogRecord logProto) throws IOException { - JsonFormat.Printer printer = JsonFormat.printer().preservingProtoFieldNames(); + JsonFormat.Printer printer = JsonFormat.printer(); String recordJson = printer.print(logProto); return (Map) JsonParser.parse(recordJson); } - private Severity getCloudLoggingLevel(GrpcLogRecord.LogLevel recordLevel) { - switch (recordLevel.getNumber()) { - case 1: // GrpcLogRecord.LogLevel.LOG_LEVEL_TRACE - case 2: // GrpcLogRecord.LogLevel.LOG_LEVEL_DEBUG - return Severity.DEBUG; - case 3: // GrpcLogRecord.LogLevel.LOG_LEVEL_INFO - return Severity.INFO; - case 4: // GrpcLogRecord.LogLevel.LOG_LEVEL_WARN - return Severity.WARNING; - case 5: // GrpcLogRecord.LogLevel.LOG_LEVEL_ERROR - return Severity.ERROR; - case 6: // GrpcLogRecord.LogLevel.LOG_LEVEL_CRITICAL - return Severity.CRITICAL; - default: - return Severity.DEFAULT; - } - } - /** * Closes Cloud Logging Client. */ diff --git a/gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto b/gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto index a37ac6f43d0..85ef00ac2dd 100644 --- a/gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto +++ b/gcp-observability/src/main/proto/grpc/observabilitylog/v1/observabilitylog.proto @@ -28,151 +28,99 @@ option java_outer_classname = "ObservabilityLogProto"; message GrpcLogRecord { // List of event types enum EventType { - GRPC_CALL_UNKNOWN = 0; + EVENT_TYPE_UNKNOWN = 0; // Header sent from client to server - GRPC_CALL_REQUEST_HEADER = 1; + CLIENT_HEADER = 1; // Header sent from server to client - GRPC_CALL_RESPONSE_HEADER = 2; + SERVER_HEADER = 2; // Message sent from client to server - GRPC_CALL_REQUEST_MESSAGE = 3; + CLIENT_MESSAGE = 3; // Message sent from server to client - GRPC_CALL_RESPONSE_MESSAGE = 4; - // Trailer indicates the end of the gRPC call - GRPC_CALL_TRAILER = 5; + SERVER_MESSAGE = 4; // A signal that client is done sending - GRPC_CALL_HALF_CLOSE = 6; + CLIENT_HALF_CLOSE = 5; + // Trailer indicates the end of the gRPC call + SERVER_TRAILER = 6; // A signal that the rpc is canceled - GRPC_CALL_CANCEL = 7; + CANCEL = 7; } + // The entity that generates the log entry enum EventLogger { LOGGER_UNKNOWN = 0; - LOGGER_CLIENT = 1; - LOGGER_SERVER = 2; - } - // The log severity level of the log entry - enum LogLevel { - LOG_LEVEL_UNKNOWN = 0; - LOG_LEVEL_TRACE = 1; - LOG_LEVEL_DEBUG = 2; - LOG_LEVEL_INFO = 3; - LOG_LEVEL_WARN = 4; - LOG_LEVEL_ERROR = 5; - LOG_LEVEL_CRITICAL = 6; + CLIENT = 1; + SERVER = 2; } - // The timestamp of the log event - google.protobuf.Timestamp timestamp = 1; - - // Uniquely identifies a call. The value must not be 0 in order to disambiguate - // from an unset value. - // Each call may have several log entries. They will all have the same rpc_id. + // Uniquely identifies a call. + // Each call may have several log entries. They will all have the same call_id. // Nothing is guaranteed about their value other than they are unique across // different RPCs in the same gRPC process. - string rpc_id = 2; + string call_id = 2; - EventType event_type = 3; // one of the above EventType enum - EventLogger event_logger = 4; // one of the above EventLogger enum + // The entry sequence ID for this call. The first message has a value of 1, + // to disambiguate from an unset value. The purpose of this field is to + // detect missing entries in environments where durability or ordering is + // not guaranteed. + uint64 sequence_id = 3; - // the name of the service - string service_name = 5; - // the name of the RPC method - string method_name = 6; + EventType type = 4; // one of the above EventType enum + EventLogger logger = 5; // one of the above EventLogger enum - LogLevel log_level = 7; // one of the above LogLevel enum + // Payload for log entry. + // It can include a combination of {metadata, message, status based on type of + // the event event being logged and config options. + Payload payload = 6; + // true if message or metadata field is either truncated or omitted due + // to config options + bool payload_truncated = 7; // Peer address information. On client side, peer is logged on server // header event or trailer event (if trailer-only). On server side, peer // is always logged on the client header event. - Address peer_address = 8; - - // the RPC timeout value - google.protobuf.Duration timeout = 11; + Address peer = 8; // A single process may be used to run multiple virtual servers with // different identities. // The authority is the name of such a server identify. It is typically a // portion of the URI in the form of or :. - string authority = 12; - - // Size of the message or metadata, depending on the event type, - // regardless of whether the full message or metadata is being logged - // (i.e. could be truncated or omitted). - uint32 payload_size = 13; - - // true if message or metadata field is either truncated or omitted due - // to config options - bool payload_truncated = 14; - - // Used by header event or trailer event - Metadata metadata = 15; - - // The entry sequence ID for this call. The first message has a value of 1, - // to disambiguate from an unset value. The purpose of this field is to - // detect missing entries in environments where durability or ordering is - // not guaranteed. - uint64 sequence_id = 16; - - // Used by message event - bytes message = 17; + string authority = 10; + // the name of the service + string service_name = 11; + // the name of the RPC method + string method_name = 12; +} +message Payload { + // A list of metadata pairs + map metadata = 1; + // the RPC timeout value + google.protobuf.Duration timeout = 2; // The gRPC status code - uint32 status_code = 18; - + uint32 status_code = 3; // The gRPC status message - string status_message = 19; - + string status_message = 4; // The value of the grpc-status-details-bin metadata key, if any. // This is always an encoded google.rpc.Status message - bytes status_details = 20; - - // Attributes of the environment generating log record. The purpose of this - // field is to identify the source environment. - EnvironmentTags env_tags = 21; - - // A list of non-gRPC custom values specified by the application - repeated CustomTags custom_tags = 22; - - // A list of metadata pairs - message Metadata { - repeated MetadataEntry entry = 1; - } - - // One metadata key value pair - message MetadataEntry { - string key = 1; - bytes value = 2; - } - - // Address information - message Address { - enum Type { - TYPE_UNKNOWN = 0; - TYPE_IPV4 = 1; // in 1.2.3.4 form - TYPE_IPV6 = 2; // IPv6 canonical form (RFC5952 section 4) - TYPE_UNIX = 3; // UDS string - } - Type type = 1; - string address = 2; - // only for TYPE_IPV4 and TYPE_IPV6 - uint32 ip_port = 3; - } - - // Source Environment information - message EnvironmentTags { - string gcp_project_id = 1; - string gcp_numeric_project_id = 2; - string gce_instance_id = 3; - string gce_instance_hostname = 4; - string gce_instance_zone = 5; - string gke_cluster_uid = 6; - string gke_cluster_name = 7; - string gke_cluster_location = 8; - } + bytes status_details = 5; + // Size of the message or metadata, depending on the event type, + // regardless of whether the full message or metadata is being logged + // (i.e. could be truncated or omitted). + uint32 message_length = 6; + // Used by message event + bytes message = 7; +} - // Custom key value pair - message CustomTags { - string key = 1; - string value = 2; +// Address information +message Address { + enum Type { + TYPE_UNKNOWN = 0; + TYPE_IPV4 = 1; // in 1.2.3.4 form + TYPE_IPV6 = 2; // IPv6 canonical form (RFC5952 section 4) + TYPE_UNIX = 3; // UDS string } + Type type = 1; + string address = 2; + // only for TYPE_IPV4 and TYPE_IPV6 + uint32 ip_port = 3; } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java index aa6c2d55d8b..992ccc5dbf5 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTest.java @@ -17,15 +17,17 @@ package io.grpc.gcp.observability; import static com.google.common.truth.Truth.assertThat; -import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; import io.grpc.ManagedChannelBuilder; -import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.ServerBuilder; import io.grpc.StaticTestingClassLoader; @@ -36,9 +38,7 @@ import io.grpc.gcp.observability.interceptors.LogHelper; import io.grpc.gcp.observability.logging.GcpLogSink; import io.grpc.gcp.observability.logging.Sink; -import io.grpc.internal.TimeProvider; import io.grpc.observabilitylog.v1.GrpcLogRecord; -import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.protobuf.SimpleServiceGrpc; import java.io.IOException; @@ -49,6 +49,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; import org.mockito.Mockito; @RunWith(JUnit4.class) @@ -67,7 +68,6 @@ public class LoggingTest { private static final ImmutableMap CUSTOM_TAGS = ImmutableMap.of( "KEY1", "Value1", "KEY2", "VALUE2"); - private static final long FLUSH_LIMIT = 100L; private final StaticTestingClassLoader classLoader = new StaticTestingClassLoader(getClass().getClassLoader(), Pattern.compile("io\\.grpc\\..*")); @@ -100,9 +100,9 @@ public void clientServer_interceptorCalled_logNever() throws Exception { } @Test - public void clientServer_interceptorCalled_logFewEvents() throws Exception { + public void clientServer_interceptorCalled_logEvents_usingMockSink() throws Exception { Class runnable = - classLoader.loadClass(LoggingTest.StaticTestingClassLogFewEvents.class.getName()); + classLoader.loadClass(StaticTestingClassLogEventsUsingMockSink.class.getName()); ((Runnable) runnable.getDeclaredConstructor().newInstance()).run(); } @@ -113,9 +113,9 @@ public static final class StaticTestingClassEndtoEndLogging implements Runnable public void run() { Sink sink = new GcpLogSink( - PROJECT_ID, LOCATION_TAGS, CUSTOM_TAGS, FLUSH_LIMIT, Collections.emptySet()); + PROJECT_ID, LOCATION_TAGS, CUSTOM_TAGS, Collections.emptySet()); ObservabilityConfig config = mock(ObservabilityConfig.class); - LogHelper spyLogHelper = spy(new LogHelper(sink, TimeProvider.SYSTEM_TIME_PROVIDER)); + LogHelper spyLogHelper = spy(new LogHelper(sink)); ConfigFilterHelper mockFilterHelper = mock(ConfigFilterHelper.class); InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = new InternalLoggingChannelInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); @@ -123,17 +123,18 @@ public void run() { new InternalLoggingServerInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); when(config.isEnableCloudLogging()).thenReturn(true); - FilterParams logAlwaysFilterParams = FilterParams.create(true, 0, 0); - when(mockFilterHelper.isMethodToBeLogged(any(MethodDescriptor.class))) + FilterParams logAlwaysFilterParams = FilterParams.create(true, 1024, 10); + when(mockFilterHelper.logRpcMethod(anyString(), eq(true))) + .thenReturn(logAlwaysFilterParams); + when(mockFilterHelper.logRpcMethod(anyString(), eq(false))) .thenReturn(logAlwaysFilterParams); - when(mockFilterHelper.isEventToBeLogged(any(GrpcLogRecord.EventType.class))).thenReturn(true); try (GcpObservability unused = GcpObservability.grpcInit( sink, config, channelInterceptorFactory, serverInterceptorFactory)) { Server server = ServerBuilder.forPort(0) - .addService(new LoggingTestHelper.SimpleServiceImpl()) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) .build() .start(); int port = cleanupRule.register(server).getPort(); @@ -141,7 +142,7 @@ public void run() { SimpleServiceGrpc.newBlockingStub( cleanupRule.register( ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); - assertThat(LoggingTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) .isEqualTo("Hello buddy"); assertThat(Mockito.mockingDetails(spyLogHelper).getInvocations().size()).isGreaterThan(11); } catch (IOException e) { @@ -156,7 +157,7 @@ public static final class StaticTestingClassLogNever implements Runnable { public void run() { Sink mockSink = mock(GcpLogSink.class); ObservabilityConfig config = mock(ObservabilityConfig.class); - LogHelper spyLogHelper = spy(new LogHelper(mockSink, TimeProvider.SYSTEM_TIME_PROVIDER)); + LogHelper spyLogHelper = spy(new LogHelper(mockSink)); ConfigFilterHelper mockFilterHelper = mock(ConfigFilterHelper.class); InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = new InternalLoggingChannelInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper); @@ -165,16 +166,17 @@ public void run() { when(config.isEnableCloudLogging()).thenReturn(true); FilterParams logNeverFilterParams = FilterParams.create(false, 0, 0); - when(mockFilterHelper.isMethodToBeLogged(any(MethodDescriptor.class))) + when(mockFilterHelper.logRpcMethod(anyString(), eq(true))) + .thenReturn(logNeverFilterParams); + when(mockFilterHelper.logRpcMethod(anyString(), eq(false))) .thenReturn(logNeverFilterParams); - when(mockFilterHelper.isEventToBeLogged(any(GrpcLogRecord.EventType.class))).thenReturn(true); try (GcpObservability unused = GcpObservability.grpcInit( mockSink, config, channelInterceptorFactory, serverInterceptorFactory)) { Server server = ServerBuilder.forPort(0) - .addService(new LoggingTestHelper.SimpleServiceImpl()) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) .build() .start(); int port = cleanupRule.register(server).getPort(); @@ -182,7 +184,7 @@ public void run() { SimpleServiceGrpc.newBlockingStub( cleanupRule.register( ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); - assertThat(LoggingTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) .isEqualTo("Hello buddy"); verifyNoInteractions(spyLogHelper); verifyNoInteractions(mockSink); @@ -192,41 +194,32 @@ public void run() { } } - public static final class StaticTestingClassLogFewEvents implements Runnable { + public static final class StaticTestingClassLogEventsUsingMockSink implements Runnable { @Override public void run() { Sink mockSink = mock(GcpLogSink.class); ObservabilityConfig config = mock(ObservabilityConfig.class); - LogHelper mockLogHelper = mock(LogHelper.class); + LogHelper spyLogHelper = spy(new LogHelper(mockSink)); ConfigFilterHelper mockFilterHelper2 = mock(ConfigFilterHelper.class); InternalLoggingChannelInterceptor.Factory channelInterceptorFactory = - new InternalLoggingChannelInterceptor.FactoryImpl(mockLogHelper, mockFilterHelper2); + new InternalLoggingChannelInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper2); InternalLoggingServerInterceptor.Factory serverInterceptorFactory = - new InternalLoggingServerInterceptor.FactoryImpl(mockLogHelper, mockFilterHelper2); + new InternalLoggingServerInterceptor.FactoryImpl(spyLogHelper, mockFilterHelper2); when(config.isEnableCloudLogging()).thenReturn(true); FilterParams logAlwaysFilterParams = FilterParams.create(true, 0, 0); - when(mockFilterHelper2.isMethodToBeLogged(any(MethodDescriptor.class))) + when(mockFilterHelper2.logRpcMethod(anyString(), eq(true))) + .thenReturn(logAlwaysFilterParams); + when(mockFilterHelper2.logRpcMethod(anyString(), eq(false))) .thenReturn(logAlwaysFilterParams); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_REQUEST_HEADER)) - .thenReturn(true); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_RESPONSE_HEADER)) - .thenReturn(true); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_HALF_CLOSE)).thenReturn(true); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_TRAILER)).thenReturn(true); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_CANCEL)).thenReturn(true); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_REQUEST_MESSAGE)) - .thenReturn(false); - when(mockFilterHelper2.isEventToBeLogged(EventType.GRPC_CALL_RESPONSE_MESSAGE)) - .thenReturn(false); try (GcpObservability observability = GcpObservability.grpcInit( mockSink, config, channelInterceptorFactory, serverInterceptorFactory)) { Server server = ServerBuilder.forPort(0) - .addService(new LoggingTestHelper.SimpleServiceImpl()) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) .build() .start(); int port = cleanupRule.register(server).getPort(); @@ -234,7 +227,7 @@ public void run() { SimpleServiceGrpc.newBlockingStub( cleanupRule.register( ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); - assertThat(LoggingTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) .isEqualTo("Hello buddy"); // Total number of calls should have been 14 (6 from client and 6 from server) // Since cancel is not invoked, it will be 12. @@ -242,9 +235,15 @@ public void run() { // message(count:2) // events are not in the event_types list, i.e 14 - 2(cancel) - 2(req_msg) - 2(resp_msg) // = 8 - assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(8); + assertThat(Mockito.mockingDetails(mockSink).getInvocations().size()).isEqualTo(12); + ArgumentCaptor captor = ArgumentCaptor.forClass(GrpcLogRecord.class); + verify(mockSink, times(12)).write(captor.capture()); + for (GrpcLogRecord record : captor.getAllValues()) { + assertThat(record.getType()).isInstanceOf(GrpcLogRecord.EventType.class); + assertThat(record.getLogger()).isInstanceOf(GrpcLogRecord.EventLogger.class); + } } catch (IOException e) { - throw new AssertionError("Exception while testing logging event filter", e); + throw new AssertionError("Exception while testing logging using mock sink", e); } } } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java index f967b99fbcb..046799cc9d2 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/MetricsTest.java @@ -97,7 +97,7 @@ public void run() { mock(InternalLoggingServerInterceptor.Factory.class); when(mockConfig.isEnableCloudMonitoring()).thenReturn(true); - when(mockConfig.getDestinationProjectId()).thenReturn(PROJECT_ID); + when(mockConfig.getProjectId()).thenReturn(PROJECT_ID); try { GcpObservability observability = @@ -107,7 +107,7 @@ public void run() { Server server = ServerBuilder.forPort(0) - .addService(new LoggingTestHelper.SimpleServiceImpl()) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) .build() .start(); int port = cleanupRule.register(server).getPort(); @@ -115,7 +115,7 @@ public void run() { SimpleServiceGrpc.newBlockingStub( cleanupRule.register( ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); - assertThat(LoggingTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) .isEqualTo("Hello buddy"); // Adding sleep to ensure metrics are exported before querying cloud monitoring backend TimeUnit.SECONDS.sleep(40); diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java index 821dcd43ee4..d305583d53f 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityConfigImplTest.java @@ -18,23 +18,25 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import com.google.common.base.Charsets; -import com.google.common.collect.ImmutableList; import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; -import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; import io.opencensus.trace.Sampler; import io.opencensus.trace.samplers.Samplers; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -43,88 +45,172 @@ @RunWith(JUnit4.class) public class ObservabilityConfigImplTest { - private static final String EVENT_TYPES = "{\n" - + " \"enable_cloud_logging\": false,\n" - + " \"event_types\": " - + "[\"GRPC_CALL_REQUEST_HEADER\", \"GRPC_CALL_HALF_CLOSE\", \"GRPC_CALL_TRAILER\"]\n" + private static final String LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"*\"],\n" + + " \"max_metadata_bytes\": 4096\n" + + " }" + + " ],\n" + + " \"server_rpc_events\": [{\n" + + " \"methods\": [\"*\"],\n" + + " \"max_metadata_bytes\": 32,\n" + + " \"max_message_bytes\": 64\n" + + " }" + + " ]\n" + + " }\n" + "}"; - private static final String LOG_FILTERS = "{\n" - + " \"enable_cloud_logging\": true,\n" - + " \"destination_project_id\": \"grpc-testing\",\n" - + " \"flush_message_count\": 1000,\n" - + " \"log_filters\": [{\n" - + " \"pattern\": \"*/*\",\n" - + " \"header_bytes\": 4096,\n" - + " \"message_bytes\": 2048\n" + private static final String CLIENT_LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"*\"],\n" + + " \"max_metadata_bytes\": 4096,\n" + + " \"max_message_bytes\": 2048\n" + " }," + " {\n" - + " \"pattern\": \"service1/Method2\"\n" + + " \"methods\": [\"service1/Method2\", \"Service2/*\"],\n" + + " \"exclude\": true\n" + " }" + " ]\n" + + " }\n" + "}"; - private static final String DEST_PROJECT_ID = "{\n" - + " \"enable_cloud_logging\": true,\n" - + " \"destination_project_id\": \"grpc-testing\"\n" - + "}"; + private static final String SERVER_LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"server_rpc_events\": [{\n" + + " \"methods\": [\"service1/method4\", \"service2/method234\"],\n" + + " \"max_metadata_bytes\": 32,\n" + + " \"max_message_bytes\": 64\n" + + " }," + + " {\n" + + " \"methods\": [\"service4/*\", \"Service2/*\"],\n" + + " \"exclude\": true\n" + + " }" + + " ]\n" + + " }\n" + + "}"; - private static final String FLUSH_MESSAGE_COUNT = "{\n" - + " \"enable_cloud_logging\": true,\n" - + " \"flush_message_count\": 500\n" + private static final String VALID_LOG_FILTERS = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"server_rpc_events\": [{\n" + + " \"methods\": [\"service.Service1/*\", \"service2.Service4/method4\"],\n" + + " \"max_metadata_bytes\": 16,\n" + + " \"max_message_bytes\": 64\n" + + " }" + + " ]\n" + + " }\n" + "}"; - private static final String DISABLE_CLOUD_LOGGING = "{\n" - + " \"enable_cloud_logging\": false\n" + + private static final String PROJECT_ID = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {},\n" + + " \"project_id\": \"grpc-testing\"\n" + "}"; + private static final String EMPTY_CONFIG = "{}"; + private static final String ENABLE_CLOUD_MONITORING_AND_TRACING = "{\n" - + " \"enable_cloud_monitoring\": true,\n" - + " \"enable_cloud_trace\": true\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_monitoring\": {},\n" + + " \"cloud_trace\": {}\n" + + "}"; + + private static final String ENABLE_CLOUD_MONITORING = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_monitoring\": {}\n" + "}"; - private static final String GLOBAL_TRACING_ALWAYS_SAMPLER = "{\n" - + " \"enable_cloud_trace\": true,\n" - + " \"global_trace_sampling_rate\": 1.00\n" + private static final String ENABLE_CLOUD_TRACE = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {}\n" + "}"; - private static final String GLOBAL_TRACING_NEVER_SAMPLER = "{\n" - + " \"enable_cloud_trace\": true,\n" - + " \"global_trace_sampling_rate\": 0.00\n" + private static final String TRACING_ALWAYS_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": 1.00\n" + + " }\n" + "}"; - private static final String GLOBAL_TRACING_PROBABILISTIC_SAMPLER = "{\n" - + " \"enable_cloud_trace\": true,\n" - + " \"global_trace_sampling_rate\": 0.75\n" + private static final String TRACING_NEVER_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": 0.00\n" + + " }\n" + "}"; - private static final String GLOBAL_TRACING_DEFAULT_SAMPLER = "{\n" - + " \"enable_cloud_trace\": true\n" + private static final String TRACING_PROBABILISTIC_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": 0.75\n" + + " }\n" + "}"; - private static final String GLOBAL_TRACING_BADPROBABILISTIC_SAMPLER = "{\n" - + " \"enable_cloud_tracing\": true,\n" - + " \"global_trace_sampling_rate\": -0.75\n" + private static final String TRACING_DEFAULT_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {}\n" + + "}"; + + private static final String GLOBAL_TRACING_BAD_PROBABILISTIC_SAMPLER = "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_trace\": {\n" + + " \"sampling_rate\": -0.75\n" + + " }\n" + "}"; private static final String CUSTOM_TAGS = "{\n" - + " \"enable_cloud_logging\": true,\n" - + " \"custom_tags\": {\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {},\n" + + " \"labels\": {\n" + " \"SOURCE_VERSION\" : \"J2e1Cf\",\n" + " \"SERVICE_NAME\" : \"payment-service\",\n" + " \"ENTRYPOINT_SCRIPT\" : \"entrypoint.sh\"\n" + " }\n" + "}"; - private static final String BAD_CUSTOM_TAGS = "{\n" - + " \"enable_cloud_monitoring\": true,\n" - + " \"custom_tags\": {\n" + private static final String BAD_CUSTOM_TAGS = + "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_monitoring\": {},\n" + + " \"labels\": {\n" + " \"SOURCE_VERSION\" : \"J2e1Cf\",\n" + " \"SERVICE_NAME\" : { \"SUB_SERVICE_NAME\" : \"payment-service\"},\n" + " \"ENTRYPOINT_SCRIPT\" : \"entrypoint.sh\"\n" + " }\n" + "}"; + private static final String LOG_FILTER_GLOBAL_EXCLUDE = + "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"service1/Method2\", \"*\"],\n" + + " \"max_metadata_bytes\": 20,\n" + + " \"max_message_bytes\": 15,\n" + + " \"exclude\": true\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + + private static final String LOG_FILTER_INVALID_METHOD = + "{\n" + + " \"project_id\": \"grpc-testing\",\n" + + " \"cloud_logging\": {\n" + + " \"client_rpc_events\": [{\n" + + " \"methods\": [\"s*&%ervice1/Method2\", \"*\"],\n" + + " \"max_metadata_bytes\": 20\n" + + " }" + + " ]\n" + + " }\n" + + "}"; + ObservabilityConfigImpl observabilityConfig = new ObservabilityConfigImpl(); @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); @@ -135,72 +221,126 @@ public void nullConfig() throws IOException { observabilityConfig.parse(null); fail("exception expected!"); } catch (IllegalArgumentException iae) { - assertThat(iae.getMessage()).isEqualTo("GRPC_CONFIG_OBSERVABILITY value is null!"); + assertThat(iae.getMessage()).isEqualTo("GRPC_GCP_OBSERVABILITY_CONFIG value is null!"); } } @Test public void emptyConfig() throws IOException { - observabilityConfig.parse("{}"); + observabilityConfig.parse(EMPTY_CONFIG); assertFalse(observabilityConfig.isEnableCloudLogging()); assertFalse(observabilityConfig.isEnableCloudMonitoring()); assertFalse(observabilityConfig.isEnableCloudTracing()); - assertNull(observabilityConfig.getDestinationProjectId()); - assertNull(observabilityConfig.getFlushMessageCount()); - assertNull(observabilityConfig.getLogFilters()); - assertNull(observabilityConfig.getEventTypes()); + assertThat(observabilityConfig.getClientLogFilters()).isEmpty(); + assertThat(observabilityConfig.getServerLogFilters()).isEmpty(); + assertThat(observabilityConfig.getSampler()).isNull(); + assertThat(observabilityConfig.getProjectId()).isNull(); + assertThat(observabilityConfig.getCustomTags()).isEmpty(); } @Test - public void disableCloudLogging() throws IOException { - observabilityConfig.parse(DISABLE_CLOUD_LOGGING); - assertFalse(observabilityConfig.isEnableCloudLogging()); - assertFalse(observabilityConfig.isEnableCloudMonitoring()); - assertFalse(observabilityConfig.isEnableCloudTracing()); - assertNull(observabilityConfig.getDestinationProjectId()); - assertNull(observabilityConfig.getFlushMessageCount()); - assertNull(observabilityConfig.getLogFilters()); - assertNull(observabilityConfig.getEventTypes()); + public void emptyConfigFile() throws IOException { + File configFile = tempFolder.newFile(); + try { + observabilityConfig.parseFile(configFile.getAbsolutePath()); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo( + "GRPC_GCP_OBSERVABILITY_CONFIG_FILE is empty!"); + } } @Test - public void destProjectId() throws IOException { - observabilityConfig.parse(DEST_PROJECT_ID); + public void setProjectId() throws IOException { + observabilityConfig.parse(PROJECT_ID); assertTrue(observabilityConfig.isEnableCloudLogging()); - assertThat(observabilityConfig.getDestinationProjectId()).isEqualTo("grpc-testing"); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); } @Test - public void flushMessageCount() throws Exception { - observabilityConfig.parse(FLUSH_MESSAGE_COUNT); + public void logFilters() throws IOException { + observabilityConfig.parse(LOG_FILTERS); assertTrue(observabilityConfig.isEnableCloudLogging()); - assertThat(observabilityConfig.getFlushMessageCount()).isEqualTo(500L); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + + List clientLogFilters = observabilityConfig.getClientLogFilters(); + assertThat(clientLogFilters).hasSize(1); + assertThat(clientLogFilters.get(0).headerBytes).isEqualTo(4096); + assertThat(clientLogFilters.get(0).messageBytes).isEqualTo(0); + assertThat(clientLogFilters.get(0).excludePattern).isFalse(); + assertThat(clientLogFilters.get(0).matchAll).isTrue(); + assertThat(clientLogFilters.get(0).services).isEmpty(); + assertThat(clientLogFilters.get(0).methods).isEmpty(); + + List serverLogFilters = observabilityConfig.getServerLogFilters(); + assertThat(serverLogFilters).hasSize(1); + assertThat(serverLogFilters.get(0).headerBytes).isEqualTo(32); + assertThat(serverLogFilters.get(0).messageBytes).isEqualTo(64); + assertThat(serverLogFilters.get(0).excludePattern).isFalse(); + assertThat(serverLogFilters.get(0).matchAll).isTrue(); + assertThat(serverLogFilters.get(0).services).isEmpty(); + assertThat(serverLogFilters.get(0).methods).isEmpty(); } @Test - public void logFilters() throws IOException { - observabilityConfig.parse(LOG_FILTERS); + public void setClientLogFilters() throws IOException { + observabilityConfig.parse(CLIENT_LOG_FILTERS); assertTrue(observabilityConfig.isEnableCloudLogging()); - assertThat(observabilityConfig.getDestinationProjectId()).isEqualTo("grpc-testing"); - assertThat(observabilityConfig.getFlushMessageCount()).isEqualTo(1000L); - List logFilters = observabilityConfig.getLogFilters(); - assertThat(logFilters).hasSize(2); - assertThat(logFilters.get(0).pattern).isEqualTo("*/*"); - assertThat(logFilters.get(0).headerBytes).isEqualTo(4096); - assertThat(logFilters.get(0).messageBytes).isEqualTo(2048); - assertThat(logFilters.get(1).pattern).isEqualTo("service1/Method2"); - assertThat(logFilters.get(1).headerBytes).isNull(); - assertThat(logFilters.get(1).messageBytes).isNull(); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + List logFilterList = observabilityConfig.getClientLogFilters(); + assertThat(logFilterList).hasSize(2); + assertThat(logFilterList.get(0).headerBytes).isEqualTo(4096); + assertThat(logFilterList.get(0).messageBytes).isEqualTo(2048); + assertThat(logFilterList.get(0).excludePattern).isFalse(); + assertThat(logFilterList.get(0).matchAll).isTrue(); + assertThat(logFilterList.get(0).services).isEmpty(); + assertThat(logFilterList.get(0).methods).isEmpty(); + + assertThat(logFilterList.get(1).headerBytes).isEqualTo(0); + assertThat(logFilterList.get(1).messageBytes).isEqualTo(0); + assertThat(logFilterList.get(1).excludePattern).isTrue(); + assertThat(logFilterList.get(1).matchAll).isFalse(); + assertThat(logFilterList.get(1).services).isEqualTo(Collections.singleton("Service2")); + assertThat(logFilterList.get(1).methods) + .isEqualTo(Collections.singleton("service1/Method2")); } @Test - public void eventTypes() throws IOException { - observabilityConfig.parse(EVENT_TYPES); - assertFalse(observabilityConfig.isEnableCloudLogging()); - List eventTypes = observabilityConfig.getEventTypes(); - assertThat(eventTypes).isEqualTo( - ImmutableList.of(EventType.GRPC_CALL_REQUEST_HEADER, EventType.GRPC_CALL_HALF_CLOSE, - EventType.GRPC_CALL_TRAILER)); + public void setServerLogFilters() throws IOException { + Set expectedMethods = Stream.of("service1/method4", "service2/method234") + .collect(Collectors.toCollection(HashSet::new)); + observabilityConfig.parse(SERVER_LOG_FILTERS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + List logFilterList = observabilityConfig.getServerLogFilters(); + assertThat(logFilterList).hasSize(2); + assertThat(logFilterList.get(0).headerBytes).isEqualTo(32); + assertThat(logFilterList.get(0).messageBytes).isEqualTo(64); + assertThat(logFilterList.get(0).excludePattern).isFalse(); + assertThat(logFilterList.get(0).matchAll).isFalse(); + assertThat(logFilterList.get(0).services).isEmpty(); + assertThat(logFilterList.get(0).methods) + .isEqualTo(expectedMethods); + + Set expectedServices = Stream.of("service4", "Service2") + .collect(Collectors.toCollection(HashSet::new)); + assertThat(logFilterList.get(1).headerBytes).isEqualTo(0); + assertThat(logFilterList.get(1).messageBytes).isEqualTo(0); + assertThat(logFilterList.get(1).excludePattern).isTrue(); + assertThat(logFilterList.get(1).matchAll).isFalse(); + assertThat(logFilterList.get(1).services).isEqualTo(expectedServices); + assertThat(logFilterList.get(1).methods).isEmpty(); + } + + @Test + public void enableCloudMonitoring() throws IOException { + observabilityConfig.parse(ENABLE_CLOUD_MONITORING); + assertTrue(observabilityConfig.isEnableCloudMonitoring()); + } + + @Test + public void enableCloudTracing() throws IOException { + observabilityConfig.parse(ENABLE_CLOUD_TRACE); + assertTrue(observabilityConfig.isEnableCloudTracing()); } @Test @@ -213,7 +353,7 @@ public void enableCloudMonitoringAndTracing() throws IOException { @Test public void alwaysSampler() throws IOException { - observabilityConfig.parse(GLOBAL_TRACING_ALWAYS_SAMPLER); + observabilityConfig.parse(TRACING_ALWAYS_SAMPLER); assertTrue(observabilityConfig.isEnableCloudTracing()); Sampler sampler = observabilityConfig.getSampler(); assertThat(sampler).isNotNull(); @@ -222,7 +362,7 @@ public void alwaysSampler() throws IOException { @Test public void neverSampler() throws IOException { - observabilityConfig.parse(GLOBAL_TRACING_NEVER_SAMPLER); + observabilityConfig.parse(TRACING_NEVER_SAMPLER); assertTrue(observabilityConfig.isEnableCloudTracing()); Sampler sampler = observabilityConfig.getSampler(); assertThat(sampler).isNotNull(); @@ -231,7 +371,7 @@ public void neverSampler() throws IOException { @Test public void probabilisticSampler() throws IOException { - observabilityConfig.parse(GLOBAL_TRACING_PROBABILISTIC_SAMPLER); + observabilityConfig.parse(TRACING_PROBABILISTIC_SAMPLER); assertTrue(observabilityConfig.isEnableCloudTracing()); Sampler sampler = observabilityConfig.getSampler(); assertThat(sampler).isNotNull(); @@ -240,7 +380,7 @@ public void probabilisticSampler() throws IOException { @Test public void defaultSampler() throws IOException { - observabilityConfig.parse(GLOBAL_TRACING_DEFAULT_SAMPLER); + observabilityConfig.parse(TRACING_DEFAULT_SAMPLER); assertTrue(observabilityConfig.isEnableCloudTracing()); Sampler sampler = observabilityConfig.getSampler(); assertThat(sampler).isNotNull(); @@ -250,30 +390,44 @@ public void defaultSampler() throws IOException { @Test public void badProbabilisticSampler_error() throws IOException { try { - observabilityConfig.parse(GLOBAL_TRACING_BADPROBABILISTIC_SAMPLER); + observabilityConfig.parse(GLOBAL_TRACING_BAD_PROBABILISTIC_SAMPLER); fail("exception expected!"); } catch (IllegalArgumentException iae) { assertThat(iae.getMessage()).isEqualTo( - "'global_trace_sampling_rate' needs to be between [0.0, 1.0]"); + "'sampling_rate' needs to be between [0.0, 1.0]"); } } @Test public void configFileLogFilters() throws Exception { File configFile = tempFolder.newFile(); - Files.write(Paths.get(configFile.getAbsolutePath()), LOG_FILTERS.getBytes(Charsets.US_ASCII)); + Files.write( + Paths.get(configFile.getAbsolutePath()), CLIENT_LOG_FILTERS.getBytes(Charsets.US_ASCII)); observabilityConfig.parseFile(configFile.getAbsolutePath()); assertTrue(observabilityConfig.isEnableCloudLogging()); - assertThat(observabilityConfig.getDestinationProjectId()).isEqualTo("grpc-testing"); - assertThat(observabilityConfig.getFlushMessageCount()).isEqualTo(1000L); - List logFilters = observabilityConfig.getLogFilters(); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + List logFilters = observabilityConfig.getClientLogFilters(); assertThat(logFilters).hasSize(2); - assertThat(logFilters.get(0).pattern).isEqualTo("*/*"); assertThat(logFilters.get(0).headerBytes).isEqualTo(4096); assertThat(logFilters.get(0).messageBytes).isEqualTo(2048); - assertThat(logFilters.get(1).pattern).isEqualTo("service1/Method2"); - assertThat(logFilters.get(1).headerBytes).isNull(); - assertThat(logFilters.get(1).messageBytes).isNull(); + assertThat(logFilters.get(1).headerBytes).isEqualTo(0); + assertThat(logFilters.get(1).messageBytes).isEqualTo(0); + + assertThat(logFilters).hasSize(2); + assertThat(logFilters.get(0).headerBytes).isEqualTo(4096); + assertThat(logFilters.get(0).messageBytes).isEqualTo(2048); + assertThat(logFilters.get(0).excludePattern).isFalse(); + assertThat(logFilters.get(0).matchAll).isTrue(); + assertThat(logFilters.get(0).services).isEmpty(); + assertThat(logFilters.get(0).methods).isEmpty(); + + assertThat(logFilters.get(1).headerBytes).isEqualTo(0); + assertThat(logFilters.get(1).messageBytes).isEqualTo(0); + assertThat(logFilters.get(1).excludePattern).isTrue(); + assertThat(logFilters.get(1).matchAll).isFalse(); + assertThat(logFilters.get(1).services).isEqualTo(Collections.singleton("Service2")); + assertThat(logFilters.get(1).methods) + .isEqualTo(Collections.singleton("service1/Method2")); } @Test @@ -294,7 +448,45 @@ public void badCustomTags() throws IOException { fail("exception expected!"); } catch (IllegalArgumentException iae) { assertThat(iae.getMessage()).isEqualTo( - "'custom_tags' needs to be a map of "); + "'labels' needs to be a map of "); } } -} \ No newline at end of file + + @Test + public void globalLogFilterExclude() throws IOException { + try { + observabilityConfig.parse(LOG_FILTER_GLOBAL_EXCLUDE); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).isEqualTo( + "cannot have 'exclude' and '*' wildcard in the same filter"); + } + } + + @Test + public void logFilterInvalidMethod() throws IOException { + try { + observabilityConfig.parse(LOG_FILTER_INVALID_METHOD); + fail("exception expected!"); + } catch (IllegalArgumentException iae) { + assertThat(iae.getMessage()).contains( + "invalid service or method filter"); + } + } + + @Test + public void validLogFilter() throws Exception { + observabilityConfig.parse(VALID_LOG_FILTERS); + assertTrue(observabilityConfig.isEnableCloudLogging()); + assertThat(observabilityConfig.getProjectId()).isEqualTo("grpc-testing"); + List logFilterList = observabilityConfig.getServerLogFilters(); + assertThat(logFilterList).hasSize(1); + assertThat(logFilterList.get(0).headerBytes).isEqualTo(16); + assertThat(logFilterList.get(0).messageBytes).isEqualTo(64); + assertThat(logFilterList.get(0).excludePattern).isFalse(); + assertThat(logFilterList.get(0).matchAll).isFalse(); + assertThat(logFilterList.get(0).services).isEqualTo(Collections.singleton("service.Service1")); + assertThat(logFilterList.get(0).methods) + .isEqualTo(Collections.singleton("service2.Service4/method4")); + } +} diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTestHelper.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityTestHelper.java similarity index 97% rename from gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTestHelper.java rename to gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityTestHelper.java index 529ec2503fd..ebb73ec76a1 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/LoggingTestHelper.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/ObservabilityTestHelper.java @@ -21,7 +21,7 @@ import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; -public class LoggingTestHelper { +public class ObservabilityTestHelper { static String makeUnaryRpcViaClientStub( String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) { diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java index ec759827737..ae7aa63befc 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/TracesTest.java @@ -100,7 +100,7 @@ public void run() { when(mockConfig.isEnableCloudTracing()).thenReturn(true); when(mockConfig.getSampler()).thenReturn(Samplers.alwaysSample()); - when(mockConfig.getDestinationProjectId()).thenReturn(PROJECT_ID); + when(mockConfig.getProjectId()).thenReturn(PROJECT_ID); try { GcpObservability observability = @@ -110,7 +110,7 @@ public void run() { Server server = ServerBuilder.forPort(0) - .addService(new LoggingTestHelper.SimpleServiceImpl()) + .addService(new ObservabilityTestHelper.SimpleServiceImpl()) .build() .start(); int port = cleanupRule.register(server).getPort(); @@ -118,7 +118,7 @@ public void run() { SimpleServiceGrpc.newBlockingStub( cleanupRule.register( ManagedChannelBuilder.forAddress("localhost", port).usePlaintext().build())); - assertThat(LoggingTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) + assertThat(ObservabilityTestHelper.makeUnaryRpcViaClientStub("buddy", stub)) .isEqualTo("Hello buddy"); // Adding sleep to ensure traces are exported before querying cloud tracing backend TimeUnit.SECONDS.sleep(10); diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java index ba6e05e2dcd..971e6070777 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/ConfigFilterHelperTest.java @@ -17,44 +17,32 @@ package io.grpc.gcp.observability.interceptors; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import io.grpc.MethodDescriptor; import io.grpc.gcp.observability.ObservabilityConfig; import io.grpc.gcp.observability.ObservabilityConfig.LogFilter; import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; -import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; -import io.grpc.testing.TestMethodDescriptors; -import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.Set; import org.junit.Before; import org.junit.Test; public class ConfigFilterHelperTest { private static final ImmutableList configLogFilters = ImmutableList.of( - new LogFilter("service1/Method2",1024,1024), - new LogFilter("service2/*",2048,1024), - new LogFilter("*",128,128), - new LogFilter("service2/*",2048,1024)); - - private static final ImmutableList configEventTypes = - ImmutableList.of( - EventType.GRPC_CALL_REQUEST_HEADER, - EventType.GRPC_CALL_HALF_CLOSE, - EventType.GRPC_CALL_TRAILER); - - private final MethodDescriptor.Builder builder = TestMethodDescriptors.voidMethod() - .toBuilder(); - private MethodDescriptor method; + new LogFilter(Collections.emptySet(), Collections.singleton("service1/Method2"), false, + 1024, 1024, false), + new LogFilter( + Collections.singleton("service2"), Collections.singleton("service4/method2"), false, + 2048, 1024, false), + new LogFilter( + Collections.singleton("service2"), Collections.singleton("service4/method3"), false, + 2048, 1024, true), + new LogFilter( + Collections.emptySet(), Collections.emptySet(), true, + 128, 128, false)); private ObservabilityConfig mockConfig; private ConfigFilterHelper configFilterHelper; @@ -62,157 +50,100 @@ public class ConfigFilterHelperTest { @Before public void setup() { mockConfig = mock(ObservabilityConfig.class); - configFilterHelper = new ConfigFilterHelper(mockConfig); - } - - @Test - public void disableCloudLogging_emptyLogFilters() { - when(mockConfig.isEnableCloudLogging()).thenReturn(false); - assertFalse(configFilterHelper.methodOrServiceFilterPresent); - assertThat(configFilterHelper.perServiceFilters).isEmpty(); - assertThat(configFilterHelper.perServiceFilters).isEmpty(); - assertThat(configFilterHelper.perMethodFilters).isEmpty(); - assertThat(configFilterHelper.logEventTypeSet).isNull(); + configFilterHelper = ConfigFilterHelper.getInstance(mockConfig); } @Test - public void enableCloudLogging_emptyLogFilters() { + public void enableCloudLogging_withoutLogFilters() { when(mockConfig.isEnableCloudLogging()).thenReturn(true); - when(mockConfig.getLogFilters()).thenReturn(null); - when(mockConfig.getEventTypes()).thenReturn(null); - configFilterHelper.setMethodOrServiceFilterMaps(); - configFilterHelper.setEventFilterSet(); - - assertFalse(configFilterHelper.methodOrServiceFilterPresent); - assertThat(configFilterHelper.perServiceFilters).isEmpty(); - assertThat(configFilterHelper.perServiceFilters).isEmpty(); - assertThat(configFilterHelper.perMethodFilters).isEmpty(); - assertThat(configFilterHelper.logEventTypeSet).isNull(); + assertThat(mockConfig.getClientLogFilters()).isEmpty(); + assertThat(mockConfig.getServerLogFilters()).isEmpty(); } @Test - public void enableCloudLogging_withLogFilters() { + public void checkMethodLog_withoutLogFilters() { when(mockConfig.isEnableCloudLogging()).thenReturn(true); - when(mockConfig.getLogFilters()).thenReturn(configLogFilters); - when(mockConfig.getEventTypes()).thenReturn(configEventTypes); - - configFilterHelper.setMethodOrServiceFilterMaps(); - configFilterHelper.setEventFilterSet(); - - assertTrue(configFilterHelper.methodOrServiceFilterPresent); + assertThat(mockConfig.getClientLogFilters()).isEmpty(); + assertThat(mockConfig.getServerLogFilters()).isEmpty(); - Map expectedServiceFilters = new HashMap<>(); - expectedServiceFilters.put("*", - FilterParams.create(true, 128, 128)); - expectedServiceFilters.put("service2", - FilterParams.create(true, 2048, 1024)); - assertThat(configFilterHelper.perServiceFilters).isEqualTo(expectedServiceFilters); - - Map expectedMethodFilters = new HashMap<>(); - expectedMethodFilters.put("service1/Method2", - FilterParams.create(true, 1024, 1024)); - assertThat(configFilterHelper.perMethodFilters).isEqualTo(expectedMethodFilters); - - Set expectedLogEventTypeSet = ImmutableSet.copyOf(configEventTypes); - assertThat(configFilterHelper.logEventTypeSet).isEqualTo(expectedLogEventTypeSet); + FilterParams expectedParams = + FilterParams.create(false, 0, 0); + FilterParams clientResultParams + = configFilterHelper.logRpcMethod("service3/Method3", true); + assertThat(clientResultParams).isEqualTo(expectedParams); + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service3/Method3", false); + assertThat(serverResultParams).isEqualTo(expectedParams); } @Test public void checkMethodAlwaysLogged() { - List sampleLogFilters = ImmutableList.of( - new LogFilter("*", 4096, 4096)); - when(mockConfig.getLogFilters()).thenReturn(sampleLogFilters); - configFilterHelper.setMethodOrServiceFilterMaps(); + List sampleLogFilters = + ImmutableList.of( + new LogFilter( + Collections.emptySet(), Collections.emptySet(), true, + 4096, 4096, false)); + when(mockConfig.getClientLogFilters()).thenReturn(sampleLogFilters); + when(mockConfig.getServerLogFilters()).thenReturn(sampleLogFilters); FilterParams expectedParams = FilterParams.create(true, 4096, 4096); - method = builder.setFullMethodName("service1/Method6").build(); - FilterParams resultParams - = configFilterHelper.isMethodToBeLogged(method); - assertThat(resultParams).isEqualTo(expectedParams); + FilterParams clientResultParams + = configFilterHelper.logRpcMethod("service1/Method6", true); + assertThat(clientResultParams).isEqualTo(expectedParams); + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service1/Method6", false); + assertThat(serverResultParams).isEqualTo(expectedParams); } @Test public void checkMethodNotToBeLogged() { - List sampleLogFilters = ImmutableList.of( - new LogFilter("service1/Method2", 1024, 1024), - new LogFilter("service2/*", 2048, 1024)); - when(mockConfig.getLogFilters()).thenReturn(sampleLogFilters); - configFilterHelper.setMethodOrServiceFilterMaps(); + List sampleLogFilters = + ImmutableList.of( + new LogFilter(Collections.emptySet(), Collections.singleton("service2/*"), false, + 1024, 1024, true), + new LogFilter( + Collections.singleton("service2/Method1"), Collections.emptySet(), false, + 2048, 1024, false)); + when(mockConfig.getClientLogFilters()).thenReturn(sampleLogFilters); + when(mockConfig.getServerLogFilters()).thenReturn(sampleLogFilters); FilterParams expectedParams = FilterParams.create(false, 0, 0); - method = builder.setFullMethodName("service3/Method3").build(); - FilterParams resultParams - = configFilterHelper.isMethodToBeLogged(method); - assertThat(resultParams).isEqualTo(expectedParams); + FilterParams clientResultParams1 + = configFilterHelper.logRpcMethod("service3/Method3", true); + assertThat(clientResultParams1).isEqualTo(expectedParams); + + FilterParams clientResultParams2 + = configFilterHelper.logRpcMethod("service2/Method1", true); + assertThat(clientResultParams2).isEqualTo(expectedParams); + + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service2/Method1", false); + assertThat(serverResultParams).isEqualTo(expectedParams); } @Test public void checkMethodToBeLoggedConditional() { - when(mockConfig.getLogFilters()).thenReturn(configLogFilters); - configFilterHelper.setMethodOrServiceFilterMaps(); + when(mockConfig.getClientLogFilters()).thenReturn(configLogFilters); + when(mockConfig.getServerLogFilters()).thenReturn(configLogFilters); FilterParams expectedParams = FilterParams.create(true, 1024, 1024); - method = builder.setFullMethodName("service1/Method2").build(); FilterParams resultParams - = configFilterHelper.isMethodToBeLogged(method); + = configFilterHelper.logRpcMethod("service1/Method2", true); assertThat(resultParams).isEqualTo(expectedParams); FilterParams expectedParamsWildCard = FilterParams.create(true, 2048, 1024); - method = builder.setFullMethodName("service2/Method1").build(); FilterParams resultParamsWildCard - = configFilterHelper.isMethodToBeLogged(method); + = configFilterHelper.logRpcMethod("service2/Method1", true); assertThat(resultParamsWildCard).isEqualTo(expectedParamsWildCard); - } - - @Test - public void checkEventToBeLogged_noFilter_defaultLogAllEventTypes() { - List eventList = new ArrayList<>(); - eventList.add(EventType.GRPC_CALL_REQUEST_HEADER); - eventList.add(EventType.GRPC_CALL_RESPONSE_HEADER); - eventList.add(EventType.GRPC_CALL_REQUEST_MESSAGE); - eventList.add(EventType.GRPC_CALL_RESPONSE_MESSAGE); - eventList.add(EventType.GRPC_CALL_HALF_CLOSE); - eventList.add(EventType.GRPC_CALL_TRAILER); - eventList.add(EventType.GRPC_CALL_CANCEL); - - for (EventType event : eventList) { - assertTrue(configFilterHelper.isEventToBeLogged(event)); - } - } - - @Test - public void checkEventToBeLogged_emptyFilter_doNotLogEventTypes() { - when(mockConfig.getEventTypes()).thenReturn(new ArrayList<>()); - configFilterHelper.setEventFilterSet(); - - List eventList = new ArrayList<>(); - eventList.add(EventType.GRPC_CALL_REQUEST_HEADER); - eventList.add(EventType.GRPC_CALL_RESPONSE_HEADER); - eventList.add(EventType.GRPC_CALL_REQUEST_MESSAGE); - eventList.add(EventType.GRPC_CALL_RESPONSE_MESSAGE); - eventList.add(EventType.GRPC_CALL_HALF_CLOSE); - eventList.add(EventType.GRPC_CALL_TRAILER); - eventList.add(EventType.GRPC_CALL_CANCEL); - - for (EventType event : eventList) { - assertFalse(configFilterHelper.isEventToBeLogged(event)); - } - } - - @Test - public void checkEventToBeLogged_withEventTypesFromConfig() { - when(mockConfig.getEventTypes()).thenReturn(configEventTypes); - configFilterHelper.setEventFilterSet(); - - EventType logEventType = EventType.GRPC_CALL_REQUEST_HEADER; - assertTrue(configFilterHelper.isEventToBeLogged(logEventType)); - - EventType doNotLogEventType = EventType.GRPC_CALL_RESPONSE_MESSAGE; - assertFalse(configFilterHelper.isEventToBeLogged(doNotLogEventType)); + FilterParams excludeParams = + FilterParams.create(false, 0, 0); + FilterParams serverResultParams + = configFilterHelper.logRpcMethod("service4/method3", false); + assertThat(serverResultParams).isEqualTo(excludeParams); } } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java index 025c99e5b6a..2a2e1d4c229 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingChannelInterceptorTest.java @@ -26,7 +26,6 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -108,8 +107,6 @@ public void setup() throws Exception { cancelCalled = SettableFuture.create(); peer = new InetSocketAddress(InetAddress.getByName("127.0.0.1"), 1234); filterParams = FilterParams.create(true, 0, 0); - when(mockFilterHelper.isEventToBeLogged(any(GrpcLogRecord.EventType.class))) - .thenReturn(true); } @Test @@ -164,7 +161,7 @@ public String authority() { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) .thenReturn(filterParams); ClientCall interceptedLoggingCall = @@ -185,7 +182,7 @@ public String authority() { clientInitial.put(keyA, dataA); clientInitial.put(keyB, dataB); interceptedLoggingCall.start(mockListener, clientInitial); - verify(mockLogHelper).logRequestHeader( + verify(mockLogHelper).logClientHeader( /*seq=*/ eq(1L), eq("service"), eq("method"), @@ -193,7 +190,7 @@ public String authority() { ArgumentMatchers.isNull(), same(clientInitial), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_CLIENT), + eq(EventLogger.CLIENT), anyString(), ArgumentMatchers.isNull()); verifyNoMoreInteractions(mockLogHelper); @@ -207,13 +204,14 @@ public String authority() { { Metadata serverInitial = new Metadata(); interceptedListener.get().onHeaders(serverInitial); - verify(mockLogHelper).logResponseHeader( + verify(mockLogHelper).logServerHeader( /*seq=*/ eq(2L), eq("service"), eq("method"), + eq("the-authority"), same(serverInitial), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_CLIENT), + eq(EventLogger.CLIENT), anyString(), same(peer)); verifyNoMoreInteractions(mockLogHelper); @@ -231,10 +229,11 @@ public String authority() { /*seq=*/ eq(3L), eq("service"), eq("method"), - eq(EventType.GRPC_CALL_REQUEST_MESSAGE), + eq("the-authority"), + eq(EventType.CLIENT_MESSAGE), same(request), eq(filterParams.messageBytes()), - eq(EventLogger.LOGGER_CLIENT), + eq(EventLogger.CLIENT), anyString()); verifyNoMoreInteractions(mockLogHelper); assertSame(request, actualRequest.get()); @@ -250,7 +249,8 @@ public String authority() { /*seq=*/ eq(4L), eq("service"), eq("method"), - eq(EventLogger.LOGGER_CLIENT), + eq("the-authority"), + eq(EventLogger.CLIENT), anyString()); halfCloseCalled.get(1, TimeUnit.MILLISECONDS); verifyNoMoreInteractions(mockLogHelper); @@ -267,10 +267,11 @@ public String authority() { /*seq=*/ eq(5L), eq("service"), eq("method"), - eq(EventType.GRPC_CALL_RESPONSE_MESSAGE), + eq("the-authority"), + eq(EventType.SERVER_MESSAGE), same(response), eq(filterParams.messageBytes()), - eq(EventLogger.LOGGER_CLIENT), + eq(EventLogger.CLIENT), anyString()); verifyNoMoreInteractions(mockLogHelper); verify(mockListener).onMessage(same(response)); @@ -288,10 +289,11 @@ public String authority() { /*seq=*/ eq(6L), eq("service"), eq("method"), + eq("the-authority"), same(status), same(trailers), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_CLIENT), + eq(EventLogger.CLIENT), anyString(), same(peer)); verifyNoMoreInteractions(mockLogHelper); @@ -308,7 +310,8 @@ public String authority() { /*seq=*/ eq(7L), eq("service"), eq("method"), - eq(EventLogger.LOGGER_CLIENT), + eq("the-authority"), + eq(EventLogger.CLIENT), anyString()); cancelCalled.get(1, TimeUnit.MILLISECONDS); } @@ -323,7 +326,7 @@ public void clientDeadLineLogged_deadlineSetViaCallOption() { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) .thenReturn(filterParams); @SuppressWarnings("unchecked") ClientCall.Listener mockListener = mock(ClientCall.Listener.class); @@ -349,7 +352,7 @@ public String authority() { interceptedLoggingCall.start(mockListener, new Metadata()); ArgumentCaptor callOptTimeoutCaptor = ArgumentCaptor.forClass(Duration.class); verify(mockLogHelper, times(1)) - .logRequestHeader( + .logClientHeader( anyLong(), AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), @@ -381,7 +384,7 @@ public void clientDeadlineLogged_deadlineSetViaContext() throws Exception { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) .thenReturn(filterParams); callFuture.set( @@ -408,7 +411,7 @@ public String authority() { Objects.requireNonNull(callFuture.get()).start(mockListener, new Metadata()); ArgumentCaptor contextTimeoutCaptor = ArgumentCaptor.forClass(Duration.class); verify(mockLogHelper, times(1)) - .logRequestHeader( + .logClientHeader( anyLong(), AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), @@ -443,7 +446,7 @@ public void clientDeadlineLogged_deadlineSetViaContextAndCallOptions() throws Ex .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) .thenReturn(filterParams); callFuture.set( @@ -470,7 +473,7 @@ public String authority() { Objects.requireNonNull(callFuture.get()).start(mockListener, new Metadata()); ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Duration.class); verify(mockLogHelper, times(1)) - .logRequestHeader( + .logClientHeader( anyLong(), AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), @@ -541,7 +544,7 @@ public String authority() { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) .thenReturn(FilterParams.create(false, 0, 0)); ClientCall interceptedLoggingCall = @@ -606,7 +609,7 @@ public String authority() { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), true)) .thenReturn(FilterParams.create(true, 10, 10)); ClientCall interceptedLoggingCall = @@ -630,105 +633,4 @@ public String authority() { assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(7); } } - - @Test - public void eventFilter_enabled() { - when(mockFilterHelper.isEventToBeLogged(EventType.GRPC_CALL_REQUEST_HEADER)).thenReturn(false); - when(mockFilterHelper.isEventToBeLogged(EventType.GRPC_CALL_RESPONSE_HEADER)).thenReturn(false); - - Channel channel = new Channel() { - @Override - public ClientCall newCall( - MethodDescriptor methodDescriptor, CallOptions callOptions) { - return new NoopClientCall() { - @Override - @SuppressWarnings("unchecked") - public void start(Listener responseListener, Metadata headers) { - interceptedListener.set((Listener) responseListener); - actualClientInitial.set(headers); - } - - @Override - public void sendMessage(RequestT message) { - actualRequest.set(message); - } - - @Override - public void cancel(String message, Throwable cause) { - cancelCalled.set(null); - } - - @Override - public void halfClose() { - halfCloseCalled.set(null); - } - - @Override - public Attributes getAttributes() { - return Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer).build(); - } - }; - } - - @Override - public String authority() { - return "the-authority"; - } - }; - - @SuppressWarnings("unchecked") - ClientCall.Listener mockListener = mock(ClientCall.Listener.class); - - MethodDescriptor method = - MethodDescriptor.newBuilder() - .setType(MethodType.UNKNOWN) - .setFullMethodName("service/method") - .setRequestMarshaller(BYTEARRAY_MARSHALLER) - .setResponseMarshaller(BYTEARRAY_MARSHALLER) - .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) - .thenReturn(FilterParams.create(true, 10, 10)); - - ClientCall interceptedLoggingCall = - factory.create() - .interceptCall(method, - CallOptions.DEFAULT, - channel); - - { - interceptedLoggingCall.start(mockListener, new Metadata()); - verify(mockLogHelper, never()).logRequestHeader( - anyLong(), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - any(Duration.class), - any(Metadata.class), - anyInt(), - any(GrpcLogRecord.EventLogger.class), - anyString(), - AdditionalMatchers.or(ArgumentMatchers.isNull(), - ArgumentMatchers.any())); - interceptedListener.get().onHeaders(new Metadata()); - verify(mockLogHelper, never()).logResponseHeader( - anyLong(), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - any(Metadata.class), - anyInt(), - any(GrpcLogRecord.EventLogger.class), - anyString(), - ArgumentMatchers.any()); - byte[] request = "this is a request".getBytes(US_ASCII); - interceptedLoggingCall.sendMessage(request); - interceptedLoggingCall.halfClose(); - byte[] response = "this is a response".getBytes(US_ASCII); - interceptedListener.get().onMessage(response); - Status status = Status.INTERNAL.withDescription("trailer description"); - Metadata trailers = new Metadata(); - interceptedListener.get().onClose(status, trailers); - interceptedLoggingCall.cancel(null, null); - assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(5); - } - } } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java index a222da4c4d3..fee936dfbca 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/InternalLoggingServerInterceptorTest.java @@ -20,12 +20,10 @@ import static io.grpc.gcp.observability.interceptors.LogHelperTest.BYTEARRAY_MARSHALLER; import static org.junit.Assert.assertSame; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -45,7 +43,6 @@ import io.grpc.Status; import io.grpc.gcp.observability.interceptors.ConfigFilterHelper.FilterParams; import io.grpc.internal.NoopServerCall; -import io.grpc.observabilitylog.v1.GrpcLogRecord; import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; import java.net.InetAddress; @@ -61,7 +58,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.AdditionalMatchers; import org.mockito.ArgumentCaptor; import org.mockito.ArgumentMatchers; import org.mockito.Mockito; @@ -105,8 +101,6 @@ public void setup() throws Exception { actualStatus = new AtomicReference<>(); actualTrailers = new AtomicReference<>(); peer = new InetSocketAddress(InetAddress.getByName("127.0.0.1"), 1234); - when(mockFilterHelper.isEventToBeLogged(any(GrpcLogRecord.EventType.class))) - .thenReturn(true); } @Test @@ -121,7 +115,7 @@ public void internalLoggingServerInterceptor() { .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); FilterParams filterParams = FilterParams.create(true, 0, 0); - when(mockFilterHelper.isMethodToBeLogged(method)).thenReturn(filterParams); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)).thenReturn(filterParams); capturedListener = factory.create() .interceptCall( @@ -167,7 +161,7 @@ public String getAuthority() { }); // receive request header { - verify(mockLogHelper).logRequestHeader( + verify(mockLogHelper).logClientHeader( /*seq=*/ eq(1L), eq("service"), eq("method"), @@ -175,7 +169,7 @@ public String getAuthority() { ArgumentMatchers.isNull(), same(clientInitial), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_SERVER), + eq(EventLogger.SERVER), anyString(), same(peer)); verifyNoMoreInteractions(mockLogHelper); @@ -188,13 +182,14 @@ public String getAuthority() { { Metadata serverInitial = new Metadata(); interceptedLoggingCall.get().sendHeaders(serverInitial); - verify(mockLogHelper).logResponseHeader( + verify(mockLogHelper).logServerHeader( /*seq=*/ eq(2L), eq("service"), eq("method"), + eq("the-authority"), same(serverInitial), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_SERVER), + eq(EventLogger.SERVER), anyString(), ArgumentMatchers.isNull()); verifyNoMoreInteractions(mockLogHelper); @@ -212,10 +207,11 @@ public String getAuthority() { /*seq=*/ eq(3L), eq("service"), eq("method"), - eq(EventType.GRPC_CALL_REQUEST_MESSAGE), + eq("the-authority"), + eq(EventType.CLIENT_MESSAGE), same(request), eq(filterParams.messageBytes()), - eq(EventLogger.LOGGER_SERVER), + eq(EventLogger.SERVER), anyString()); verifyNoMoreInteractions(mockLogHelper); verify(mockListener).onMessage(same(request)); @@ -231,7 +227,8 @@ public String getAuthority() { /*seq=*/ eq(4L), eq("service"), eq("method"), - eq(EventLogger.LOGGER_SERVER), + eq("the-authority"), + eq(EventLogger.SERVER), anyString()); verifyNoMoreInteractions(mockLogHelper); verify(mockListener).onHalfClose(); @@ -248,10 +245,11 @@ public String getAuthority() { /*seq=*/ eq(5L), eq("service"), eq("method"), - eq(EventType.GRPC_CALL_RESPONSE_MESSAGE), + eq("the-authority"), + eq(EventType.SERVER_MESSAGE), same(response), eq(filterParams.messageBytes()), - eq(EventLogger.LOGGER_SERVER), + eq(EventLogger.SERVER), anyString()); verifyNoMoreInteractions(mockLogHelper); assertSame(response, actualResponse.get()); @@ -269,10 +267,11 @@ public String getAuthority() { /*seq=*/ eq(6L), eq("service"), eq("method"), + eq("the-authority"), same(status), same(trailers), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_SERVER), + eq(EventLogger.SERVER), anyString(), ArgumentMatchers.isNull()); verifyNoMoreInteractions(mockLogHelper); @@ -290,7 +289,8 @@ public String getAuthority() { /*seq=*/ eq(7L), eq("service"), eq("method"), - eq(EventLogger.LOGGER_SERVER), + eq("the-authority"), + eq(EventLogger.SERVER), anyString()); verify(mockListener).onCancel(); } @@ -306,7 +306,7 @@ public void serverDeadlineLogged() { .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); FilterParams filterParams = FilterParams.create(true, 0, 0); - when(mockFilterHelper.isMethodToBeLogged(method)).thenReturn(filterParams); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)).thenReturn(filterParams); final ServerCall noopServerCall = new NoopServerCall() { @Override public MethodDescriptor getMethodDescriptor() { @@ -332,7 +332,7 @@ public String getAuthority() { }); ArgumentCaptor timeoutCaptor = ArgumentCaptor.forClass(Duration.class); verify(mockLogHelper, times(1)) - .logRequestHeader( + .logClientHeader( /*seq=*/ eq(1L), eq("service"), eq("method"), @@ -340,7 +340,7 @@ public String getAuthority() { timeoutCaptor.capture(), any(Metadata.class), eq(filterParams.headerBytes()), - eq(EventLogger.LOGGER_SERVER), + eq(EventLogger.SERVER), anyString(), ArgumentMatchers.isNull()); verifyNoMoreInteractions(mockLogHelper); @@ -359,7 +359,8 @@ public void serverMethodOrServiceFilter_disabled() { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)).thenReturn(FilterParams.create(false, 0, 0)); + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)) + .thenReturn(FilterParams.create(false, 0, 0)); capturedListener = factory.create() .interceptCall( @@ -416,7 +417,7 @@ public void serverMethodOrServiceFilter_enabled() { .setRequestMarshaller(BYTEARRAY_MARSHALLER) .setResponseMarshaller(BYTEARRAY_MARSHALLER) .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) + when(mockFilterHelper.logRpcMethod(method.getFullMethodName(), false)) .thenReturn(FilterParams.create(true, 10, 10)); capturedListener = @@ -477,84 +478,4 @@ public String getAuthority() { assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(7); } } - - @Test - public void eventFilter_enabled() { - when(mockFilterHelper.isEventToBeLogged(EventType.GRPC_CALL_HALF_CLOSE)).thenReturn(false); - - Metadata clientInitial = new Metadata(); - final MethodDescriptor method = - MethodDescriptor.newBuilder() - .setType(MethodType.UNKNOWN) - .setFullMethodName("service/method") - .setRequestMarshaller(BYTEARRAY_MARSHALLER) - .setResponseMarshaller(BYTEARRAY_MARSHALLER) - .build(); - when(mockFilterHelper.isMethodToBeLogged(method)) - .thenReturn(FilterParams.create(true, 10, 10)); - - capturedListener = - factory.create() - .interceptCall( - new NoopServerCall() { - @Override - public void sendHeaders(Metadata headers) { - actualServerInitial.set(headers); - } - - @Override - public void sendMessage(byte[] message) { - actualResponse.set(message); - } - - @Override - public void close(Status status, Metadata trailers) { - actualStatus.set(status); - actualTrailers.set(trailers); - } - - @Override - public MethodDescriptor getMethodDescriptor() { - return method; - } - - @Override - public Attributes getAttributes() { - return Attributes - .newBuilder() - .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, peer) - .build(); - } - - @Override - public String getAuthority() { - return "the-authority"; - } - }, - clientInitial, - (call, headers) -> { - interceptedLoggingCall.set(call); - return mockListener; - }); - - { - interceptedLoggingCall.get().sendHeaders(new Metadata()); - byte[] request = "this is a request".getBytes(US_ASCII); - capturedListener.onMessage(request); - capturedListener.onHalfClose(); - byte[] response = "this is a response".getBytes(US_ASCII); - interceptedLoggingCall.get().sendMessage(response); - Status status = Status.INTERNAL.withDescription("trailer description"); - Metadata trailers = new Metadata(); - interceptedLoggingCall.get().close(status, trailers); - verify(mockLogHelper, never()).logHalfClose( - anyLong(), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - AdditionalMatchers.or(ArgumentMatchers.isNull(), anyString()), - any(GrpcLogRecord.EventLogger.class), - anyString()); - capturedListener.onCancel(); - assertThat(Mockito.mockingDetails(mockLogHelper).getInvocations().size()).isEqualTo(6); - } - } } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java index 209543595d6..73704eb4181 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/interceptors/LogHelperTest.java @@ -26,26 +26,22 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import com.google.common.collect.Iterables; import com.google.protobuf.ByteString; import com.google.protobuf.Duration; -import com.google.protobuf.Timestamp; import com.google.protobuf.util.Durations; import io.grpc.Attributes; import io.grpc.Grpc; import io.grpc.Metadata; import io.grpc.MethodDescriptor.Marshaller; import io.grpc.Status; -import io.grpc.gcp.observability.interceptors.LogHelper.PayloadBuilder; +import io.grpc.gcp.observability.interceptors.LogHelper.PayloadBuilderHelper; import io.grpc.gcp.observability.logging.GcpLogSink; import io.grpc.gcp.observability.logging.Sink; -import io.grpc.internal.TimeProvider; +import io.grpc.observabilitylog.v1.Address; import io.grpc.observabilitylog.v1.GrpcLogRecord; -import io.grpc.observabilitylog.v1.GrpcLogRecord.Address; import io.grpc.observabilitylog.v1.GrpcLogRecord.EventLogger; import io.grpc.observabilitylog.v1.GrpcLogRecord.EventType; -import io.grpc.observabilitylog.v1.GrpcLogRecord.LogLevel; -import io.grpc.observabilitylog.v1.GrpcLogRecord.MetadataEntry; +import io.grpc.observabilitylog.v1.Payload; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -55,8 +51,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.nio.charset.StandardCharsets; -import java.util.Objects; -import java.util.concurrent.TimeUnit; +import java.util.HashMap; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -77,33 +72,12 @@ public class LogHelperTest { Metadata.Key.of("b", Metadata.ASCII_STRING_MARSHALLER); private static final Metadata.Key KEY_C = Metadata.Key.of("c", Metadata.ASCII_STRING_MARSHALLER); - private static final MetadataEntry ENTRY_A = - MetadataEntry - .newBuilder() - .setKey(KEY_A.name()) - .setValue(ByteString.copyFrom(DATA_A.getBytes(StandardCharsets.US_ASCII))) - .build(); - private static final MetadataEntry ENTRY_B = - MetadataEntry - .newBuilder() - .setKey(KEY_B.name()) - .setValue(ByteString.copyFrom(DATA_B.getBytes(StandardCharsets.US_ASCII))) - .build(); - private static final MetadataEntry ENTRY_C = - MetadataEntry - .newBuilder() - .setKey(KEY_C.name()) - .setValue(ByteString.copyFrom(DATA_C.getBytes(StandardCharsets.US_ASCII))) - .build(); private static final int HEADER_LIMIT = 10; private static final int MESSAGE_LIMIT = Integer.MAX_VALUE; private final Metadata nonEmptyMetadata = new Metadata(); private final Sink sink = mock(GcpLogSink.class); - private final Timestamp timestamp - = Timestamp.newBuilder().setSeconds(9876).setNanos(54321).build(); - private final TimeProvider timeProvider = () -> TimeUnit.SECONDS.toNanos(9876) + 54321; - private final LogHelper logHelper = new LogHelper(sink, timeProvider); + private final LogHelper logHelper = new LogHelper(sink); @Before public void setUp() { @@ -159,29 +133,26 @@ public String toString() { @Test public void metadataToProto_empty() { assertThat(metadataToProtoTestHelper( - EventType.GRPC_CALL_REQUEST_HEADER, new Metadata(), Integer.MAX_VALUE)) + EventType.CLIENT_HEADER, new Metadata(), Integer.MAX_VALUE)) .isEqualTo(GrpcLogRecord.newBuilder() - .setEventType(EventType.GRPC_CALL_REQUEST_HEADER) - .setMetadata( - GrpcLogRecord.Metadata.getDefaultInstance()) + .setType(EventType.CLIENT_HEADER) + .setPayload( + Payload.newBuilder().putAllMetadata(new HashMap<>())) .build()); } @Test public void metadataToProto() { - int nonEmptyMetadataSize = 30; + Payload.Builder payloadBuilder = Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .putMetadata("c", DATA_C); + assertThat(metadataToProtoTestHelper( - EventType.GRPC_CALL_REQUEST_HEADER, nonEmptyMetadata, Integer.MAX_VALUE)) + EventType.CLIENT_HEADER, nonEmptyMetadata, Integer.MAX_VALUE)) .isEqualTo(GrpcLogRecord.newBuilder() - .setEventType(EventType.GRPC_CALL_REQUEST_HEADER) - .setMetadata( - GrpcLogRecord.Metadata - .newBuilder() - .addEntry(ENTRY_A) - .addEntry(ENTRY_B) - .addEntry(ENTRY_C) - .build()) - .setPayloadSize(nonEmptyMetadataSize) + .setType(EventType.CLIENT_HEADER) + .setPayload(payloadBuilder) .build()); } @@ -193,44 +164,45 @@ public void metadataToProto_setsTruncated() { @Test public void metadataToProto_truncated() { // 0 byte limit not enough for any metadata - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 0).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata.getDefaultInstance()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 0).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putAllMetadata(new HashMap<>()) + .build()); // not enough bytes for first key value - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 9).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata.getDefaultInstance()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 9).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putAllMetadata(new HashMap<>()) + .build()); // enough for first key value - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 10).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata - .newBuilder() - .addEntry(ENTRY_A) - .build()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 10).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder().putMetadata("a", DATA_A).build()); // Test edge cases for >= 2 key values - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 19).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata - .newBuilder() - .addEntry(ENTRY_A) - .build()); - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 20).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata - .newBuilder() - .addEntry(ENTRY_A) - .addEntry(ENTRY_B) - .build()); - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 29).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata - .newBuilder() - .addEntry(ENTRY_A) - .addEntry(ENTRY_B) - .build()); - + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 19).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder().putMetadata("a", DATA_A).build()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 20).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .build()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 29).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .build()); // not truncated: enough for all keys - assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 30).payload.build()) - .isEqualTo(io.grpc.observabilitylog.v1.GrpcLogRecord.Metadata - .newBuilder() - .addEntry(ENTRY_A) - .addEntry(ENTRY_B) - .addEntry(ENTRY_C) - .build()); + assertThat(LogHelper.createMetadataProto(nonEmptyMetadata, 30).payloadBuilder.build()) + .isEqualTo( + io.grpc.observabilitylog.v1.Payload.newBuilder() + .putMetadata("a", DATA_A) + .putMetadata("b", DATA_B) + .putMetadata("c", DATA_C) + .build()); } @Test @@ -240,8 +212,11 @@ public void messageToProto() { StandardCharsets.US_ASCII); assertThat(messageTestHelper(bytes, Integer.MAX_VALUE)) .isEqualTo(GrpcLogRecord.newBuilder() - .setMessage(ByteString.copyFrom(bytes)) - .setPayloadSize(bytes.length) + .setPayload( + Payload.newBuilder() + .setMessage( + ByteString.copyFrom(bytes)) + .setMessageLength(bytes.length)) .build()); } @@ -252,18 +227,25 @@ public void messageToProto_truncated() { StandardCharsets.US_ASCII); assertThat(messageTestHelper(bytes, 0)) .isEqualTo(GrpcLogRecord.newBuilder() - .setPayloadSize(bytes.length) + .setPayload( + Payload.newBuilder() + .setMessageLength(bytes.length)) .setPayloadTruncated(true) .build()); int limit = 10; String truncatedMessage = "this is a "; assertThat(messageTestHelper(bytes, limit)) - .isEqualTo(GrpcLogRecord.newBuilder() - .setMessage(ByteString.copyFrom(truncatedMessage.getBytes(StandardCharsets.US_ASCII))) - .setPayloadSize(bytes.length) - .setPayloadTruncated(true) - .build()); + .isEqualTo( + GrpcLogRecord.newBuilder() + .setPayload( + Payload.newBuilder() + .setMessage( + ByteString.copyFrom( + truncatedMessage.getBytes(StandardCharsets.US_ASCII))) + .setMessageLength(bytes.length)) + .setPayloadTruncated(true) + .build()); } @@ -274,30 +256,28 @@ public void logRequestHeader() throws Exception { String methodName = "method"; String authority = "authority"; Duration timeout = Durations.fromMillis(1234); - String rpcId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; InetAddress address = InetAddress.getByName("127.0.0.1"); int port = 12345; InetSocketAddress peerAddress = new InetSocketAddress(address, port); GrpcLogRecord.Builder builder = - metadataToProtoTestHelper(EventType.GRPC_CALL_REQUEST_HEADER, nonEmptyMetadata, + metadataToProtoTestHelper(EventType.CLIENT_HEADER, nonEmptyMetadata, HEADER_LIMIT) .toBuilder() - .setTimestamp(timestamp) .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_REQUEST_HEADER) - .setEventLogger(EventLogger.LOGGER_CLIENT) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setRpcId(rpcId); - builder.setAuthority(authority) - .setTimeout(timeout); + .setType(EventType.CLIENT_HEADER) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + builder.setAuthority(authority); + builder.setPayload(builder.getPayload().toBuilder().setTimeout(timeout).build()); GrpcLogRecord base = builder.build(); // logged on client { - logHelper.logRequestHeader( + logHelper.logClientHeader( seqId, serviceName, methodName, @@ -305,15 +285,15 @@ public void logRequestHeader() throws Exception { timeout, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, + EventLogger.CLIENT, + callId, null); verify(sink).write(base); } // logged on server { - logHelper.logRequestHeader( + logHelper.logClientHeader( seqId, serviceName, methodName, @@ -321,19 +301,19 @@ public void logRequestHeader() throws Exception { timeout, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_SERVER, - rpcId, + EventLogger.SERVER, + callId, peerAddress); verify(sink).write( base.toBuilder() - .setPeerAddress(LogHelper.socketAddressToProto(peerAddress)) - .setEventLogger(EventLogger.LOGGER_SERVER) + .setPeer(LogHelper.socketAddressToProto(peerAddress)) + .setLogger(EventLogger.SERVER) .build()); } // timeout is null { - logHelper.logRequestHeader( + logHelper.logClientHeader( seqId, serviceName, methodName, @@ -341,18 +321,18 @@ public void logRequestHeader() throws Exception { null, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, + EventLogger.CLIENT, + callId, null); verify(sink).write( base.toBuilder() - .clearTimeout() + .setPayload(base.getPayload().toBuilder().clearTimeout().build()) .build()); } // peerAddress is not null (error on client) try { - logHelper.logRequestHeader( + logHelper.logClientHeader( seqId, serviceName, methodName, @@ -360,8 +340,8 @@ public void logRequestHeader() throws Exception { timeout, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, + EventLogger.CLIENT, + callId, peerAddress); fail(); } catch (IllegalArgumentException expected) { @@ -374,68 +354,71 @@ public void logResponseHeader() throws Exception { long seqId = 1; String serviceName = "service"; String methodName = "method"; - String rpcId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + String authority = "authority"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; InetAddress address = InetAddress.getByName("127.0.0.1"); int port = 12345; InetSocketAddress peerAddress = new InetSocketAddress(address, port); GrpcLogRecord.Builder builder = - metadataToProtoTestHelper(EventType.GRPC_CALL_RESPONSE_HEADER, nonEmptyMetadata, + metadataToProtoTestHelper(EventType.SERVER_HEADER, nonEmptyMetadata, HEADER_LIMIT) .toBuilder() - .setTimestamp(timestamp) .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_RESPONSE_HEADER) - .setEventLogger(EventLogger.LOGGER_CLIENT) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setRpcId(rpcId); - builder.setPeerAddress(LogHelper.socketAddressToProto(peerAddress)); + .setAuthority(authority) + .setType(EventType.SERVER_HEADER) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + builder.setPeer(LogHelper.socketAddressToProto(peerAddress)); GrpcLogRecord base = builder.build(); // logged on client { - logHelper.logResponseHeader( + logHelper.logServerHeader( seqId, serviceName, methodName, + authority, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, + EventLogger.CLIENT, + callId, peerAddress); verify(sink).write(base); } // logged on server { - logHelper.logResponseHeader( + logHelper.logServerHeader( seqId, serviceName, methodName, + authority, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_SERVER, - rpcId, + EventLogger.SERVER, + callId, null); verify(sink).write( base.toBuilder() - .setEventLogger(EventLogger.LOGGER_SERVER) - .clearPeerAddress() + .setLogger(EventLogger.SERVER) + .clearPeer() .build()); } // peerAddress is not null (error on server) try { - logHelper.logResponseHeader( + logHelper.logServerHeader( seqId, serviceName, methodName, + authority, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_SERVER, - rpcId, + EventLogger.SERVER, + callId, peerAddress); fail(); @@ -450,27 +433,30 @@ public void logTrailer() throws Exception { long seqId = 1; String serviceName = "service"; String methodName = "method"; - String rpcId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + String authority = "authority"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; InetAddress address = InetAddress.getByName("127.0.0.1"); int port = 12345; - InetSocketAddress peerAddress = new InetSocketAddress(address, port); + InetSocketAddress peer = new InetSocketAddress(address, port); Status statusDescription = Status.INTERNAL.withDescription("test description"); GrpcLogRecord.Builder builder = - metadataToProtoTestHelper(EventType.GRPC_CALL_RESPONSE_HEADER, nonEmptyMetadata, + metadataToProtoTestHelper(EventType.SERVER_HEADER, nonEmptyMetadata, HEADER_LIMIT) .toBuilder() - .setTimestamp(timestamp) .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_TRAILER) - .setEventLogger(EventLogger.LOGGER_CLIENT) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) + .setAuthority(authority) + .setType(EventType.SERVER_TRAILER) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); + builder.setPeer(LogHelper.socketAddressToProto(peer)); + builder.setPayload( + builder.getPayload().toBuilder() .setStatusCode(Status.INTERNAL.getCode().value()) .setStatusMessage("test description") - .setRpcId(rpcId); - builder.setPeerAddress(LogHelper.socketAddressToProto(peerAddress)); + .build()); GrpcLogRecord base = builder.build(); // logged on client @@ -479,12 +465,13 @@ public void logTrailer() throws Exception { seqId, serviceName, methodName, + authority, statusDescription, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, - peerAddress); + EventLogger.CLIENT, + callId, + peer); verify(sink).write(base); } @@ -494,16 +481,17 @@ public void logTrailer() throws Exception { seqId, serviceName, methodName, + authority, statusDescription, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_SERVER, - rpcId, + EventLogger.SERVER, + callId, null); verify(sink).write( base.toBuilder() - .clearPeerAddress() - .setEventLogger(EventLogger.LOGGER_SERVER) + .clearPeer() + .setLogger(EventLogger.SERVER) .build()); } @@ -513,15 +501,16 @@ public void logTrailer() throws Exception { seqId, serviceName, methodName, + authority, statusDescription, nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, + EventLogger.CLIENT, + callId, null); verify(sink).write( base.toBuilder() - .clearPeerAddress() + .clearPeer() .build()); } @@ -531,15 +520,16 @@ public void logTrailer() throws Exception { seqId, serviceName, methodName, + authority, statusDescription.getCode().toStatus(), nonEmptyMetadata, HEADER_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId, - peerAddress); + EventLogger.CLIENT, + callId, + peer); verify(sink).write( base.toBuilder() - .clearStatusMessage() + .setPayload(base.getPayload().toBuilder().clearStatusMessage().build()) .build()); } } @@ -551,10 +541,9 @@ public void alwaysLoggedMetadata_grpcTraceBin() { Metadata metadata = new Metadata(); metadata.put(key, new byte[1]); int zeroHeaderBytes = 0; - PayloadBuilder pair = + PayloadBuilderHelper pair = LogHelper.createMetadataProto(metadata, zeroHeaderBytes); - assertThat(Objects.requireNonNull(Iterables.getOnlyElement(pair.payload.getEntryBuilderList())) - .getKey()).isEqualTo(key.name()); + assertThat(pair.payloadBuilder.build().getMetadataMap().containsKey(key.name())).isTrue(); assertFalse(pair.truncated); } @@ -565,9 +554,9 @@ public void neverLoggedMetadata_grpcStatusDetailsBin() { Metadata metadata = new Metadata(); metadata.put(key, new byte[1]); int unlimitedHeaderBytes = Integer.MAX_VALUE; - PayloadBuilder pair + PayloadBuilderHelper pair = LogHelper.createMetadataProto(metadata, unlimitedHeaderBytes); - assertThat(pair.payload.getEntryBuilderList()).isEmpty(); + assertThat(pair.payloadBuilder.getMetadataMap()).isEmpty(); assertFalse(pair.truncated); } @@ -576,19 +565,19 @@ public void logRpcMessage() { long seqId = 1; String serviceName = "service"; String methodName = "method"; - String rpcId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + String authority = "authority"; + String callId = "d155e885-9587-4e77-81f7-3aa5a443d47f"; byte[] message = new byte[100]; GrpcLogRecord.Builder builder = messageTestHelper(message, MESSAGE_LIMIT) .toBuilder() - .setTimestamp(timestamp) .setSequenceId(seqId) .setServiceName(serviceName) .setMethodName(methodName) - .setEventType(EventType.GRPC_CALL_REQUEST_MESSAGE) - .setEventLogger(EventLogger.LOGGER_CLIENT) - .setLogLevel(LogLevel.LOG_LEVEL_DEBUG) - .setRpcId(rpcId); + .setAuthority(authority) + .setType(EventType.CLIENT_MESSAGE) + .setLogger(EventLogger.CLIENT) + .setCallId(callId); GrpcLogRecord base = builder.build(); // request message { @@ -596,11 +585,12 @@ public void logRpcMessage() { seqId, serviceName, methodName, - EventType.GRPC_CALL_REQUEST_MESSAGE, + authority, + EventType.CLIENT_MESSAGE, message, MESSAGE_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId); + EventLogger.CLIENT, + callId); verify(sink).write(base); } // response message, logged on client @@ -609,14 +599,15 @@ public void logRpcMessage() { seqId, serviceName, methodName, - EventType.GRPC_CALL_RESPONSE_MESSAGE, + authority, + EventType.SERVER_MESSAGE, message, MESSAGE_LIMIT, - EventLogger.LOGGER_CLIENT, - rpcId); + EventLogger.CLIENT, + callId); verify(sink).write( base.toBuilder() - .setEventType(EventType.GRPC_CALL_RESPONSE_MESSAGE) + .setType(EventType.SERVER_MESSAGE) .build()); } // request message, logged on server @@ -625,14 +616,15 @@ public void logRpcMessage() { seqId, serviceName, methodName, - EventType.GRPC_CALL_REQUEST_MESSAGE, + authority, + EventType.CLIENT_MESSAGE, message, MESSAGE_LIMIT, - EventLogger.LOGGER_SERVER, - rpcId); + EventLogger.SERVER, + callId); verify(sink).write( base.toBuilder() - .setEventLogger(EventLogger.LOGGER_SERVER) + .setLogger(EventLogger.SERVER) .build()); } // response message, logged on server @@ -641,15 +633,34 @@ public void logRpcMessage() { seqId, serviceName, methodName, - EventType.GRPC_CALL_RESPONSE_MESSAGE, + authority, + EventType.SERVER_MESSAGE, message, MESSAGE_LIMIT, - EventLogger.LOGGER_SERVER, - rpcId); + EventLogger.SERVER, + callId); + verify(sink).write( + base.toBuilder() + .setType(EventType.SERVER_MESSAGE) + .setLogger(EventLogger.SERVER) + .build()); + } + // message is not of type : com.google.protobuf.Message or byte[] + { + logHelper.logRpcMessage( + seqId, + serviceName, + methodName, + authority, + EventType.CLIENT_MESSAGE, + "message", + MESSAGE_LIMIT, + EventLogger.CLIENT, + callId); verify(sink).write( base.toBuilder() - .setEventType(EventType.GRPC_CALL_RESPONSE_MESSAGE) - .setEventLogger(EventLogger.LOGGER_SERVER) + .clearPayload() + .clearPayloadTruncated() .build()); } } @@ -667,21 +678,19 @@ public void getPeerAddressTest() throws Exception { private static GrpcLogRecord metadataToProtoTestHelper( EventType type, Metadata metadata, int maxHeaderBytes) { GrpcLogRecord.Builder builder = GrpcLogRecord.newBuilder(); - PayloadBuilder pair + PayloadBuilderHelper pair = LogHelper.createMetadataProto(metadata, maxHeaderBytes); - builder.setMetadata(pair.payload); - builder.setPayloadSize(pair.size); + builder.setPayload(pair.payloadBuilder); builder.setPayloadTruncated(pair.truncated); - builder.setEventType(type); + builder.setType(type); return builder.build(); } private static GrpcLogRecord messageTestHelper(byte[] message, int maxMessageBytes) { GrpcLogRecord.Builder builder = GrpcLogRecord.newBuilder(); - PayloadBuilder pair + PayloadBuilderHelper pair = LogHelper.createMessageProto(message, maxMessageBytes); - builder.setMessage(pair.payload); - builder.setPayloadSize(pair.size); + builder.setPayload(pair.payloadBuilder); builder.setPayloadTruncated(pair.truncated); return builder.build(); } diff --git a/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java b/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java index cea081b9b55..e02cc6dd4eb 100644 --- a/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java +++ b/gcp-observability/src/test/java/io/grpc/gcp/observability/logging/GcpLogSinkTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.anyIterable; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -67,7 +66,6 @@ public class GcpLogSinkTest { private static final ImmutableMap CUSTOM_TAGS = ImmutableMap.of("KEY1", "Value1", "KEY2", "VALUE2"); - private static final long FLUSH_LIMIT = 10L; // gRPC is expected to always use this log name when reporting to GCP cloud logging. private static final String EXPECTED_LOG_NAME = "microservices.googleapis.com%2Fobservability%2Fgrpc"; @@ -77,28 +75,33 @@ public class GcpLogSinkTest { private static final String METHOD_NAME = "method"; private static final String AUTHORITY = "authority"; private static final Duration TIMEOUT = Durations.fromMillis(1234); - private static final String RPC_ID = "d155e885-9587-4e77-81f7-3aa5a443d47f"; + private static final String CALL_ID = "d155e885-9587-4e77-81f7-3aa5a443d47f"; private static final GrpcLogRecord LOG_PROTO = GrpcLogRecord.newBuilder() .setSequenceId(SEQ_ID) .setServiceName(SERVICE_NAME) .setMethodName(METHOD_NAME) .setAuthority(AUTHORITY) - .setTimeout(TIMEOUT) - .setEventType(EventType.GRPC_CALL_REQUEST_HEADER) - .setEventLogger(EventLogger.LOGGER_CLIENT) - .setRpcId(RPC_ID) + .setPayload(io.grpc.observabilitylog.v1.Payload.newBuilder().setTimeout(TIMEOUT)) + .setType(EventType.CLIENT_HEADER) + .setLogger(EventLogger.CLIENT) + .setCallId(CALL_ID) .build(); + // .putFields("timeout", Value.newBuilder().setStringValue("1.234s").build()) + private static final Struct struct = + Struct.newBuilder() + .putFields("timeout", Value.newBuilder().setStringValue("1.234s").build()) + .build(); private static final Struct EXPECTED_STRUCT_LOG_PROTO = Struct.newBuilder() - .putFields("sequence_id", Value.newBuilder().setStringValue(String.valueOf(SEQ_ID)).build()) - .putFields("service_name", Value.newBuilder().setStringValue(SERVICE_NAME).build()) - .putFields("method_name", Value.newBuilder().setStringValue(METHOD_NAME).build()) + .putFields("sequenceId", Value.newBuilder().setStringValue(String.valueOf(SEQ_ID)).build()) + .putFields("serviceName", Value.newBuilder().setStringValue(SERVICE_NAME).build()) + .putFields("methodName", Value.newBuilder().setStringValue(METHOD_NAME).build()) .putFields("authority", Value.newBuilder().setStringValue(AUTHORITY).build()) - .putFields("timeout", Value.newBuilder().setStringValue("1.234s").build()) - .putFields("event_type", Value.newBuilder().setStringValue( - String.valueOf(EventType.GRPC_CALL_REQUEST_HEADER)).build()) - .putFields("event_logger", Value.newBuilder().setStringValue( - String.valueOf(EventLogger.LOGGER_CLIENT)).build()) - .putFields("rpc_id", Value.newBuilder().setStringValue(RPC_ID).build()) + .putFields("payload", Value.newBuilder().setStructValue(struct).build()) + .putFields("type", Value.newBuilder().setStringValue( + String.valueOf(EventType.CLIENT_HEADER)).build()) + .putFields("logger", Value.newBuilder().setStringValue( + String.valueOf(EventLogger.CLIENT)).build()) + .putFields("callId", Value.newBuilder().setStringValue(CALL_ID).build()) .build(); @Mock private Logging mockLogging; @@ -107,7 +110,7 @@ public class GcpLogSinkTest { @SuppressWarnings("unchecked") public void verifyWrite() throws Exception { GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, - CUSTOM_TAGS, FLUSH_LIMIT, Collections.emptySet()); + CUSTOM_TAGS, Collections.emptySet()); sink.write(LOG_PROTO); ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( @@ -125,7 +128,7 @@ public void verifyWrite() throws Exception { @SuppressWarnings("unchecked") public void verifyWriteWithTags() { GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, - CUSTOM_TAGS, FLUSH_LIMIT, Collections.emptySet()); + CUSTOM_TAGS, Collections.emptySet()); MonitoredResource expectedMonitoredResource = GcpLogSink.getResource(LOCATION_TAGS); sink.write(LOG_PROTO); @@ -149,7 +152,7 @@ public void emptyCustomTags_labelsNotSet() { Map emptyCustomTags = null; Map expectedEmptyLabels = new HashMap<>(); GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, - emptyCustomTags, FLUSH_LIMIT, Collections.emptySet()); + emptyCustomTags, Collections.emptySet()); sink.write(LOG_PROTO); ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( @@ -166,11 +169,11 @@ public void emptyCustomTags_labelsNotSet() { @SuppressWarnings("unchecked") public void emptyCustomTags_setSourceProject() { Map emptyCustomTags = null; - String destinationProjectId = "DESTINATION_PROJECT"; + String projectId = "PROJECT"; Map expectedLabels = GcpLogSink.getCustomTags(emptyCustomTags, LOCATION_TAGS, - destinationProjectId); - GcpLogSink sink = new GcpLogSink(mockLogging, destinationProjectId, LOCATION_TAGS, - emptyCustomTags, FLUSH_LIMIT, Collections.emptySet()); + projectId); + GcpLogSink sink = new GcpLogSink(mockLogging, projectId, LOCATION_TAGS, + emptyCustomTags, Collections.emptySet()); sink.write(LOG_PROTO); ArgumentCaptor> logEntrySetCaptor = ArgumentCaptor.forClass( @@ -183,24 +186,10 @@ public void emptyCustomTags_setSourceProject() { } } - @Test - public void verifyFlush() { - long lowerFlushLimit = 2L; - GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, - CUSTOM_TAGS, lowerFlushLimit, Collections.emptySet()); - sink.write(LOG_PROTO); - verify(mockLogging, never()).flush(); - sink.write(LOG_PROTO); - verify(mockLogging, times(1)).flush(); - sink.write(LOG_PROTO); - sink.write(LOG_PROTO); - verify(mockLogging, times(2)).flush(); - } - @Test public void verifyClose() throws Exception { GcpLogSink sink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, - CUSTOM_TAGS, FLUSH_LIMIT, Collections.emptySet()); + CUSTOM_TAGS, Collections.emptySet()); sink.write(LOG_PROTO); verify(mockLogging, times(1)).write(anyIterable()); sink.close(); @@ -211,7 +200,7 @@ public void verifyClose() throws Exception { @Test public void verifyExclude() throws Exception { Sink mockSink = new GcpLogSink(mockLogging, DEST_PROJECT_NAME, LOCATION_TAGS, - CUSTOM_TAGS, FLUSH_LIMIT, Collections.singleton("service")); + CUSTOM_TAGS, Collections.singleton("service")); mockSink.write(LOG_PROTO); verifyNoInteractions(mockLogging); } diff --git a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java index 759af51ca63..df592802fa8 100644 --- a/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java +++ b/googleapis/src/main/java/io/grpc/googleapis/GoogleCloudToProdNameResolver.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Charsets; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.CharStreams; @@ -54,6 +55,7 @@ final class GoogleCloudToProdNameResolver extends NameResolver { @VisibleForTesting static final String METADATA_URL_SUPPORT_IPV6 = "http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/ipv6s"; + static final String C2P_AUTHORITY = "traffic-director-c2p.xds.googleapis.com"; @VisibleForTesting static boolean isOnGcp = InternalCheckGcpEnvironment.isOnGcp(); @VisibleForTesting @@ -62,6 +64,10 @@ final class GoogleCloudToProdNameResolver extends NameResolver { || System.getProperty("io.grpc.xds.bootstrap") != null || System.getenv("GRPC_XDS_BOOTSTRAP_CONFIG") != null || System.getProperty("io.grpc.xds.bootstrapConfig") != null; + @VisibleForTesting + static boolean enableFederation = + !Strings.isNullOrEmpty(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")) + && Boolean.parseBoolean(System.getenv("GRPC_EXPERIMENTAL_XDS_FEDERATION")); private static final String serverUriOverride = System.getenv("GRPC_TEST_ONLY_GOOGLE_C2P_RESOLVER_TRAFFIC_DIRECTOR_URI"); @@ -76,7 +82,10 @@ final class GoogleCloudToProdNameResolver extends NameResolver { private final boolean usingExecutorResource; // It's not possible to use both PSM and DirectPath C2P in the same application. // Delegate to DNS if user-provided bootstrap is found. - private final String schemeOverride = !isOnGcp || xdsBootstrapProvided ? "dns" : "xds"; + private final String schemeOverride = + !isOnGcp + || (xdsBootstrapProvided && !enableFederation) + ? "dns" : "xds"; private Executor executor; private Listener2 listener; private boolean succeeded; @@ -103,8 +112,12 @@ final class GoogleCloudToProdNameResolver extends NameResolver { targetUri); authority = GrpcUtil.checkAuthority(targetPath.substring(1)); syncContext = checkNotNull(args, "args").getSynchronizationContext(); + targetUri = overrideUriScheme(targetUri, schemeOverride); + if (schemeOverride.equals("xds") && enableFederation) { + targetUri = overrideUriAuthority(targetUri, C2P_AUTHORITY); + } delegate = checkNotNull(nameResolverFactory, "nameResolverFactory").newNameResolver( - overrideUriScheme(targetUri, schemeOverride), args); + targetUri, args); executor = args.getOffloadExecutor(); usingExecutorResource = executor == null; } @@ -191,9 +204,14 @@ public void run() { serverBuilder.put("channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default"))); serverBuilder.put("server_features", ImmutableList.of("xds_v3")); + ImmutableMap.Builder authoritiesBuilder = ImmutableMap.builder(); + authoritiesBuilder.put( + C2P_AUTHORITY, + ImmutableMap.of("xds_servers", ImmutableList.of(serverBuilder.buildOrThrow()))); return ImmutableMap.of( "node", nodeBuilder.buildOrThrow(), - "xds_servers", ImmutableList.of(serverBuilder.buildOrThrow())); + "xds_servers", ImmutableList.of(serverBuilder.buildOrThrow()), + "authorities", authoritiesBuilder.buildOrThrow()); } @Override @@ -266,6 +284,16 @@ private static URI overrideUriScheme(URI uri, String scheme) { return res; } + private static URI overrideUriAuthority(URI uri, String authority) { + URI res; + try { + res = new URI(uri.getScheme(), authority, uri.getPath(), uri.getQuery(), uri.getFragment()); + } catch (URISyntaxException ex) { + throw new IllegalArgumentException("Invalid authority: " + authority, ex); + } + return res; + } + private enum HttpConnectionFactory implements HttpConnectionProvider { INSTANCE; diff --git a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java index 562798c0183..a7c4ab059b6 100644 --- a/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java +++ b/googleapis/src/test/java/io/grpc/googleapis/GoogleCloudToProdNameResolverTest.java @@ -66,7 +66,7 @@ public class GoogleCloudToProdNameResolverTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final URI TARGET_URI = URI.create("google-c2p-experimental:///googleapis.com"); + private static final URI TARGET_URI = URI.create("google-c2p:///googleapis.com"); private static final String ZONE = "us-central1-a"; private static final int DEFAULT_PORT = 887; @@ -187,6 +187,40 @@ public void onGcpAndNoProvidedBootstrapDelegateToXds() { "server_uri", "directpath-pa.googleapis.com", "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), "server_features", ImmutableList.of("xds_v3")); + Map authorities = (Map) bootstrap.get("authorities"); + assertThat(authorities).containsExactly( + "traffic-director-c2p.xds.googleapis.com", + ImmutableMap.of("xds_servers", ImmutableList.of(server))); + } + + @SuppressWarnings("unchecked") + @Test + public void onGcpAndProvidedBootstrapAndFederationEnabledDelegateToXds() { + GoogleCloudToProdNameResolver.isOnGcp = true; + GoogleCloudToProdNameResolver.xdsBootstrapProvided = true; + GoogleCloudToProdNameResolver.enableFederation = true; + createResolver(); + resolver.start(mockListener); + fakeExecutor.runDueTasks(); + assertThat(delegatedResolver.keySet()).containsExactly("xds"); + verify(Iterables.getOnlyElement(delegatedResolver.values())).start(mockListener); + // check bootstrap + Map bootstrap = fakeBootstrapSetter.bootstrapRef.get(); + Map node = (Map) bootstrap.get("node"); + assertThat(node).containsExactly( + "id", "C2P-991614323", + "locality", ImmutableMap.of("zone", ZONE), + "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); + Map server = Iterables.getOnlyElement( + (List>) bootstrap.get("xds_servers")); + assertThat(server).containsExactly( + "server_uri", "directpath-pa.googleapis.com", + "channel_creds", ImmutableList.of(ImmutableMap.of("type", "google_default")), + "server_features", ImmutableList.of("xds_v3")); + Map authorities = (Map) bootstrap.get("authorities"); + assertThat(authorities).containsExactly( + "traffic-director-c2p.xds.googleapis.com", + ImmutableMap.of("xds_servers", ImmutableList.of(server))); } @Test diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 3351eba9cb7..37d443cd479 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -8,7 +8,7 @@ guava = "31.1-android" netty = '4.1.79.Final' nettytcnative = '2.0.54.Final' opencensus = "0.31.0" -protobuf = "3.21.1" +protobuf = "3.21.7" [libraries] android-annotations = "com.google.android:annotations:4.1.1.4" diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 65293d24511..14b06a3b57c 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -76,7 +76,7 @@ class GrpclbLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { Attributes attributes = resolvedAddresses.getAttributes(); List newLbAddresses = attributes.get(GrpclbConstants.ATTR_LB_ADDRS); if (newLbAddresses == null) { @@ -85,7 +85,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (newLbAddresses.isEmpty() && resolvedAddresses.getAddresses().isEmpty()) { handleNameResolutionError( Status.UNAVAILABLE.withDescription("No backend or balancer addresses found")); - return; + return false; } List overrideAuthorityLbAddresses = new ArrayList<>(newLbAddresses.size()); @@ -114,6 +114,8 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { } grpclbState.handleAddresses(Collections.unmodifiableList(overrideAuthorityLbAddresses), newBackendServers); + + return true; } @Override diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index faae729a074..49b74645ec8 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -262,7 +262,7 @@ void handleAddresses( List newBackendServers) { logger.log( ChannelLogLevel.DEBUG, - "[grpclb-<{0}>] Resolved addresses: lb addresses {0}, backends: {1}", + "[grpclb-<{0}>] Resolved addresses: lb addresses {1}, backends: {2}", serviceName, newLbAddressGroups, newBackendServers); diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 54ec832297f..66fd850802c 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -2735,7 +2735,7 @@ private void deliverResolvedAddresses( syncContext.execute(new Runnable() { @Override public void run() { - balancer.handleResolvedAddresses( + balancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(backendAddrs) .setAttributes(attrs) diff --git a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java index 25f4f9232cf..13f55226483 100644 --- a/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java +++ b/netty/src/main/java/io/grpc/netty/GrpcHttp2ConnectionHandler.java @@ -34,6 +34,9 @@ */ @Internal public abstract class GrpcHttp2ConnectionHandler extends Http2ConnectionHandler { + static final int ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT = 1024; + static final Cumulator ADAPTIVE_CUMULATOR = + new NettyAdaptiveCumulator(ADAPTIVE_CUMULATOR_COMPOSE_MIN_SIZE_DEFAULT); @Nullable protected final ChannelPromise channelUnused; @@ -48,6 +51,7 @@ protected GrpcHttp2ConnectionHandler( super(decoder, encoder, initialSettings); this.channelUnused = channelUnused; this.negotiationLogger = negotiationLogger; + setCumulator(ADAPTIVE_CUMULATOR); } /** diff --git a/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java b/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java new file mode 100644 index 00000000000..b3a28c55c79 --- /dev/null +++ b/netty/src/main/java/io/grpc/netty/NettyAdaptiveCumulator.java @@ -0,0 +1,224 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.handler.codec.ByteToMessageDecoder.Cumulator; + +class NettyAdaptiveCumulator implements Cumulator { + private final int composeMinSize; + + /** + * "Adaptive" cumulator: cumulate {@link ByteBuf}s by dynamically switching between merge and + * compose strategies. + * + * @param composeMinSize Determines the minimal size of the buffer that should be composed (added + * as a new component of the {@link CompositeByteBuf}). If the total size + * of the last component (tail) and the incoming buffer is below this value, + * the incoming buffer is appended to the tail, and the new component is not + * added. + */ + NettyAdaptiveCumulator(int composeMinSize) { + Preconditions.checkArgument(composeMinSize >= 0, "composeMinSize must be non-negative"); + this.composeMinSize = composeMinSize; + } + + /** + * "Adaptive" cumulator: cumulate {@link ByteBuf}s by dynamically switching between merge and + * compose strategies. + * + *

This cumulator applies a heuristic to make a decision whether to track a reference to the + * buffer with bytes received from the network stack in an array ("zero-copy"), or to merge into + * the last component (the tail) by performing a memory copy. + * + *

It is necessary as a protection from a potential attack on the {@link + * io.netty.handler.codec.ByteToMessageDecoder#COMPOSITE_CUMULATOR}. Consider a pathological case + * when an attacker sends TCP packages containing a single byte of data, and forcing the cumulator + * to track each one in a separate buffer. The cost is memory overhead for each buffer, and extra + * compute to read the cumulation. + * + *

Implemented heuristic establishes a minimal threshold for the total size of the tail and + * incoming buffer, below which they are merged. The sum of the tail and the incoming buffer is + * used to avoid a case where attacker alternates the size of data packets to trick the cumulator + * into always selecting compose strategy. + * + *

Merging strategy attempts to minimize unnecessary memory writes. When possible, it expands + * the tail capacity and only copies the incoming buffer into available memory. Otherwise, when + * both tail and the buffer must be copied, the tail is reallocated (or fully replaced) with a new + * buffer of exponentially increasing capacity (bounded to {@link #composeMinSize}) to ensure + * runtime {@code O(n^2)} is amortized to {@code O(n)}. + */ + @Override + @SuppressWarnings("ReferenceEquality") + public final ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) { + if (!cumulation.isReadable()) { + cumulation.release(); + return in; + } + CompositeByteBuf composite = null; + try { + if (cumulation instanceof CompositeByteBuf && cumulation.refCnt() == 1) { + composite = (CompositeByteBuf) cumulation; + // Writer index must equal capacity if we are going to "write" + // new components to the end + if (composite.writerIndex() != composite.capacity()) { + composite.capacity(composite.writerIndex()); + } + } else { + composite = alloc.compositeBuffer(Integer.MAX_VALUE) + .addFlattenedComponents(true, cumulation); + } + addInput(alloc, composite, in); + in = null; + return composite; + } finally { + if (in != null) { + // We must release if the ownership was not transferred as otherwise it may produce a leak + in.release(); + // Also release any new buffer allocated if we're not returning it + if (composite != null && composite != cumulation) { + composite.release(); + } + } + } + } + + @VisibleForTesting + void addInput(ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + if (shouldCompose(composite, in, composeMinSize)) { + composite.addFlattenedComponents(true, in); + } else { + // The total size of the new data and the last component are below the threshold. Merge them. + mergeWithCompositeTail(alloc, composite, in); + } + } + + @VisibleForTesting + static boolean shouldCompose(CompositeByteBuf composite, ByteBuf in, int composeMinSize) { + int componentCount = composite.numComponents(); + if (composite.numComponents() == 0) { + return true; + } + int inputSize = in.readableBytes(); + int tailStart = composite.toByteIndex(componentCount - 1); + int tailSize = composite.writerIndex() - tailStart; + return tailSize + inputSize >= composeMinSize; + } + + /** + * Append the given {@link ByteBuf} {@code in} to {@link CompositeByteBuf} {@code composite} by + * expanding or replacing the tail component of the {@link CompositeByteBuf}. + * + *

The goal is to prevent {@code O(n^2)} runtime in a pathological case, that forces copying + * the tail component into a new buffer, for each incoming single-byte buffer. We append the new + * bytes to the tail, when a write (or a fast write) is possible. + * + *

Otherwise, the tail is replaced with a new buffer, with the capacity increased enough to + * achieve runtime amortization. + * + *

We assume that implementations of {@link ByteBufAllocator#calculateNewCapacity(int, int)}, + * are similar to {@link io.netty.buffer.AbstractByteBufAllocator#calculateNewCapacity(int, int)}, + * which doubles buffer capacity by normalizing it to the closest power of two. This assumption + * is verified in unit tests for this method. + */ + @VisibleForTesting + static void mergeWithCompositeTail( + ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + int inputSize = in.readableBytes(); + int tailComponentIndex = composite.numComponents() - 1; + int tailStart = composite.toByteIndex(tailComponentIndex); + int tailSize = composite.writerIndex() - tailStart; + int newTailSize = inputSize + tailSize; + ByteBuf tail = composite.component(tailComponentIndex); + ByteBuf newTail = null; + try { + if (tail.refCnt() == 1 && !tail.isReadOnly() && newTailSize <= tail.maxCapacity()) { + // Ideal case: the tail isn't shared, and can be expanded to the required capacity. + // Take ownership of the tail. + newTail = tail.retain(); + + // TODO(https://github.com/netty/netty/issues/12844): remove when we use Netty with + // the issue fixed. + // In certain cases, removing the CompositeByteBuf component, and then adding it back + // isn't idempotent. An example is provided in https://github.com/netty/netty/issues/12844. + // This happens because the buffer returned by composite.component() has out-of-sync + // indexes. Under the hood the CompositeByteBuf returns a duplicate() of the underlying + // buffer, but doesn't set the indexes. + // + // To get the right indexes we use the fact that composite.internalComponent() returns + // the slice() into the readable portion of the underlying buffer. + // We use this implementation detail (internalComponent() returning a *SlicedByteBuf), + // and combine it with the fact that SlicedByteBuf duplicates have their indexes + // adjusted so they correspond to the to the readable portion of the slice. + // + // Hence composite.internalComponent().duplicate() returns a buffer with the + // indexes that should've been on the composite.component() in the first place. + // Until the issue is fixed, we manually adjust the indexes of the removed component. + ByteBuf sliceDuplicate = composite.internalComponent(tailComponentIndex).duplicate(); + newTail.setIndex(sliceDuplicate.readerIndex(), sliceDuplicate.writerIndex()); + + /* + * The tail is a readable non-composite buffer, so writeBytes() handles everything for us. + * + * - ensureWritable() performs a fast resize when possible (f.e. PooledByteBuf simply + * updates its boundary to the end of consecutive memory run assigned to this buffer) + * - when the required size doesn't fit into writableBytes(), a new buffer is + * allocated, and the capacity calculated with alloc.calculateNewCapacity() + * - note that maxFastWritableBytes() would normally allow a fast expansion of PooledByteBuf + * is not called because CompositeByteBuf.component() returns a duplicate, wrapped buffer. + * Unwrapping buffers is unsafe, and potential benefit of fast writes may not be + * as pronounced because the capacity is doubled with each reallocation. + */ + newTail.writeBytes(in); + } else { + // The tail is shared, or not expandable. Replace it with a new buffer of desired capacity. + newTail = alloc.buffer(alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE)); + newTail.setBytes(0, composite, tailStart, tailSize) + .setBytes(tailSize, in, in.readerIndex(), inputSize) + .writerIndex(newTailSize); + in.readerIndex(in.writerIndex()); + } + // Store readerIndex to avoid out of bounds writerIndex during component replacement. + int prevReader = composite.readerIndex(); + // Remove the old tail, reset writer index. + composite.removeComponent(tailComponentIndex).setIndex(0, tailStart); + // Add back the new tail. + composite.addFlattenedComponents(true, newTail); + // New tail's ownership transferred to the composite buf. + newTail = null; + in.release(); + in = null; + // Restore the reader. In case it fails we restore the reader after releasing/forgetting + // the input and the new tail so that finally block can handles them properly. + composite.readerIndex(prevReader); + } finally { + // Input buffer was merged with the tail. + if (in != null) { + in.release(); + } + // If new tail's ownership isn't transferred to the composite buf. + // Release it to prevent a leak. + if (newTail != null) { + newTail.release(); + } + } + } +} diff --git a/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java new file mode 100644 index 00000000000..6a0c00bac0e --- /dev/null +++ b/netty/src/test/java/io/grpc/netty/NettyAdaptiveCumulatorTest.java @@ -0,0 +1,641 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.netty; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; +import static com.google.common.truth.TruthJUnit.assume; +import static io.netty.util.CharsetUtil.US_ASCII; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.runners.Enclosed; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Enclosed.class) +public class NettyAdaptiveCumulatorTest { + + private static Collection cartesianProductParams(List... lists) { + return Lists.cartesianProduct(lists).stream().map(List::toArray).collect(Collectors.toList()); + } + + @RunWith(JUnit4.class) + public static class CumulateTests { + // Represent data as immutable ASCII Strings for easy and readable ByteBuf equality assertions. + private static final String DATA_INITIAL = "0123"; + private static final String DATA_INCOMING = "456789"; + private static final String DATA_CUMULATED = "0123456789"; + + private static final ByteBufAllocator alloc = new UnpooledByteBufAllocator(false); + private NettyAdaptiveCumulator cumulator; + private NettyAdaptiveCumulator throwingCumulator; + private final UnsupportedOperationException throwingCumulatorError = + new UnsupportedOperationException(); + + // Buffers for testing + private final ByteBuf contiguous = ByteBufUtil.writeAscii(alloc, DATA_INITIAL); + private final ByteBuf in = ByteBufUtil.writeAscii(alloc, DATA_INCOMING); + + @Before + public void setUp() { + cumulator = new NettyAdaptiveCumulator(0) { + @Override + void addInput(ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + // To limit the testing scope to NettyAdaptiveCumulator.cumulate(), always compose + composite.addFlattenedComponents(true, in); + } + }; + + // Throws an error on adding incoming buffer. + throwingCumulator = new NettyAdaptiveCumulator(0) { + @Override + void addInput(ByteBufAllocator alloc, CompositeByteBuf composite, ByteBuf in) { + throw throwingCumulatorError; + } + }; + } + + @Test + public void cumulate_notReadableCumulation_replacedWithInputAndReleased() { + contiguous.readerIndex(contiguous.writerIndex()); + assertFalse(contiguous.isReadable()); + ByteBuf cumulation = cumulator.cumulate(alloc, contiguous, in); + assertEquals(DATA_INCOMING, cumulation.toString(US_ASCII)); + assertEquals(0, contiguous.refCnt()); + // In retained by cumulation. + assertEquals(1, in.refCnt()); + assertEquals(1, cumulation.refCnt()); + cumulation.release(); + } + + @Test + public void cumulate_contiguousCumulation_newCompositeFromContiguousAndInput() { + CompositeByteBuf cumulation = (CompositeByteBuf) cumulator.cumulate(alloc, contiguous, in); + assertEquals(DATA_INITIAL, cumulation.component(0).toString(US_ASCII)); + assertEquals(DATA_INCOMING, cumulation.component(1).toString(US_ASCII)); + assertEquals(DATA_CUMULATED, cumulation.toString(US_ASCII)); + // Both in and contiguous are retained by cumulation. + assertEquals(1, contiguous.refCnt()); + assertEquals(1, in.refCnt()); + assertEquals(1, cumulation.refCnt()); + cumulation.release(); + } + + @Test + public void cumulate_compositeCumulation_inputAppendedAsANewComponent() { + CompositeByteBuf composite = alloc.compositeBuffer().addComponent(true, contiguous); + assertSame(composite, cumulator.cumulate(alloc, composite, in)); + assertEquals(DATA_INITIAL, composite.component(0).toString(US_ASCII)); + assertEquals(DATA_INCOMING, composite.component(1).toString(US_ASCII)); + assertEquals(DATA_CUMULATED, composite.toString(US_ASCII)); + // Both in and contiguous are retained by cumulation. + assertEquals(1, contiguous.refCnt()); + assertEquals(1, in.refCnt()); + assertEquals(1, composite.refCnt()); + composite.release(); + } + + @Test + public void cumulate_compositeCumulation_inputReleasedOnError() { + CompositeByteBuf composite = alloc.compositeBuffer().addComponent(true, contiguous); + try { + throwingCumulator.cumulate(alloc, composite, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(throwingCumulatorError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // Initial composite cumulation owned by the caller in this case, so it isn't released. + assertEquals(1, composite.refCnt()); + // Contiguous still managed by the cumulation + assertEquals(1, contiguous.refCnt()); + } finally { + composite.release(); + } + } + + @Test + public void cumulate_contiguousCumulation_inputAndNewCompositeReleasedOnError() { + // Return our instance of new composite to ensure it's released. + CompositeByteBuf newComposite = alloc.compositeBuffer(Integer.MAX_VALUE); + ByteBufAllocator mockAlloc = mock(ByteBufAllocator.class); + when(mockAlloc.compositeBuffer(anyInt())).thenReturn(newComposite); + + try { + // Previous cumulation is non-composite, so cumulator will create anew composite and add + // both buffers to it. + throwingCumulator.cumulate(mockAlloc, contiguous, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(throwingCumulatorError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // New composite cumulation hasn't been returned to the caller, so it must be released. + assertEquals(0, newComposite.refCnt()); + // Previous cumulation released because it was owned by the new composite cumulation. + assertEquals(0, contiguous.refCnt()); + } + } + } + + @RunWith(Parameterized.class) + public static class ShouldComposeTests { + // Represent data as immutable ASCII Strings for easy and readable ByteBuf equality assertions. + private static final String DATA_INITIAL = "0123"; + private static final String DATA_INCOMING = "456789"; + + /** + * Cartesian product of the test values. + */ + @Parameters(name = "composeMinSize={0}, tailData=\"{1}\", inData=\"{2}\"") + public static Collection params() { + List composeMinSize = ImmutableList.of(0, 9, 10, 11, Integer.MAX_VALUE); + List tailData = ImmutableList.of("", DATA_INITIAL); + List inData = ImmutableList.of("", DATA_INCOMING); + return cartesianProductParams(composeMinSize, tailData, inData); + } + + @Parameter public int composeMinSize; + @Parameter(1) public String tailData; + @Parameter(2) public String inData; + + private CompositeByteBuf composite; + private ByteBuf tail; + private ByteBuf in; + + @Before + public void setUp() { + ByteBufAllocator alloc = new UnpooledByteBufAllocator(false); + in = ByteBufUtil.writeAscii(alloc, inData); + tail = ByteBufUtil.writeAscii(alloc, tailData); + composite = alloc.compositeBuffer(Integer.MAX_VALUE); + // Note that addFlattenedComponents() will not add a new component when tail is not readable. + composite.addFlattenedComponents(true, tail); + } + + @After + public void tearDown() { + in.release(); + composite.release(); + } + + @Test + public void shouldCompose_emptyComposite() { + assume().that(composite.numComponents()).isEqualTo(0); + assertTrue(NettyAdaptiveCumulator.shouldCompose(composite, in, composeMinSize)); + } + + @Test + public void shouldCompose_composeMinSizeReached() { + assume().that(composite.numComponents()).isGreaterThan(0); + assume().that(tail.readableBytes() + in.readableBytes()).isAtLeast(composeMinSize); + assertTrue(NettyAdaptiveCumulator.shouldCompose(composite, in, composeMinSize)); + } + + @Test + public void shouldCompose_composeMinSizeNotReached() { + assume().that(composite.numComponents()).isGreaterThan(0); + assume().that(tail.readableBytes() + in.readableBytes()).isLessThan(composeMinSize); + assertFalse(NettyAdaptiveCumulator.shouldCompose(composite, in, composeMinSize)); + } + } + + @RunWith(Parameterized.class) + public static class MergeWithCompositeTailTests { + private static final String INCOMING_DATA_READABLE = "+incoming"; + private static final String INCOMING_DATA_DISCARDABLE = "discard"; + + private static final String TAIL_DATA_DISCARDABLE = "---"; + private static final String TAIL_DATA_READABLE = "tail"; + private static final String TAIL_DATA = TAIL_DATA_DISCARDABLE + TAIL_DATA_READABLE; + private static final int TAIL_READER_INDEX = TAIL_DATA_DISCARDABLE.length(); + private static final int TAIL_MAX_CAPACITY = 128; + + // DRY sacrificed to improve readability. + private static final String EXPECTED_TAIL_DATA = "tail+incoming"; + + /** + * Cartesian product of the test values. + * + *

Test cases when the cumulation contains components, other than tail, and could be + * partially read. This is needed to verify the correctness if reader and writer indexes of the + * composite cumulation after the merge. + */ + @Parameters(name = "compositeHeadData=\"{0}\", compositeReaderIndex={1}") + public static Collection params() { + String headData = "head"; + + List compositeHeadData = ImmutableList.of( + // Test without the "head" component. Empty string is equivalent of fully read buffer, + // so it's not added to the composite byte buf. The tail is added as the first component. + "", + // Test with the "head" component, so the tail is added as the second component. + headData + ); + + // After the tail is added to the composite cumulator, advance the reader index to + // cover different cases. + // The reader index only looks at what's readable in the composite byte buf, so + // discardable bytes of head and tail doesn't count. + List compositeReaderIndex = ImmutableList.of( + // Reader in the beginning + 0, + // Within the head (when present) or the tail + headData.length() - 2, + // Within the tail, even if the head is present + headData.length() + 2 + ); + return cartesianProductParams(compositeHeadData, compositeReaderIndex); + } + + @Parameter public String compositeHeadData; + @Parameter(1) public int compositeReaderIndex; + + // Use pooled allocator to have maxFastWritableBytes() behave differently than writableBytes(). + private final ByteBufAllocator alloc = new PooledByteBufAllocator(); + + // Composite buffer to be used in tests. + private CompositeByteBuf composite; + private ByteBuf tail; + private ByteBuf in; + + @Before + public void setUp() { + composite = alloc.compositeBuffer(); + + // The "head" component. It represents existing data in the cumulator. + // Note that addFlattenedComponents() does not add completely read buffer, which covers + // the case when compositeHeadData parameter is an empty string. + ByteBuf head = alloc.buffer().writeBytes(compositeHeadData.getBytes(US_ASCII)); + composite.addFlattenedComponents(true, head); + + // The "tail" component. It also represents existing data in the cumulator, but it's + // not added to the cumulator during setUp() stage. It is to be manipulated by tests to + // produce different buffer write scenarios based on different tail's capacity. + // After tail is changes for each test scenario, it's added to the composite buffer. + // + // The default state of the tail before each test: tail is full, but expandable (the data uses + // all initial capacity, but not maximum capacity). + // Tail data and indexes: + // ----tail + // r w + tail = alloc.buffer(TAIL_DATA.length(), TAIL_MAX_CAPACITY) + .writeBytes(TAIL_DATA.getBytes(US_ASCII)) + .readerIndex(TAIL_READER_INDEX); + + // Incoming data and indexes: + // discard+incoming + // r w + in = alloc.buffer() + .writeBytes(INCOMING_DATA_DISCARDABLE.getBytes(US_ASCII)) + .writeBytes(INCOMING_DATA_READABLE.getBytes(US_ASCII)) + .readerIndex(INCOMING_DATA_DISCARDABLE.length()); + } + + @After + public void tearDown() { + composite.release(); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_write() { + // Make incoming data fit into tail capacity. + int fitCapacity = tail.capacity() + INCOMING_DATA_READABLE.length(); + tail.capacity(fitCapacity); + // Confirm it fits. + assertThat(in.readableBytes()).isAtMost(tail.writableBytes()); + + // All fits, so tail capacity must stay the same. + composite.addFlattenedComponents(true, tail); + assertTailExpanded(EXPECTED_TAIL_DATA, fitCapacity); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_fastWrite() { + // Confirm that the tail can be expanded fast to fit the incoming data. + assertThat(in.readableBytes()).isAtMost(tail.maxFastWritableBytes()); + + // To avoid undesirable buffer unwrapping, at the moment adaptive cumulator is set not + // apply fastWrite technique. Even when fast write is possible, it will fall back to + // reallocating a larger buffer. + // int tailFastCapacity = tail.writerIndex() + tail.maxFastWritableBytes(); + int tailFastCapacity = + alloc.calculateNewCapacity(EXPECTED_TAIL_DATA.length(), Integer.MAX_VALUE); + + // Tail capacity is extended to its fast capacity. + composite.addFlattenedComponents(true, tail); + assertTailExpanded(EXPECTED_TAIL_DATA, tailFastCapacity); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_reallocateInMemory() { + int tailFastCapacity = tail.writerIndex() + tail.maxFastWritableBytes(); + String inSuffixOverFastBytes = Strings.repeat("a", tailFastCapacity + 1); + int newTailSize = tail.readableBytes() + inSuffixOverFastBytes.length(); + composite.addFlattenedComponents(true, tail); + + // Make input larger than tailFastCapacity + in.writeCharSequence(inSuffixOverFastBytes, US_ASCII); + // Confirm that the tail can only fit incoming data via reallocation. + assertThat(in.readableBytes()).isGreaterThan(tail.maxFastWritableBytes()); + assertThat(in.readableBytes()).isAtMost(tail.maxWritableBytes()); + + // Confirm the assumption that new capacity is produced by alloc.calculateNewCapacity(). + int expectedTailCapacity = alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE); + assertTailExpanded(EXPECTED_TAIL_DATA.concat(inSuffixOverFastBytes), expectedTailCapacity); + } + + private void assertTailExpanded(String expectedTailReadableData, int expectedNewTailCapacity) { + int originalNumComponents = composite.numComponents(); + + // Handle the case when reader index is beyond all readable bytes of the cumulation. + int compositeReaderIndexBounded = Math.min(compositeReaderIndex, composite.writerIndex()); + composite.readerIndex(compositeReaderIndexBounded); + + // Execute the merge logic. + NettyAdaptiveCumulator.mergeWithCompositeTail(alloc, composite, in); + + // Composite component count shouldn't change. + assertWithMessage( + "When tail is expanded, the number of components in the cumulation must not change") + .that(composite.numComponents()).isEqualTo(originalNumComponents); + + ByteBuf newTail = composite.component(composite.numComponents() - 1); + + // Verify the readable part of the expanded tail: + // 1. Initial readable bytes of the tail not changed + // 2. Discardable bytes (0 < discardable < readerIndex) of the incoming buffer are discarded. + // 3. Readable bytes of the incoming buffer are fully read and appended to the tail. + assertEquals(expectedTailReadableData, newTail.toString(US_ASCII)); + // Verify expanded capacity. + assertEquals(expectedNewTailCapacity, newTail.capacity()); + + // Discardable bytes (0 < discardable < readerIndex) of the tail are kept as is. + String newTailDataDiscardable = newTail.toString(0, newTail.readerIndex(), US_ASCII); + assertWithMessage("After tail expansion, its discardable bytes should be unchanged") + .that(newTailDataDiscardable).isEqualTo(TAIL_DATA_DISCARDABLE); + + // Reader index must stay where it was + assertEquals(TAIL_READER_INDEX, newTail.readerIndex()); + // Writer index at the end + assertEquals(TAIL_READER_INDEX + expectedTailReadableData.length(), + newTail.writerIndex()); + + // Verify resulting cumulation. + assertExpectedCumulation(newTail, expectedTailReadableData, compositeReaderIndexBounded); + + // Verify incoming buffer. + assertWithMessage("Incoming buffer is fully read").that(in.isReadable()).isFalse(); + assertWithMessage("Incoming buffer is released").that(in.refCnt()).isEqualTo(0); + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_maxCapacityReached() { + // Fill in tail to the maxCapacity. + String tailSuffixFullCapacity = Strings.repeat("a", tail.maxWritableBytes()); + tail.writeCharSequence(tailSuffixFullCapacity, US_ASCII); + composite.addFlattenedComponents(true, tail); + assertTailReplaced(); + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_shared() { + tail.retain(); + composite.addFlattenedComponents(true, tail); + assertTailReplaced(); + tail.release(); + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_readOnly() { + composite.addFlattenedComponents(true, tail.asReadOnly()); + assertTailReplaced(); + } + + private void assertTailReplaced() { + int cumulationOriginalComponentsNum = composite.numComponents(); + int taiOriginalRefCount = tail.refCnt(); + String expectedTailReadable = tail.toString(US_ASCII) + in.toString(US_ASCII); + int expectedReallocatedTailCapacity = alloc + .calculateNewCapacity(expectedTailReadable.length(), Integer.MAX_VALUE); + + int compositeReaderIndexBounded = Math.min(compositeReaderIndex, composite.writerIndex()); + composite.readerIndex(compositeReaderIndexBounded); + NettyAdaptiveCumulator.mergeWithCompositeTail(alloc, composite, in); + + // Composite component count shouldn't change. + assertEquals(cumulationOriginalComponentsNum, composite.numComponents()); + ByteBuf replacedTail = composite.component(composite.numComponents() - 1); + + // Verify the readable part of the expanded tail: + // 1. Discardable bytes (0 < discardable < readerIndex) of the tail are discarded. + // 2. Readable bytes of the tail are kept as is + // 3. Discardable bytes (0 < discardable < readerIndex) of the incoming buffer are discarded. + // 4. Readable bytes of the incoming buffer are fully read and appended to the tail. + assertEquals(0, in.readableBytes()); + assertEquals(expectedTailReadable, replacedTail.toString(US_ASCII)); + + // Since tail discardable bytes are discarded, new reader index must be reset to 0. + assertEquals(0, replacedTail.readerIndex()); + // And new writer index at the new data's length. + assertEquals(expectedTailReadable.length(), replacedTail.writerIndex()); + // Verify the capacity of reallocated tail. + assertEquals(expectedReallocatedTailCapacity, replacedTail.capacity()); + + // Verify resulting cumulation. + assertExpectedCumulation(replacedTail, expectedTailReadable, compositeReaderIndexBounded); + + // Verify incoming buffer. + assertWithMessage("Incoming buffer is fully read").that(in.isReadable()).isFalse(); + assertWithMessage("Incoming buffer is released").that(in.refCnt()).isEqualTo(0); + + // The old tail must be released once (have one less reference). + assertWithMessage("Replaced tail released once.") + .that(tail.refCnt()).isEqualTo(taiOriginalRefCount - 1); + } + + private void assertExpectedCumulation( + ByteBuf newTail, String expectedTailReadable, int expectedReaderIndex) { + // Verify the readable part of the cumulation: + // 1. Readable composite head (initial) data + // 2. Readable part of the tail + // 3. Readable part of the incoming data + String expectedCumulationData = + compositeHeadData.concat(expectedTailReadable).substring(expectedReaderIndex); + assertEquals(expectedCumulationData, composite.toString(US_ASCII)); + + // Cumulation capacity includes: + // 1. Full composite head, including discardable bytes + // 2. Expanded tail readable bytes + int expectedCumulationCapacity = compositeHeadData.length() + expectedTailReadable.length(); + assertEquals(expectedCumulationCapacity, composite.capacity()); + + // Composite Reader index must stay where it was. + assertEquals(expectedReaderIndex, composite.readerIndex()); + // Composite writer index must be at the end. + assertEquals(expectedCumulationCapacity, composite.writerIndex()); + + // Composite cumulation is retained and owns the new tail. + assertEquals(1, composite.refCnt()); + assertEquals(1, newTail.refCnt()); + } + + @Test + public void mergeWithCompositeTail_tailExpandable_mergedReleaseOnThrow() { + final UnsupportedOperationException expectedError = new UnsupportedOperationException(); + CompositeByteBuf compositeThrows = new CompositeByteBuf(alloc, false, Integer.MAX_VALUE, + tail) { + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, + ByteBuf buffer) { + throw expectedError; + } + }; + + try { + NettyAdaptiveCumulator.mergeWithCompositeTail(alloc, compositeThrows, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(expectedError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // Tail released + assertEquals(0, tail.refCnt()); + // Composite cumulation is retained + assertEquals(1, compositeThrows.refCnt()); + // Composite cumulation loses the tail + assertEquals(0, compositeThrows.numComponents()); + } finally { + compositeThrows.release(); + } + } + + @Test + public void mergeWithCompositeTail_tailNotExpandable_mergedReleaseOnThrow() { + final UnsupportedOperationException expectedError = new UnsupportedOperationException(); + CompositeByteBuf compositeRo = new CompositeByteBuf(alloc, false, Integer.MAX_VALUE, + tail.asReadOnly()) { + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, + ByteBuf buffer) { + throw expectedError; + } + }; + + // Return our instance of the new buffer to ensure it's released. + int newTailSize = tail.readableBytes() + in.readableBytes(); + ByteBuf newTail = alloc.buffer(alloc.calculateNewCapacity(newTailSize, Integer.MAX_VALUE)); + ByteBufAllocator mockAlloc = mock(ByteBufAllocator.class); + when(mockAlloc.buffer(anyInt())).thenReturn(newTail); + + try { + NettyAdaptiveCumulator.mergeWithCompositeTail(mockAlloc, compositeRo, in); + fail("Cumulator didn't throw"); + } catch (UnsupportedOperationException actualError) { + assertSame(expectedError, actualError); + // Input must be released unless its ownership has been to the composite cumulation. + assertEquals(0, in.refCnt()); + // New buffer released + assertEquals(0, newTail.refCnt()); + // Composite cumulation is retained + assertEquals(1, compositeRo.refCnt()); + // Composite cumulation loses the tail + assertEquals(0, compositeRo.numComponents()); + } finally { + compositeRo.release(); + } + } + } + + /** + * Miscellaneous tests for {@link NettyAdaptiveCumulator#mergeWithCompositeTail} that don't + * fit into {@link MergeWithCompositeTailTests}, and require custom-crafted scenarios. + */ + @RunWith(JUnit4.class) + public static class MergeWithCompositeTailMiscTests { + private final ByteBufAllocator alloc = new PooledByteBufAllocator(); + + /** + * Test the issue with {@link CompositeByteBuf#component(int)} returning a ByteBuf with + * the indexes out-of-sync with {@code CompositeByteBuf.Component} offsets. + */ + @Test + public void mergeWithCompositeTail_outOfSyncComposite() { + NettyAdaptiveCumulator cumulator = new NettyAdaptiveCumulator(1024); + + // Create underlying buffer spacious enough for the test data. + ByteBuf buf = alloc.buffer(32).writeBytes("---01234".getBytes(US_ASCII)); + + // Start with a regular cumulation and add the buf as the only component. + CompositeByteBuf composite1 = alloc.compositeBuffer(8).addFlattenedComponents(true, buf); + // Read composite1 buf to the beginning of the numbers. + assertThat(composite1.readCharSequence(3, US_ASCII).toString()).isEqualTo("---"); + + // Wrap composite1 into another cumulation. This is similar to + // what NettyAdaptiveCumulator.cumulate() does in the case the cumulation has refCnt != 1. + CompositeByteBuf composite2 = + alloc.compositeBuffer(8).addFlattenedComponents(true, composite1); + assertThat(composite2.toString(US_ASCII)).isEqualTo("01234"); + + // The previous operation does not adjust the read indexes of the underlying buffers, + // only the internal Component offsets. When the cumulator attempts to append the input to + // the tail buffer, it extracts it from the cumulation, writes to it, and then adds it back. + // Because the readerIndex on the tail buffer is not adjusted during the read operation + // on the CompositeByteBuf, adding the tail back results in the discarded bytes of the tail + // to be added back to the cumulator as if they were never read. + // + // If the reader index of the tail is not manually corrected, the resulting + // cumulation will contain the discarded part of the tail: "---". + // If it's corrected, it will only contain the numbers. + CompositeByteBuf cumulation = (CompositeByteBuf) cumulator.cumulate(alloc, composite2, + ByteBufUtil.writeAscii(alloc, "56789")); + assertThat(cumulation.toString(US_ASCII)).isEqualTo("0123456789"); + + // Correctness check: we still have a single component, and this component is still the + // original underlying buffer. + assertThat(cumulation.numComponents()).isEqualTo(1); + // Replace '2' with '*', and '8' with '$'. + buf.setByte(5, '*').setByte(11, '$'); + assertThat(cumulation.toString(US_ASCII)).isEqualTo("01*34567$9"); + } + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 81af4f30d48..02b1c3254fa 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -221,6 +221,9 @@ protected void handleNotInUse() { @Nullable final HttpConnectProxiedSocketAddress proxiedAddr; + @VisibleForTesting + int proxySocketTimeout = 30000; + // The following fields should only be used for test. Runnable connectingCallback; SettableFuture connectedFuture; @@ -626,8 +629,8 @@ private void sendConnectionPrefaceAndSettings() { private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddress proxyAddress, String proxyUsername, String proxyPassword) throws StatusException { + Socket sock = null; try { - Socket sock; // The proxy address may not be resolved if (proxyAddress.getAddress() != null) { sock = socketFactory.createSocket(proxyAddress.getAddress(), proxyAddress.getPort()); @@ -636,6 +639,9 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres socketFactory.createSocket(proxyAddress.getHostName(), proxyAddress.getPort()); } sock.setTcpNoDelay(true); + // A socket timeout is needed because lost network connectivity while reading from the proxy, + // can cause reading from the socket to hang. + sock.setSoTimeout(proxySocketTimeout); Source source = Okio.source(sock); BufferedSink sink = Okio.buffer(Okio.sink(sock)); @@ -682,8 +688,13 @@ private Socket createHttpProxySocket(InetSocketAddress address, InetSocketAddres statusLine.code, statusLine.message, body.readUtf8()); throw Status.UNAVAILABLE.withDescription(message).asException(); } + // As the socket will be used for RPCs from here on, we want the socket timeout back to zero. + sock.setSoTimeout(0); return sock; } catch (IOException e) { + if (sock != null) { + GrpcUtil.closeQuietly(sock); + } throw Status.UNAVAILABLE.withDescription("Failed trying to connect with proxy").withCause(e) .asException(); } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index d3ea82894b0..f0e8bf41ff9 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -68,6 +68,9 @@ public final class OkHttpServerBuilder extends ForwardingServerBuilder DEFAULT_TRANSPORT_EXECUTOR_POOL = @@ -120,6 +123,8 @@ public static OkHttpServerBuilder forPort(SocketAddress address, ServerCredentia long maxConnectionIdleInNanos = MAX_CONNECTION_IDLE_NANOS_DISABLED; boolean permitKeepAliveWithoutCalls; long permitKeepAliveTimeInNanos = TimeUnit.MINUTES.toNanos(5); + long maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED; + long maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; @VisibleForTesting OkHttpServerBuilder( @@ -209,6 +214,45 @@ public OkHttpServerBuilder maxConnectionIdle(long maxConnectionIdle, TimeUnit ti return this; } + /** + * Sets a custom max connection age, connection lasting longer than which will be gracefully + * terminated. An unreasonably small value might be increased. A random jitter of +/-10% will be + * added to it. {@code Long.MAX_VALUE} nano seconds or an unreasonably large value will disable + * max connection age. + */ + @Override + public OkHttpServerBuilder maxConnectionAge(long maxConnectionAge, TimeUnit timeUnit) { + checkArgument(maxConnectionAge > 0L, "max connection age must be positive: %s", + maxConnectionAge); + maxConnectionAgeInNanos = timeUnit.toNanos(maxConnectionAge); + if (maxConnectionAgeInNanos >= AS_LARGE_AS_INFINITE) { + maxConnectionAgeInNanos = MAX_CONNECTION_AGE_NANOS_DISABLED; + } + if (maxConnectionAgeInNanos < MIN_MAX_CONNECTION_AGE_NANO) { + maxConnectionAgeInNanos = MIN_MAX_CONNECTION_AGE_NANO; + } + return this; + } + + /** + * Sets a custom grace time for the graceful connection termination. Once the max connection age + * is reached, RPCs have the grace time to complete. RPCs that do not complete in time will be + * cancelled, allowing the connection to terminate. {@code Long.MAX_VALUE} nano seconds or an + * unreasonably large value are considered infinite. + * + * @see #maxConnectionAge(long, TimeUnit) + */ + @Override + public OkHttpServerBuilder maxConnectionAgeGrace(long maxConnectionAgeGrace, TimeUnit timeUnit) { + checkArgument(maxConnectionAgeGrace >= 0L, "max connection age grace must be non-negative: %s", + maxConnectionAgeGrace); + maxConnectionAgeGraceInNanos = timeUnit.toNanos(maxConnectionAgeGrace); + if (maxConnectionAgeGraceInNanos >= AS_LARGE_AS_INFINITE) { + maxConnectionAgeGraceInNanos = MAX_CONNECTION_AGE_GRACE_NANOS_INFINITE; + } + return this; + } + /** * Sets a time waiting for read activity after sending a keepalive ping. If the time expires * without any read activity on the connection, the connection is considered dead. An unreasonably diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java index f6099bec17a..1fd98079ede 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java @@ -16,6 +16,7 @@ package io.grpc.okhttp; +import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_AGE_NANOS_DISABLED; import static io.grpc.okhttp.OkHttpServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; import com.google.common.base.Preconditions; @@ -31,6 +32,7 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.KeepAliveEnforcer; import io.grpc.internal.KeepAliveManager; +import io.grpc.internal.LogExceptionRunnable; import io.grpc.internal.MaxConnectionIdleManager; import io.grpc.internal.ObjectPool; import io.grpc.internal.SerializingExecutor; @@ -96,6 +98,7 @@ final class OkHttpServerTransport implements ServerTransport, private Attributes attributes; private KeepAliveManager keepAliveManager; private MaxConnectionIdleManager maxConnectionIdleManager; + private ScheduledFuture maxConnectionAgeMonitor; private final KeepAliveEnforcer keepAliveEnforcer; private final Object lock = new Object(); @@ -223,6 +226,15 @@ public void data(boolean outFinished, int streamId, Buffer source, int byteCount maxConnectionIdleManager.start(this::shutdown, scheduledExecutorService); } + if (config.maxConnectionAgeInNanos != MAX_CONNECTION_AGE_NANOS_DISABLED) { + long maxConnectionAgeInNanos = + (long) ((.9D + Math.random() * .2D) * config.maxConnectionAgeInNanos); + maxConnectionAgeMonitor = scheduledExecutorService.schedule( + new LogExceptionRunnable(() -> shutdown(config.maxConnectionAgeGraceInNanos)), + maxConnectionAgeInNanos, + TimeUnit.NANOSECONDS); + } + transportExecutor.execute( new FrameHandler(variant.newReader(Okio.buffer(Okio.source(socket)), false))); } catch (Error | IOException | RuntimeException ex) { @@ -238,6 +250,10 @@ public void data(boolean outFinished, int streamId, Buffer source, int byteCount @Override public void shutdown() { + shutdown(TimeUnit.SECONDS.toNanos(1L)); + } + + private void shutdown(Long graceTimeInNanos) { synchronized (lock) { if (gracefulShutdown || abruptShutdown) { return; @@ -251,7 +267,7 @@ public void shutdown() { // we also set a timer to limit the upper bound in case the PING is excessively stalled or // the client is malicious. secondGoawayTimer = scheduledExecutorService.schedule( - this::triggerGracefulSecondGoaway, 1, TimeUnit.SECONDS); + this::triggerGracefulSecondGoaway, graceTimeInNanos, TimeUnit.NANOSECONDS); frameWriter.goAway(Integer.MAX_VALUE, ErrorCode.NO_ERROR, new byte[0]); frameWriter.ping(false, 0, GRACEFUL_SHUTDOWN_PING); frameWriter.flush(); @@ -348,6 +364,10 @@ private void terminated() { if (maxConnectionIdleManager != null) { maxConnectionIdleManager.onTransportTermination(); } + + if (maxConnectionAgeMonitor != null) { + maxConnectionAgeMonitor.cancel(false); + } transportExecutor = config.transportExecutorPool.returnObject(transportExecutor); scheduledExecutorService = config.scheduledExecutorServicePool.returnObject(scheduledExecutorService); @@ -479,6 +499,8 @@ static final class Config { final long maxConnectionIdleNanos; final boolean permitKeepAliveWithoutCalls; final long permitKeepAliveTimeInNanos; + final long maxConnectionAgeInNanos; + final long maxConnectionAgeGraceInNanos; public Config( OkHttpServerBuilder builder, @@ -501,6 +523,8 @@ public Config( maxConnectionIdleNanos = builder.maxConnectionIdleInNanos; permitKeepAliveWithoutCalls = builder.permitKeepAliveWithoutCalls; permitKeepAliveTimeInNanos = builder.permitKeepAliveTimeInNanos; + maxConnectionAgeInNanos = builder.maxConnectionAgeInNanos; + maxConnectionAgeGraceInNanos = builder.maxConnectionAgeGraceInNanos; } } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index e632a6c2946..fcc5e0d2381 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -1877,6 +1877,37 @@ public void proxy_immediateServerClose() throws Exception { verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); } + @Test + public void proxy_serverHangs() throws Exception { + ServerSocket serverSocket = new ServerSocket(0); + InetSocketAddress targetAddress = InetSocketAddress.createUnresolved("theservice", 80); + clientTransport = new OkHttpClientTransport( + channelBuilder.buildTransportFactory(), + targetAddress, + "authority", + "userAgent", + EAG_ATTRS, + HttpConnectProxiedSocketAddress.newBuilder() + .setTargetAddress(targetAddress) + .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) + .build(), + tooManyPingsRunnable); + clientTransport.proxySocketTimeout = 10; + clientTransport.start(transportListener); + + Socket sock = serverSocket.accept(); + serverSocket.close(); + + BufferedReader reader = new BufferedReader(new InputStreamReader(sock.getInputStream(), UTF_8)); + assertEquals("CONNECT theservice:80 HTTP/1.1", reader.readLine()); + assertEquals("Host: theservice:80", reader.readLine()); + while (!"".equals(reader.readLine())) {} + + verify(transportListener, timeout(200)).transportShutdown(any(Status.class)); + verify(transportListener, timeout(TIME_OUT_MS)).transportTerminated(); + sock.close(); + } + @Test public void goAway_notUtf8() throws Exception { initTransport(); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java index a52045011ae..af9b7c12d54 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java @@ -152,6 +152,27 @@ public void startThenShutdown() throws Exception { shutdownAndTerminate(/*lastStreamId=*/ 0); } + @Test + public void maxConnectionAge() throws Exception { + serverBuilder.maxConnectionAge(5, TimeUnit.SECONDS) + .maxConnectionAgeGrace(1, TimeUnit.SECONDS); + initTransport(); + handshake(); + clientFrameWriter.headers(1, Arrays.asList( + HTTP_SCHEME_HEADER, + METHOD_HEADER, + new Header(Header.TARGET_AUTHORITY, "example.com:80"), + new Header(Header.TARGET_PATH, "/com.example/SimpleService.doit"), + CONTENT_TYPE_HEADER, + TE_HEADER)); + clientFrameWriter.synStream(true, false, 1, -1, Arrays.asList( + new Header("some-client-sent-trailer", "trailer-value"))); + pingPong(); + fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(6)); // > 1.1 * 5 + fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(1)); + verifyGracefulShutdown(1); + } + @Test public void maxConnectionIdleTimer() throws Exception { initTransport(); diff --git a/repositories.bzl b/repositories.bzl index cdf6d69e23e..cced1d29bee 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -150,18 +150,18 @@ def com_google_protobuf(): # This statement defines the @com_google_protobuf repo. http_archive( name = "com_google_protobuf", - sha256 = "2d9084d3dd13b86ca2e811d2331f780eb86f6d7cb02b405426e3c80dcbfabf25", - strip_prefix = "protobuf-3.21.1", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.1.zip"], + sha256 = "c72840a5081484c4ac20789ea5bb5d5de6bc7c477ad76e7109fda2bc4e630fe6", + strip_prefix = "protobuf-3.21.7", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.7.zip"], ) def com_google_protobuf_javalite(): # java_lite_proto_library rules implicitly depend on @com_google_protobuf_javalite http_archive( name = "com_google_protobuf_javalite", - sha256 = "2d9084d3dd13b86ca2e811d2331f780eb86f6d7cb02b405426e3c80dcbfabf25", - strip_prefix = "protobuf-3.21.1", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.1.zip"], + sha256 = "c72840a5081484c4ac20789ea5bb5d5de6bc7c477ad76e7109fda2bc4e630fe6", + strip_prefix = "protobuf-3.21.7", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.21.7.zip"], ) def io_grpc_grpc_proto(): diff --git a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java index 2aac96cadcf..5d4e749087d 100644 --- a/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java +++ b/rls/src/main/java/io/grpc/rls/RlsLoadBalancer.java @@ -49,7 +49,7 @@ final class RlsLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(ChannelLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); LbPolicyConfiguration lbPolicyConfiguration = (LbPolicyConfiguration) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -78,6 +78,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { // not required. this.lbPolicyConfiguration = lbPolicyConfiguration; } + return true; } @Override diff --git a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java index b469ee6fe32..9f95200d503 100644 --- a/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsLoadBalancerTest.java @@ -445,7 +445,7 @@ private void deliverResolvedAddresses() throws Exception { ConfigOrError parsedConfigOrError = provider.parseLoadBalancingPolicyConfig(getServiceConfig()); assertThat(parsedConfigOrError.getConfig()).isNotNull(); - rlsLb.handleResolvedAddresses(ResolvedAddresses.newBuilder() + rlsLb.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(ImmutableList.of(new EquivalentAddressGroup(mock(SocketAddress.class)))) .setLoadBalancingPolicyConfig(parsedConfigOrError.getConfig()) .build()); diff --git a/services/build.gradle b/services/build.gradle index fdc0f35322b..b6d945c7e98 100644 --- a/services/build.gradle +++ b/services/build.gradle @@ -40,6 +40,7 @@ configureProtoCompilation() tasks.named("javadoc").configure { exclude 'io/grpc/services/Internal*.java' exclude 'io/grpc/services/internal/*' + exclude 'io/grpc/protobuf/services/internal/*' } tasks.named("jacocoTestReport").configure { diff --git a/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java b/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java index 9fa25ba2cf8..d7eb9383f07 100644 --- a/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java @@ -19,14 +19,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import static io.grpc.xds.XdsClusterResource.ADS_TYPE_URL_CDS; -import static io.grpc.xds.XdsClusterResource.ADS_TYPE_URL_CDS_V2; -import static io.grpc.xds.XdsEndpointResource.ADS_TYPE_URL_EDS; -import static io.grpc.xds.XdsEndpointResource.ADS_TYPE_URL_EDS_V2; -import static io.grpc.xds.XdsListenerResource.ADS_TYPE_URL_LDS; -import static io.grpc.xds.XdsListenerResource.ADS_TYPE_URL_LDS_V2; -import static io.grpc.xds.XdsRouteConfigureResource.ADS_TYPE_URL_RDS; -import static io.grpc.xds.XdsRouteConfigureResource.ADS_TYPE_URL_RDS_V2; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Stopwatch; @@ -46,17 +38,18 @@ import io.grpc.internal.BackoffPolicy; import io.grpc.stub.StreamObserver; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.ClientXdsClient.XdsChannelFactory; import io.grpc.xds.EnvoyProtoData.Node; import io.grpc.xds.XdsClient.ResourceStore; -import io.grpc.xds.XdsClient.ResourceUpdate; import io.grpc.xds.XdsClient.XdsResponseHandler; +import io.grpc.xds.XdsClientImpl.XdsChannelFactory; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; @@ -82,7 +75,7 @@ final class AbstractXdsClient { // Last successfully applied version_info for each resource type. Starts with empty string. // A version_info is used to update management server with client's most recent knowledge of // resources. - private final Map versions = new HashMap<>(); + private final Map, String> versions = new HashMap<>(); private boolean shutdown; @Nullable @@ -160,8 +153,7 @@ void adjustResourceSubscription(XdsResourceType resourceType) { if (adsStream == null) { startRpcStream(); } - Collection resources = resourceStore.getSubscribedResources(serverInfo, - resourceType.typeName()); + Collection resources = resourceStore.getSubscribedResources(serverInfo, resourceType); if (resources != null) { adsStream.sendDiscoveryRequest(resourceType, resources); } @@ -172,11 +164,10 @@ void adjustResourceSubscription(XdsResourceType resourceType) { * and sends an ACK request to the management server. */ // Must be synchronized. - void ackResponse(XdsResourceType xdsResourceType, String versionInfo, String nonce) { - ResourceType type = xdsResourceType.typeName(); + void ackResponse(XdsResourceType type, String versionInfo, String nonce) { versions.put(type, versionInfo); logger.log(XdsLogLevel.INFO, "Sending ACK for {0} update, nonce: {1}, current version: {2}", - type, nonce, versionInfo); + type.typeName(), nonce, versionInfo); Collection resources = resourceStore.getSubscribedResources(serverInfo, type); if (resources == null) { resources = Collections.emptyList(); @@ -189,11 +180,10 @@ void ackResponse(XdsResourceType xdsResourceType, String versionInfo, String * accepted version) to the management server. */ // Must be synchronized. - void nackResponse(XdsResourceType xdsResourceType, String nonce, String errorDetail) { - ResourceType type = xdsResourceType.typeName(); + void nackResponse(XdsResourceType type, String nonce, String errorDetail) { String versionInfo = versions.getOrDefault(type, ""); logger.log(XdsLogLevel.INFO, "Sending NACK for {0} update, nonce: {1}, current version: {2}", - type, nonce, versionInfo); + type.typeName(), nonce, versionInfo); Collection resources = resourceStore.getSubscribedResources(serverInfo, type); if (resources == null) { resources = Collections.emptyList(); @@ -239,78 +229,22 @@ public void run() { return; } startRpcStream(); - for (ResourceType type : ResourceType.values()) { - if (type == ResourceType.UNKNOWN) { - continue; - } + Set> subscribedResourceTypes = + new HashSet<>(resourceStore.getSubscribedResourceTypesWithTypeUrl().values()); + for (XdsResourceType type : subscribedResourceTypes) { Collection resources = resourceStore.getSubscribedResources(serverInfo, type); if (resources != null) { - adsStream.sendDiscoveryRequest(resourceStore.getXdsResourceType(type), resources); + adsStream.sendDiscoveryRequest(type, resources); } } xdsResponseHandler.handleStreamRestarted(serverInfo); } } - // TODO(zivy) : remove and replace with XdsResourceType - enum ResourceType { - UNKNOWN, LDS, RDS, CDS, EDS; - - String typeUrl() { - switch (this) { - case LDS: - return ADS_TYPE_URL_LDS; - case RDS: - return ADS_TYPE_URL_RDS; - case CDS: - return ADS_TYPE_URL_CDS; - case EDS: - return ADS_TYPE_URL_EDS; - case UNKNOWN: - default: - throw new AssertionError("Unknown or missing case in enum switch: " + this); - } - } - - String typeUrlV2() { - switch (this) { - case LDS: - return ADS_TYPE_URL_LDS_V2; - case RDS: - return ADS_TYPE_URL_RDS_V2; - case CDS: - return ADS_TYPE_URL_CDS_V2; - case EDS: - return ADS_TYPE_URL_EDS_V2; - case UNKNOWN: - default: - throw new AssertionError("Unknown or missing case in enum switch: " + this); - } - } - - @VisibleForTesting - static ResourceType fromTypeUrl(String typeUrl) { - switch (typeUrl) { - case ADS_TYPE_URL_LDS: - // fall trough - case ADS_TYPE_URL_LDS_V2: - return LDS; - case ADS_TYPE_URL_RDS: - // fall through - case ADS_TYPE_URL_RDS_V2: - return RDS; - case ADS_TYPE_URL_CDS: - // fall through - case ADS_TYPE_URL_CDS_V2: - return CDS; - case ADS_TYPE_URL_EDS: - // fall through - case ADS_TYPE_URL_EDS_V2: - return EDS; - default: - return UNKNOWN; - } - } + @VisibleForTesting + @Nullable + XdsResourceType fromTypeUrl(String typeUrl) { + return resourceStore.getSubscribedResourceTypesWithTypeUrl().get(typeUrl); } private abstract class AbstractAdsStream { @@ -322,7 +256,7 @@ private abstract class AbstractAdsStream { // used for management server to identify which response the client is ACKing/NACking. // To avoid confusion, client-initiated requests will always use the nonce in // most recently received responses of each resource type. - private final Map respNonces = new HashMap<>(); + private final Map, String> respNonces = new HashMap<>(); abstract void start(); @@ -334,27 +268,27 @@ private abstract class AbstractAdsStream { * client-initiated discovery requests, use {@link * #sendDiscoveryRequest(XdsResourceType, Collection)}. */ - abstract void sendDiscoveryRequest(ResourceType type, String version, + abstract void sendDiscoveryRequest(XdsResourceType type, String version, Collection resources, String nonce, @Nullable String errorDetail); /** * Sends a client-initiated discovery request. */ - final void sendDiscoveryRequest(XdsResourceType xdsResourceType, - Collection resources) { - ResourceType type = xdsResourceType.typeName(); + final void sendDiscoveryRequest(XdsResourceType type, Collection resources) { logger.log(XdsLogLevel.INFO, "Sending {0} request for resources: {1}", type, resources); sendDiscoveryRequest(type, versions.getOrDefault(type, ""), resources, respNonces.getOrDefault(type, ""), null); } - final void handleRpcResponse( - ResourceType type, String versionInfo, List resources, String nonce) { + final void handleRpcResponse(XdsResourceType type, String versionInfo, List resources, + String nonce) { if (closed) { return; } responseReceived = true; - respNonces.put(type, nonce); + if (type != null) { + respNonces.put(type, nonce); + } xdsResponseHandler.handleResourceResponse(type, serverInfo, versionInfo, resources, nonce); } @@ -422,7 +356,7 @@ public void onNext(final io.envoyproxy.envoy.api.v2.DiscoveryResponse response) syncContext.execute(new Runnable() { @Override public void run() { - ResourceType type = ResourceType.fromTypeUrl(response.getTypeUrl()); + XdsResourceType type = fromTypeUrl(response.getTypeUrl()); if (logger.isLoggable(XdsLogLevel.DEBUG)) { logger.log( XdsLogLevel.DEBUG, "Received {0} response:\n{1}", type, @@ -458,8 +392,9 @@ public void run() { } @Override - void sendDiscoveryRequest(ResourceType type, String versionInfo, Collection resources, - String nonce, @Nullable String errorDetail) { + void sendDiscoveryRequest(XdsResourceType type, String versionInfo, + Collection resources, String nonce, + @Nullable String errorDetail) { checkState(requestWriter != null, "ADS stream has not been started"); io.envoyproxy.envoy.api.v2.DiscoveryRequest.Builder builder = io.envoyproxy.envoy.api.v2.DiscoveryRequest.newBuilder() @@ -502,7 +437,7 @@ public void onNext(final DiscoveryResponse response) { syncContext.execute(new Runnable() { @Override public void run() { - ResourceType type = ResourceType.fromTypeUrl(response.getTypeUrl()); + XdsResourceType type = fromTypeUrl(response.getTypeUrl()); if (logger.isLoggable(XdsLogLevel.DEBUG)) { logger.log( XdsLogLevel.DEBUG, "Received {0} response:\n{1}", type, @@ -538,8 +473,9 @@ public void run() { } @Override - void sendDiscoveryRequest(ResourceType type, String versionInfo, Collection resources, - String nonce, @Nullable String errorDetail) { + void sendDiscoveryRequest(XdsResourceType type, String versionInfo, + Collection resources, String nonce, + @Nullable String errorDetail) { checkState(requestWriter != null, "ADS stream has not been started"); DiscoveryRequest.Builder builder = DiscoveryRequest.newBuilder() diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index afbb21008ef..0db0f59eaa2 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -79,9 +79,9 @@ final class CdsLoadBalancer2 extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { if (this.resolvedAddresses != null) { - return; + return true; } logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); this.resolvedAddresses = resolvedAddresses; @@ -91,6 +91,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.INFO, "Config: {0}", config); cdsLbState = new CdsLbState(config.name); cdsLbState.start(); + return true; } @Override diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 5241c55a4ab..2482085adfb 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -102,7 +102,7 @@ final class ClusterImplLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); Attributes attributes = resolvedAddresses.getAttributes(); if (xdsClientPool == null) { @@ -134,6 +134,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { .setAttributes(attributes) .setLoadBalancingPolicyConfig(config.childPolicy.getConfig()) .build()); + return true; } @Override @@ -162,11 +163,6 @@ public void shutdown() { } } - @Override - public boolean canHandleEmptyAddressListFromNameResolution() { - return true; - } - /** * A decorated {@link LoadBalancer.Helper} that applies configurations for connections * or requests to endpoints in the cluster. diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index 9b3fc83bad6..cce32c68246 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -70,16 +70,16 @@ class ClusterManagerLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { try { resolvingAddresses = true; - handleResolvedAddressesInternal(resolvedAddresses); + return acceptResolvedAddressesInternal(resolvedAddresses); } finally { resolvingAddresses = false; } } - public void handleResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -109,6 +109,7 @@ public void handleResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) // Must update channel picker before return so that new RPCs will not be routed to deleted // clusters and resolver can remove them in service config. updateOverallBalancingState(); + return true; } @Override @@ -126,11 +127,6 @@ public void handleNameResolutionError(Status error) { } } - @Override - public boolean canHandleEmptyAddressListFromNameResolution() { - return true; - } - @Override public void shutdown() { logger.log(XdsLogLevel.INFO, "Shutdown"); diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index f4aaf9426bc..ca481e5691e 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -113,7 +113,7 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); if (xdsClientPool == null) { xdsClientPool = resolvedAddresses.getAttributes().get(InternalXdsAttributes.XDS_CLIENT_POOL); @@ -127,6 +127,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { this.config = config; delegate.handleResolvedAddresses(resolvedAddresses); } + return true; } @Override @@ -170,7 +171,7 @@ private final class ClusterResolverLbState extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { this.resolvedAddresses = resolvedAddresses; ClusterResolverConfig config = (ClusterResolverConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -189,6 +190,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { clusterStates.put(instance.cluster, state); state.start(); } + return true; } @Override diff --git a/xds/src/main/java/io/grpc/xds/CsdsService.java b/xds/src/main/java/io/grpc/xds/CsdsService.java index edee01f95f1..3aab66d94c9 100644 --- a/xds/src/main/java/io/grpc/xds/CsdsService.java +++ b/xds/src/main/java/io/grpc/xds/CsdsService.java @@ -33,7 +33,6 @@ import io.grpc.StatusException; import io.grpc.internal.ObjectPool; import io.grpc.stub.StreamObserver; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.XdsClient.ResourceMetadata; import io.grpc.xds.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.XdsClient.ResourceMetadata.UpdateFailureState; @@ -156,12 +155,12 @@ static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient) throws Inte ClientConfig.Builder builder = ClientConfig.newBuilder() .setNode(xdsClient.getBootstrapInfo().node().toEnvoyProtoNode()); - Map> metadataByType = + Map, Map> metadataByType = awaitSubscribedResourcesMetadata(xdsClient.getSubscribedResourcesMetadataSnapshot()); - for (Map.Entry> metadataByTypeEntry + for (Map.Entry, Map> metadataByTypeEntry : metadataByType.entrySet()) { - ResourceType type = metadataByTypeEntry.getKey(); + XdsResourceType type = metadataByTypeEntry.getKey(); Map metadataByResourceName = metadataByTypeEntry.getValue(); for (Map.Entry metadataEntry : metadataByResourceName.entrySet()) { String resourceName = metadataEntry.getKey(); @@ -187,8 +186,9 @@ static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient) throws Inte return builder.build(); } - private static Map> awaitSubscribedResourcesMetadata( - ListenableFuture>> future) + private static Map, Map> + awaitSubscribedResourcesMetadata( + ListenableFuture, Map>> future) throws InterruptedException { try { // Normally this shouldn't take long, but add some slack for cases like a cold JVM. diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index 584ac2dd16f..b4aa39821d2 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -88,7 +88,13 @@ final class LeastRequestLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + if (resolvedAddresses.getAddresses().isEmpty()) { + handleNameResolutionError(Status.UNAVAILABLE.withDescription( + "NameResolver returned no usable address. addrs=" + resolvedAddresses.getAddresses() + + ", attrs=" + resolvedAddresses.getAttributes())); + return false; + } LeastRequestConfig config = (LeastRequestConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); // Config may be null if least_request is used outside xDS @@ -146,6 +152,8 @@ public void onSubchannelState(ConnectivityStateInfo state) { for (Subchannel removedSubchannel : removedSubchannels) { shutdownSubchannel(removedSubchannel); } + + return true; } @Override diff --git a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java index 5dc73a53135..ce3e95f03d1 100644 --- a/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java +++ b/xds/src/main/java/io/grpc/xds/LoadBalancerConfigFactory.java @@ -35,8 +35,8 @@ import io.grpc.InternalLogId; import io.grpc.LoadBalancerRegistry; import io.grpc.internal.JsonParser; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; import io.grpc.xds.LoadBalancerConfigFactory.LoadBalancingPolicyConverter.MaxRecursionReachedException; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.io.IOException; import java.util.Map; diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java index 5cae9139ae6..0a9def370e3 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancer.java @@ -58,7 +58,7 @@ final class PriorityLoadBalancer extends LoadBalancer { private final XdsLogger logger; // Includes all active and deactivated children. Mutable. New entries are only added from priority - // 0 up to the selected priority. An entry is only deleted 15 minutes after the its deactivation. + // 0 up to the selected priority. An entry is only deleted 15 minutes after its deactivation. private final Map children = new HashMap<>(); // Following fields are only null initially. @@ -70,6 +70,8 @@ final class PriorityLoadBalancer extends LoadBalancer { @Nullable private String currentPriority; private ConnectivityState currentConnectivityState; private SubchannelPicker currentPicker; + // Set to true if currently in the process of handling resolved addresses. + private boolean handlingResolvedAddresses; PriorityLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); @@ -81,7 +83,7 @@ final class PriorityLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); this.resolvedAddresses = resolvedAddresses; PriorityLbConfig config = (PriorityLbConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); @@ -94,12 +96,15 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { children.get(priority).deactivate(); } } + handlingResolvedAddresses = true; for (String priority : priorityNames) { if (children.containsKey(priority)) { children.get(priority).updateResolvedAddresses(); } } + handlingResolvedAddresses = false; tryNextPriority(); + return true; } @Override @@ -133,8 +138,11 @@ private void tryNextPriority() { ChildLbState child = new ChildLbState(priority, priorityConfigs.get(priority).ignoreReresolution); children.put(priority, child); - child.updateResolvedAddresses(); updateOverallState(priority, CONNECTING, BUFFER_PICKER); + // Calling the child's updateResolvedAddresses() can result in tryNextPriority() being + // called recursively. We need to be sure to be done with processing here before it is + // called. + child.updateResolvedAddresses(); return; // Give priority i time to connect. } ChildLbState child = children.get(priority); @@ -297,32 +305,33 @@ public void refreshNameResolution() { @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - syncContext.execute(new Runnable() { - @Override - public void run() { - if (!children.containsKey(priority)) { - return; - } - connectivityState = newState; - picker = newPicker; - if (deletionTimer != null && deletionTimer.isPending()) { - return; - } - if (newState.equals(CONNECTING) ) { - if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) { - failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, - executor); - } - } else if (newState.equals(READY) || newState.equals(IDLE)) { - seenReadyOrIdleSinceTransientFailure = true; - failOverTimer.cancel(); - } else if (newState.equals(TRANSIENT_FAILURE)) { - seenReadyOrIdleSinceTransientFailure = false; - failOverTimer.cancel(); - } - tryNextPriority(); + if (!children.containsKey(priority)) { + return; + } + connectivityState = newState; + picker = newPicker; + + if (deletionTimer != null && deletionTimer.isPending()) { + return; + } + if (newState.equals(CONNECTING)) { + if (!failOverTimer.isPending() && seenReadyOrIdleSinceTransientFailure) { + failOverTimer = syncContext.schedule(new FailOverTask(), 10, TimeUnit.SECONDS, + executor); } - }); + } else if (newState.equals(READY) || newState.equals(IDLE)) { + seenReadyOrIdleSinceTransientFailure = true; + failOverTimer.cancel(); + } else if (newState.equals(TRANSIENT_FAILURE)) { + seenReadyOrIdleSinceTransientFailure = false; + failOverTimer.cancel(); + } + + // If we are currently handling newly resolved addresses, let's not try to reconfigure as + // the address handling process will take care of that to provide an atomic config update. + if (!handlingResolvedAddresses) { + tryNextPriority(); + } } @Override diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 4b365230009..55440fde8eb 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -81,13 +81,13 @@ final class RingHashLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); List addrList = resolvedAddresses.getAddresses(); if (addrList.isEmpty()) { handleNameResolutionError(Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS " + "resolution was successful, but returned server addresses are empty.")); - return; + return false; } Map latestAddrs = stripAttrs(addrList); Set removedAddrs = @@ -162,6 +162,8 @@ public void onSubchannelState(ConnectivityStateInfo newState) { for (Subchannel subchann : removedSubchannels) { shutdownSubchannel(subchann); } + + return true; } private static List buildRing( diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index e70919a9ca4..5aabd976085 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -26,7 +26,7 @@ import io.grpc.internal.SharedResourceHolder; import io.grpc.internal.TimeProvider; import io.grpc.xds.Bootstrapper.BootstrapInfo; -import io.grpc.xds.ClientXdsClient.XdsChannelFactory; +import io.grpc.xds.XdsClientImpl.XdsChannelFactory; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; import io.grpc.xds.internal.security.TlsContextManagerImpl; import java.util.Map; @@ -123,7 +123,7 @@ public XdsClient getObject() { synchronized (lock) { if (refCount == 0) { scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - xdsClient = new ClientXdsClient( + xdsClient = new XdsClientImpl( XdsChannelFactory.DEFAULT_XDS_CHANNEL_FACTORY, bootstrapInfo, context, diff --git a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java index 5edc7bbf20f..825e4a8eca0 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedTargetLoadBalancer.java @@ -61,16 +61,16 @@ final class WeightedTargetLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { try { resolvingAddresses = true; - handleResolvedAddressesInternal(resolvedAddresses); + return acceptResolvedAddressesInternal(resolvedAddresses); } finally { resolvingAddresses = false; } } - public void handleResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); Object lbConfig = resolvedAddresses.getLoadBalancingPolicyConfig(); checkNotNull(lbConfig, "missing weighted_target lb config"); @@ -109,6 +109,7 @@ public void handleResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) childBalancers.keySet().retainAll(targets.keySet()); childHelpers.keySet().retainAll(targets.keySet()); updateOverallBalancingState(); + return true; } @Override diff --git a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java index f320bc340ee..a961db02cce 100644 --- a/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WrrLocalityLoadBalancer.java @@ -61,7 +61,7 @@ final class WrrLocalityLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { logger.log(XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses); // The configuration with the child policy is combined with the locality weights @@ -76,7 +76,7 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { Status unavailable = Status.UNAVAILABLE.withDescription("wrr_locality error: no locality weights provided"); helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(unavailable)); - return; + return false; } // Weighted target LB expects a WeightedPolicySelection for each locality as it will create a @@ -101,6 +101,8 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { resolvedAddresses.toBuilder() .setLoadBalancingPolicyConfig(new WeightedTargetConfig(weightedPolicySelections)) .build()); + + return true; } @Override diff --git a/xds/src/main/java/io/grpc/xds/XdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClient.java index feb3afa3e98..82af8651159 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/XdsClient.java @@ -25,7 +25,6 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.protobuf.Any; import io.grpc.Status; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.LoadStatsManager2.ClusterLocalityStats; @@ -296,7 +295,7 @@ TlsContextManager getTlsContextManager() { * a map ("resource name": "resource metadata"). */ // Must be synchronized. - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { throw new UnsupportedOperationException(); } @@ -347,8 +346,8 @@ ClusterLocalityStats addClusterLocalityStats( interface XdsResponseHandler { /** Called when a xds response is received. */ void handleResourceResponse( - ResourceType resourceType, ServerInfo serverInfo, String versionInfo, List resources, - String nonce); + XdsResourceType resourceType, ServerInfo serverInfo, String versionInfo, + List resources, String nonce); /** Called when the ADS stream is closed passively. */ // Must be synchronized. @@ -369,9 +368,9 @@ interface ResourceStore { */ // Must be synchronized. @Nullable - Collection getSubscribedResources(ServerInfo serverInfo, ResourceType type); + Collection getSubscribedResources(ServerInfo serverInfo, + XdsResourceType type); - @Nullable - XdsResourceType getXdsResourceType(ResourceType type); + Map> getSubscribedResourceTypesWithTypeUrl(); } } diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java similarity index 90% rename from xds/src/main/java/io/grpc/xds/ClientXdsClient.java rename to xds/src/main/java/io/grpc/xds/XdsClientImpl.java index f4dccc8a10b..a30b4755a16 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientImpl.java @@ -18,10 +18,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; import static io.grpc.xds.Bootstrapper.XDSTP_SCHEME; import static io.grpc.xds.XdsResourceType.ParsedResource; import static io.grpc.xds.XdsResourceType.ValidatedResourceUpdate; @@ -46,7 +42,6 @@ import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.internal.BackoffPolicy; import io.grpc.internal.TimeProvider; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.AuthorityInfo; import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.LoadStatsManager2.ClusterDropStats; @@ -72,7 +67,7 @@ /** * XdsClient implementation for client side usages. */ -final class ClientXdsClient extends XdsClient implements XdsResponseHandler, ResourceStore { +final class XdsClientImpl extends XdsClient implements XdsResponseHandler, ResourceStore { // Longest time to wait, since the subscription to some resource, for concluding its absence. @VisibleForTesting @@ -96,12 +91,7 @@ public void uncaughtException(Thread t, Throwable e) { private final Map, Map>> resourceSubscribers = new HashMap<>(); - private final Map> xdsResourceTypeMap = - ImmutableMap.of( - LDS, XdsListenerResource.getInstance(), - RDS, XdsRouteConfigureResource.getInstance(), - CDS, XdsClusterResource.getInstance(), - EDS, XdsEndpointResource.getInstance()); + private final Map> subscribedResourceTypeUrls = new HashMap<>(); private final LoadStatsManager2 loadStatsManager; private final Map serverLrsClientMap = new HashMap<>(); private final XdsChannelFactory xdsChannelFactory; @@ -118,7 +108,7 @@ public void uncaughtException(Thread t, Throwable e) { private volatile boolean isShutdown; // TODO(zdapeng): rename to XdsClientImpl - ClientXdsClient( + XdsClientImpl( XdsChannelFactory xdsChannelFactory, Bootstrapper.BootstrapInfo bootstrapInfo, Context context, @@ -166,17 +156,16 @@ private void maybeCreateXdsChannelWithLrs(ServerInfo serverInfo) { @Override public void handleResourceResponse( - ResourceType resourceType, ServerInfo serverInfo, String versionInfo, List resources, - String nonce) { + XdsResourceType xdsResourceType, ServerInfo serverInfo, String versionInfo, + List resources, String nonce) { syncContext.throwIfNotInThisSynchronizationContext(); - XdsResourceType xdsResourceType = - xdsResourceTypeMap.get(resourceType); if (xdsResourceType == null) { logger.log(XdsLogLevel.WARNING, "Ignore an unknown type of DiscoveryResponse"); return; } Set toParseResourceNames = null; - if (!(resourceType == LDS || resourceType == RDS) + if (!(xdsResourceType == XdsListenerResource.getInstance() + || xdsResourceType == XdsRouteConfigureResource.getInstance()) && resourceSubscribers.containsKey(xdsResourceType)) { toParseResourceNames = resourceSubscribers.get(xdsResourceType).keySet(); } @@ -239,23 +228,17 @@ boolean isShutDown() { return isShutdown; } - private Map> getSubscribedResourcesMap( - ResourceType type) { - return resourceSubscribers.getOrDefault(xdsResourceTypeMap.get(type), Collections.emptyMap()); - } - - @Nullable @Override - public XdsResourceType getXdsResourceType(ResourceType type) { - return xdsResourceTypeMap.get(type); + public Map> getSubscribedResourceTypesWithTypeUrl() { + return Collections.unmodifiableMap(subscribedResourceTypeUrls); } @Nullable @Override public Collection getSubscribedResources(ServerInfo serverInfo, - ResourceType type) { + XdsResourceType type) { Map> resources = - getSubscribedResourcesMap(type); + resourceSubscribers.getOrDefault(type, Collections.emptyMap()); ImmutableSet.Builder builder = ImmutableSet.builder(); for (String key : resources.keySet()) { if (resources.get(key).serverInfo.equals(serverInfo)) { @@ -266,26 +249,26 @@ public Collection getSubscribedResources(ServerInfo serverInfo, return retVal.isEmpty() ? null : retVal; } + // As XdsClient APIs becomes resource agnostic, subscribed resource types are dynamic. + // ResourceTypes that do not have subscribers does not show up in the snapshot keys. @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { - final SettableFuture>> future = + final SettableFuture, Map>> future = SettableFuture.create(); syncContext.execute(new Runnable() { @Override public void run() { // A map from a "resource type" to a map ("resource name": "resource metadata") - ImmutableMap.Builder> metadataSnapshot = + ImmutableMap.Builder, Map> metadataSnapshot = ImmutableMap.builder(); - for (XdsResourceType resourceType: xdsResourceTypeMap.values()) { + for (XdsResourceType resourceType: resourceSubscribers.keySet()) { ImmutableMap.Builder metadataMap = ImmutableMap.builder(); - Map> resourceSubscriberMap = - resourceSubscribers.getOrDefault(resourceType, Collections.emptyMap()); for (Map.Entry> resourceEntry - : resourceSubscriberMap.entrySet()) { + : resourceSubscribers.get(resourceType).entrySet()) { metadataMap.put(resourceEntry.getKey(), resourceEntry.getValue().metadata); } - metadataSnapshot.put(resourceType.typeName(), metadataMap.buildOrThrow()); + metadataSnapshot.put(resourceType, metadataMap.buildOrThrow()); } future.set(metadataSnapshot.buildOrThrow()); } @@ -307,12 +290,14 @@ void watchXdsResource(XdsResourceType type, String public void run() { if (!resourceSubscribers.containsKey(type)) { resourceSubscribers.put(type, new HashMap<>()); + subscribedResourceTypeUrls.put(type.typeUrl(), type); + subscribedResourceTypeUrls.put(type.typeUrlV2(), type); } ResourceSubscriber subscriber = (ResourceSubscriber) resourceSubscribers.get(type).get(resourceName);; if (subscriber == null) { logger.log(XdsLogLevel.INFO, "Subscribe {0} resource {1}", type, resourceName); - subscriber = new ResourceSubscriber<>(type.typeName(), resourceName); + subscriber = new ResourceSubscriber<>(type, resourceName); resourceSubscribers.get(type).put(resourceName, subscriber); if (subscriber.xdsChannel != null) { subscriber.xdsChannel.adjustResourceSubscription(type); @@ -337,6 +322,8 @@ public void run() { if (!subscriber.isWatched()) { subscriber.cancelResourceWatch(); resourceSubscribers.get(type).remove(resourceName); + subscribedResourceTypeUrls.remove(type.typeUrl()); + subscribedResourceTypeUrls.remove(type.typeUrlV2()); if (subscriber.xdsChannel != null) { subscriber.xdsChannel.adjustResourceSubscription(type); } @@ -427,8 +414,9 @@ private void handleResourceUpdate(XdsResourceType.Arg } long updateTime = timeProvider.currentTimeNanos(); - for (Map.Entry> entry : - getSubscribedResourcesMap(xdsResourceType.typeName()).entrySet()) { + Map> subscribedResources = + resourceSubscribers.getOrDefault(xdsResourceType, Collections.emptyMap()); + for (Map.Entry> entry : subscribedResources.entrySet()) { String resourceName = entry.getKey(); ResourceSubscriber subscriber = (ResourceSubscriber) entry.getValue(); @@ -473,7 +461,7 @@ private void handleResourceUpdate(XdsResourceType.Arg // LDS/CDS responses represents the state of the world, RDS/EDS resources not referenced in // LDS/CDS resources should be deleted. if (xdsResourceType.dependentResource() != null) { - XdsResourceType dependency = xdsResourceTypeMap.get(xdsResourceType.dependentResource()); + XdsResourceType dependency = xdsResourceType.dependentResource(); Map> dependentSubscribers = resourceSubscribers.get(dependency); if (dependentSubscribers == null) { @@ -493,13 +481,13 @@ private void retainDependentResource( return; } String resourceName = null; - if (subscriber.type == LDS) { + if (subscriber.type == XdsListenerResource.getInstance()) { LdsUpdate ldsUpdate = (LdsUpdate) subscriber.data; io.grpc.xds.HttpConnectionManager hcm = ldsUpdate.httpConnectionManager(); if (hcm != null) { resourceName = hcm.rdsName(); } - } else if (subscriber.type == CDS) { + } else if (subscriber.type == XdsClusterResource.getInstance()) { CdsUpdate cdsUpdate = (CdsUpdate) subscriber.data; resourceName = cdsUpdate.edsServiceName(); } @@ -515,7 +503,7 @@ private void retainDependentResource( private final class ResourceSubscriber { @Nullable private final ServerInfo serverInfo; @Nullable private final AbstractXdsClient xdsChannel; - private final ResourceType type; + private final XdsResourceType type; private final String resource; private final Set> watchers = new HashSet<>(); @Nullable private T data; @@ -527,7 +515,7 @@ private final class ResourceSubscriber { @Nullable private ResourceMetadata metadata; @Nullable private String errorDescription; - ResourceSubscriber(ResourceType type, String resource) { + ResourceSubscriber(XdsResourceType type, String resource) { syncContext.throwIfNotInThisSynchronizationContext(); this.type = type; this.resource = resource; @@ -669,7 +657,8 @@ void onAbsent() { // and the resource is reusable. boolean ignoreResourceDeletionEnabled = serverInfo != null && serverInfo.ignoreResourceDeletion(); - boolean isStateOfTheWorld = (type == LDS || type == CDS); + boolean isStateOfTheWorld = (type == XdsListenerResource.getInstance() + || type == XdsClusterResource.getInstance()); if (ignoreResourceDeletionEnabled && isStateOfTheWorld && data != null) { if (!resourceDeletionIgnored) { logger.log(XdsLogLevel.FORCE_WARNING, @@ -719,7 +708,6 @@ private void notifyWatcher(ResourceWatcher watcher, T update) { } } - @VisibleForTesting static final class ResourceInvalidException extends Exception { private static final long serialVersionUID = 0L; diff --git a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java index 4dc3095efa9..82a977f7df4 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClusterResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsClusterResource.java @@ -17,9 +17,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.AbstractXdsClient.ResourceType; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; import static io.grpc.xds.Bootstrapper.ServerInfo; import com.google.auto.value.AutoValue; @@ -42,10 +39,10 @@ import io.grpc.NameResolver; import io.grpc.internal.ServiceConfigUtil; import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; import io.grpc.xds.EnvoyServerProtoData.OutlierDetection; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.XdsClient.ResourceUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import io.grpc.xds.XdsClusterResource.CdsUpdate; import java.util.List; import java.util.Locale; @@ -77,8 +74,8 @@ String extractResourceName(Message unpackedResource) { } @Override - ResourceType typeName() { - return CDS; + String typeName() { + return "CDS"; } @Override @@ -93,8 +90,8 @@ String typeUrlV2() { @Nullable @Override - ResourceType dependentResource() { - return EDS; + XdsResourceType dependentResource() { + return XdsEndpointResource.getInstance(); } @Override diff --git a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java index db1e93d13f3..c126d643311 100644 --- a/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsEndpointResource.java @@ -17,8 +17,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.AbstractXdsClient.ResourceType; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; @@ -27,10 +25,10 @@ import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.grpc.EquivalentAddressGroup; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.XdsClient.ResourceUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import io.grpc.xds.XdsEndpointResource.EdsUpdate; import java.net.InetSocketAddress; import java.util.ArrayList; @@ -66,8 +64,8 @@ String extractResourceName(Message unpackedResource) { } @Override - ResourceType typeName() { - return EDS; + String typeName() { + return "EDS"; } @Override @@ -82,7 +80,7 @@ String typeUrlV2() { @Nullable @Override - ResourceType dependentResource() { + XdsResourceType dependentResource() { return null; } diff --git a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java index 397ba32dc6e..5f7d6a27aa4 100644 --- a/xds/src/main/java/io/grpc/xds/XdsListenerResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsListenerResource.java @@ -17,11 +17,8 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.AbstractXdsClient.ResourceType; -import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; -import static io.grpc.xds.ClientXdsClient.ResourceInvalidException; import static io.grpc.xds.XdsClient.ResourceUpdate; +import static io.grpc.xds.XdsClientImpl.ResourceInvalidException; import static io.grpc.xds.XdsClusterResource.validateCommonTlsContext; import static io.grpc.xds.XdsRouteConfigureResource.extractVirtualHosts; @@ -81,8 +78,8 @@ String extractResourceName(Message unpackedResource) { } @Override - ResourceType typeName() { - return LDS; + String typeName() { + return "LDS"; } @Override @@ -102,8 +99,8 @@ String typeUrlV2() { @Nullable @Override - ResourceType dependentResource() { - return RDS; + XdsResourceType dependentResource() { + return XdsRouteConfigureResource.getInstance(); } @Override diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 702c217a3ee..74e35ca3e7d 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -47,7 +47,6 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.AuthorityInfo; import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; @@ -202,8 +201,9 @@ public void start(Listener2 listener) { replacement = XdsClient.percentEncodePath(replacement); } String ldsResourceName = expandPercentS(listenerNameTemplate, replacement); - if (!XdsClient.isResourceNameValid(ldsResourceName, ResourceType.LDS.typeUrl()) - && !XdsClient.isResourceNameValid(ldsResourceName, ResourceType.LDS.typeUrlV2())) { + if (!XdsClient.isResourceNameValid(ldsResourceName, XdsListenerResource.getInstance().typeUrl()) + && !XdsClient.isResourceNameValid(ldsResourceName, + XdsListenerResource.getInstance().typeUrlV2())) { listener.onError(Status.INVALID_ARGUMENT.withDescription( "invalid listener resource URI for service authority: " + serviceAuthority)); return; diff --git a/xds/src/main/java/io/grpc/xds/XdsResourceType.java b/xds/src/main/java/io/grpc/xds/XdsResourceType.java index 52c143934e6..a377ee35d7a 100644 --- a/xds/src/main/java/io/grpc/xds/XdsResourceType.java +++ b/xds/src/main/java/io/grpc/xds/XdsResourceType.java @@ -17,12 +17,11 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.AbstractXdsClient.ResourceType; import static io.grpc.xds.Bootstrapper.ServerInfo; -import static io.grpc.xds.ClientXdsClient.ResourceInvalidException; import static io.grpc.xds.XdsClient.ResourceUpdate; import static io.grpc.xds.XdsClient.canonifyResourceName; import static io.grpc.xds.XdsClient.isResourceNameValid; +import static io.grpc.xds.XdsClientImpl.ResourceInvalidException; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Strings; @@ -80,7 +79,7 @@ abstract class XdsResourceType { abstract Class unpackedClassName(); - abstract ResourceType typeName(); + abstract String typeName(); abstract String typeUrl(); @@ -88,7 +87,7 @@ abstract class XdsResourceType { // Non-null for State of the World resources. @Nullable - abstract ResourceType dependentResource(); + abstract XdsResourceType dependentResource(); static class Args { final ServerInfo serverInfo; @@ -158,7 +157,7 @@ ValidatedResourceUpdate parse(Args args, List resources) { T resourceUpdate; try { resourceUpdate = doParse(args, unpackedMessage, retainedResources, isResourceV3); - } catch (ClientXdsClient.ResourceInvalidException e) { + } catch (XdsClientImpl.ResourceInvalidException e) { errors.add(String.format("%s response %s '%s' validation error: %s", typeName(), unpackedClassName().getSimpleName(), cname, e.getMessage())); invalidResources.add(cname); diff --git a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java index f2fca0b1bf7..166809c87e1 100644 --- a/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java +++ b/xds/src/main/java/io/grpc/xds/XdsRouteConfigureResource.java @@ -17,9 +17,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType; -import static io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; import com.github.udpa.udpa.type.v1.TypedStruct; import com.google.common.annotations.VisibleForTesting; @@ -40,7 +37,6 @@ import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.grpc.Status; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; import io.grpc.xds.ClusterSpecifierPlugin.NamedPluginConfig; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; import io.grpc.xds.Filter.FilterConfig; @@ -52,9 +48,11 @@ import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.XdsClient.ResourceUpdate; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.internal.Matchers; import io.grpc.xds.internal.Matchers.FractionMatcher; import io.grpc.xds.internal.Matchers.HeaderMatcher; -import io.grpc.xds.internal.Matchers; import java.util.ArrayList; import java.util.Collections; import java.util.EnumSet; @@ -95,8 +93,8 @@ String extractResourceName(Message unpackedResource) { } @Override - ResourceType typeName() { - return RDS; + String typeName() { + return "RDS"; } @Override @@ -111,7 +109,7 @@ String typeUrlV2() { @Nullable @Override - ResourceType dependentResource() { + XdsResourceType dependentResource() { return null; } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 0c1271d1dc7..08f2e86fb69 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -208,10 +208,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { logger.log( Level.FINEST, - "ClientSdsHandler.updateSecret authority={0}, ctx.name={1}", + "ClientSdsHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); @@ -347,7 +347,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { ChannelHandler handler = InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index 7544f5d9fc3..a0c4ed37dfb 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -57,7 +57,7 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSecret(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -120,7 +120,7 @@ protected final void performCallback( public void run() { try { SslContext sslContext = sslContextGetter.get(); - callback.updateSecret(sslContext); + callback.updateSslContext(sslContext); } catch (Throwable e) { callback.onException(e); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index e429eff44a0..5f629273179 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -66,8 +66,8 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSecret(SslContext sslContext) { - callback.updateSecret(sslContext); + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); releaseSslContextProvider(toRelease); } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 1fd81e606a5..60ddb9f3da8 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -17,7 +17,6 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; import static io.grpc.xds.XdsLbPolicies.CLUSTER_RESOLVER_POLICY_NAME; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; @@ -136,7 +135,7 @@ public void setUp() { lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental", new LeastRequestLoadBalancerProvider())); loadBalancer = new CdsLoadBalancer2(helper, lbRegistry); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setAttributes( @@ -657,7 +656,7 @@ private final class FakeXdsClient extends XdsClient { @SuppressWarnings("unchecked") void watchXdsResource(XdsResourceType type, String resourceName, ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo(CDS); + assertThat(type.typeName()).isEqualTo("CDS"); assertThat(watchers).doesNotContainKey(resourceName); watchers.put(resourceName, (ResourceWatcher)watcher); } @@ -667,7 +666,7 @@ void watchXdsResource(XdsResourceType type, String void cancelXdsResourceWatch(XdsResourceType type, String resourceName, ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo(CDS); + assertThat(type.typeName()).isEqualTo("CDS"); assertThat(watchers).containsKey(resourceName); watchers.remove(resourceName); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 65af00da7a9..142786280a8 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -282,7 +282,7 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.singletonList(DropOverload.create("lb", 1_000_000)), new PolicySelection(weightedTargetProvider, weightedTargetConfig), null); - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.singletonList(endpoint)) .setAttributes( @@ -571,7 +571,7 @@ private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecuri private void deliverAddressesAndConfig(List addresses, ClusterImplConfig config) { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(addresses) .setAttributes( @@ -677,10 +677,11 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { addresses = resolvedAddresses.getAddresses(); config = resolvedAddresses.getLoadBalancingPolicyConfig(); attributes = resolvedAddresses.getAttributes(); + return true; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java index c83c9c4060a..f2b80cfff0b 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java @@ -267,7 +267,7 @@ private void deliverResolvedAddresses(final Map childPolicies) { } private void deliverResolvedAddresses(final Map childPolicies, boolean failing) { - clusterManagerLoadBalancer.handleResolvedAddresses( + clusterManagerLoadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setLoadBalancingPolicyConfig(buildConfig(childPolicies, failing)) @@ -348,12 +348,13 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { config = resolvedAddresses.getLoadBalancingPolicyConfig(); if (failing) { helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.INTERNAL)); } + return true; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 741c2ba8cdb..b4b709d2ae2 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -17,7 +17,6 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; import static io.grpc.xds.XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; @@ -1082,7 +1081,7 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug } private void deliverLbConfig(ClusterResolverConfig config) { - loadBalancer.handleResolvedAddresses( + loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setAttributes( @@ -1176,7 +1175,7 @@ private static final class FakeXdsClient extends XdsClient { @SuppressWarnings("unchecked") void watchXdsResource(XdsResourceType type, String resourceName, ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo(EDS); + assertThat(type.typeName()).isEqualTo("EDS"); assertThat(watchers).doesNotContainKey(resourceName); watchers.put(resourceName, (ResourceWatcher) watcher); } @@ -1186,7 +1185,7 @@ void watchXdsResource(XdsResourceType type, String void cancelXdsResourceWatch(XdsResourceType type, String resourceName, ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo(EDS); + assertThat(type.typeName()).isEqualTo("EDS"); assertThat(watchers).containsKey(resourceName); watchers.remove(resourceName); } @@ -1326,10 +1325,11 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public boolean acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { addresses = resolvedAddresses.getAddresses(); config = resolvedAddresses.getLoadBalancingPolicyConfig(); attributes = resolvedAddresses.getAttributes(); + return true; } @Override diff --git a/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java new file mode 100644 index 00000000000..ea764e67e40 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/ControlPlaneRule.java @@ -0,0 +1,301 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; +import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; + +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Any; +import com.google.protobuf.Message; +import com.google.protobuf.UInt32Value; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.HealthStatus; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TrafficDirection; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.config.listener.v3.ApiListener; +import io.envoyproxy.envoy.config.listener.v3.Filter; +import io.envoyproxy.envoy.config.listener.v3.FilterChain; +import io.envoyproxy.envoy.config.listener.v3.FilterChainMatch; +import io.envoyproxy.envoy.config.listener.v3.Listener; +import io.envoyproxy.envoy.config.route.v3.NonForwardingAction; +import io.envoyproxy.envoy.config.route.v3.Route; +import io.envoyproxy.envoy.config.route.v3.RouteAction; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; +import io.envoyproxy.envoy.config.route.v3.RouteMatch; +import io.envoyproxy.envoy.config.route.v3.VirtualHost; +import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; +import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; +import io.grpc.NameResolverRegistry; +import io.grpc.Server; +import io.grpc.netty.NettyServerBuilder; +import java.util.Collections; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; + +/** + * Starts a control plane server and sets up the test to use it. Initialized with a default + * configuration, but also provides methods for updating the configuration. + */ +public class ControlPlaneRule extends TestWatcher { + private static final Logger logger = Logger.getLogger(ControlPlaneRule.class.getName()); + + private static final String SCHEME = "test-xds"; + private static final String RDS_NAME = "route-config.googleapis.com"; + private static final String CLUSTER_NAME = "cluster0"; + private static final String EDS_NAME = "eds-service-0"; + private static final String SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT = + "grpc/server?udpa.resource.listening_address="; + private static final String SERVER_HOST_NAME = "test-server"; + private static final String HTTP_CONNECTION_MANAGER_TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3" + + ".HttpConnectionManager"; + + private Server server; + private XdsTestControlPlaneService controlPlaneService; + private XdsNameResolverProvider nameResolverProvider; + + /** + * Returns the test control plane service interface. + */ + public XdsTestControlPlaneService getService() { + return controlPlaneService; + } + + /** + * Returns the server instance. + */ + public Server getServer() { + return server; + } + + @Override protected void starting(Description description) { + // Start the control plane server. + try { + controlPlaneService = new XdsTestControlPlaneService(); + NettyServerBuilder controlPlaneServerBuilder = NettyServerBuilder.forPort(0) + .addService(controlPlaneService); + server = controlPlaneServerBuilder.build().start(); + } catch (Exception e) { + throw new AssertionError("unable to start the control plane server", e); + } + + // Configure and register an xDS name resolver so that gRPC knows how to connect to the server. + nameResolverProvider = XdsNameResolverProvider.createForTest(SCHEME, + defaultBootstrapOverride()); + NameResolverRegistry.getDefaultRegistry().register(nameResolverProvider); + } + + @Override protected void finished(Description description) { + if (server != null) { + server.shutdownNow(); + try { + if (!server.awaitTermination(5, TimeUnit.SECONDS)) { + logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); + } + } catch (InterruptedException e) { + throw new AssertionError("unable to shut down control plane server", e); + } + } + NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); + } + + /** + * For test purpose, use boostrapOverride to programmatically provide bootstrap info. + */ + public Map defaultBootstrapOverride() { + return ImmutableMap.of( + "node", ImmutableMap.of( + "id", UUID.randomUUID().toString(), + "cluster", "cluster0"), + "xds_servers", Collections.singletonList( + + ImmutableMap.of( + "server_uri", "localhost:" + server.getPort(), + "channel_creds", Collections.singletonList( + ImmutableMap.of("type", "insecure") + ), + "server_features", Collections.singletonList("xds_v3") + ) + ), + "server_listener_resource_name_template", SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT + ); + } + + void setLdsConfig(Listener serverListener, Listener clientListener) { + getService().setXdsConfig(ADS_TYPE_URL_LDS, + ImmutableMap.of(SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT, serverListener, + SERVER_HOST_NAME, clientListener)); + } + + void setRdsConfig(RouteConfiguration routeConfiguration) { + getService().setXdsConfig(ADS_TYPE_URL_RDS, ImmutableMap.of(RDS_NAME, routeConfiguration)); + } + + void setCdsConfig(Cluster cluster) { + getService().setXdsConfig(ADS_TYPE_URL_CDS, + ImmutableMap.of(CLUSTER_NAME, cluster)); + } + + void setEdsConfig(ClusterLoadAssignment clusterLoadAssignment) { + getService().setXdsConfig(ADS_TYPE_URL_EDS, + ImmutableMap.of(EDS_NAME, clusterLoadAssignment)); + } + + /** + * Builds a new default RDS configuration. + */ + static RouteConfiguration buildRouteConfiguration(String authority) { + io.envoyproxy.envoy.config.route.v3.VirtualHost virtualHost = VirtualHost.newBuilder() + .addDomains(authority) + .addRoutes( + Route.newBuilder() + .setMatch( + RouteMatch.newBuilder().setPrefix("/").build()) + .setRoute( + RouteAction.newBuilder().setCluster(CLUSTER_NAME).build()).build()).build(); + return RouteConfiguration.newBuilder().setName(RDS_NAME).addVirtualHosts(virtualHost).build(); + } + + /** + * Builds a new default CDS configuration. + */ + static Cluster buildCluster() { + return Cluster.newBuilder() + .setName(CLUSTER_NAME) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig( + Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_NAME) + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder().build()) + .build()) + .build()) + .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + .build(); + } + + /** + * Builds a new default EDS configuration. + */ + static ClusterLoadAssignment buildClusterLoadAssignment(String hostName, int port) { + Address address = Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder().setAddress(hostName).setPortValue(port).build()).build(); + LocalityLbEndpoints endpoints = LocalityLbEndpoints.newBuilder() + .setLoadBalancingWeight(UInt32Value.of(10)) + .setPriority(0) + .addLbEndpoints( + LbEndpoint.newBuilder() + .setEndpoint( + Endpoint.newBuilder().setAddress(address).build()) + .setHealthStatus(HealthStatus.HEALTHY) + .build()).build(); + return ClusterLoadAssignment.newBuilder() + .setClusterName(EDS_NAME) + .addEndpoints(endpoints) + .build(); + } + + /** + * Builds a new client listener. + */ + static Listener buildClientListener(String name) { + HttpFilter httpFilter = HttpFilter.newBuilder() + .setName("terminal-filter") + .setTypedConfig(Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build(); + ApiListener apiListener = ApiListener.newBuilder().setApiListener(Any.pack( + io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3 + .HttpConnectionManager.newBuilder() + .setRds( + Rds.newBuilder() + .setRouteConfigName(RDS_NAME) + .setConfigSource( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance()))) + .addAllHttpFilters(Collections.singletonList(httpFilter)) + .build(), + HTTP_CONNECTION_MANAGER_TYPE_URL)).build(); + return Listener.newBuilder() + .setName(name) + .setApiListener(apiListener).build(); + } + + /** + * Builds a new server listener. + */ + static Listener buildServerListener() { + HttpFilter routerFilter = HttpFilter.newBuilder() + .setName("terminal-filter") + .setTypedConfig( + Any.pack(Router.newBuilder().build())) + .setIsOptional(true) + .build(); + VirtualHost virtualHost = io.envoyproxy.envoy.config.route.v3.VirtualHost.newBuilder() + .setName("virtual-host-0") + .addDomains("*") + .addRoutes( + Route.newBuilder() + .setMatch( + RouteMatch.newBuilder().setPrefix("/").build()) + .setNonForwardingAction(NonForwardingAction.newBuilder().build()) + .build()).build(); + RouteConfiguration routeConfig = RouteConfiguration.newBuilder() + .addVirtualHosts(virtualHost) + .build(); + io.envoyproxy.envoy.config.listener.v3.Filter filter = Filter.newBuilder() + .setName("network-filter-0") + .setTypedConfig( + Any.pack( + HttpConnectionManager.newBuilder() + .setRouteConfig(routeConfig) + .addAllHttpFilters(Collections.singletonList(routerFilter)) + .build())).build(); + FilterChainMatch filterChainMatch = FilterChainMatch.newBuilder() + .setSourceType(FilterChainMatch.ConnectionSourceType.ANY) + .build(); + FilterChain filterChain = FilterChain.newBuilder() + .setName("filter-chain-0") + .setFilterChainMatch(filterChainMatch) + .addFilters(filter) + .build(); + return Listener.newBuilder() + .setName(SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT) + .setTrafficDirection(TrafficDirection.INBOUND) + .addFilterChains(filterChain) + .build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index 8bf9981947f..5272a6d297e 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -17,10 +17,6 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -49,13 +45,13 @@ import io.grpc.internal.testing.StreamRecorder; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcServerRule; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.XdsClient.ResourceMetadata; import io.grpc.xds.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; -import java.util.EnumMap; +import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.Callable; @@ -80,7 +76,11 @@ public class CsdsServiceTest { ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create(), true))) .node(BOOTSTRAP_NODE) .build(); - private static final XdsClient XDS_CLIENT_NO_RESOURCES = new FakeXdsClient(); + private static final FakeXdsClient XDS_CLIENT_NO_RESOURCES = new FakeXdsClient(); + private static final XdsResourceType LDS = XdsListenerResource.getInstance(); + private static final XdsResourceType CDS = XdsClusterResource.getInstance(); + private static final XdsResourceType RDS = XdsRouteConfigureResource.getInstance(); + private static final XdsResourceType EDS = XdsEndpointResource.getInstance(); @RunWith(JUnit4.class) public static class ServiceTests { @@ -126,7 +126,7 @@ public void fetchClientConfig_invalidArgument() { public void fetchClientConfig_unexpectedException() { XdsClient throwingXdsClient = new FakeXdsClient() { @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { return Futures.immediateFailedFuture( new IllegalArgumentException("IllegalArgumentException")); @@ -150,12 +150,12 @@ public void fetchClientConfig_unexpectedException() { public void fetchClientConfig_interruptedException() { XdsClient throwingXdsClient = new FakeXdsClient() { @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { return Futures.submit( - new Callable>>() { + new Callable, Map>>() { @Override - public Map> call() { + public Map, Map> call() { Thread.currentThread().interrupt(); return null; } @@ -264,7 +264,7 @@ private void verifyResponse(ClientStatusResponse response) { assertThat(response.getConfigCount()).isEqualTo(1); ClientConfig clientConfig = response.getConfig(0); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(clientConfig); + verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); } private void verifyRequestInvalidResponseStatus(Status status) { @@ -321,18 +321,29 @@ public void metadataStatusToClientStatus() { @Test public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() throws InterruptedException { - ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(new FakeXdsClient() { + FakeXdsClient fakeXdsClient = new FakeXdsClient() { @Override - protected Map> + protected Map, Map> getSubscribedResourcesMetadata() { - return new ImmutableMap.Builder>() + return new ImmutableMap.Builder, Map>() .put(LDS, ImmutableMap.of("subscribedResourceName.LDS", METADATA_ACKED_LDS)) .put(RDS, ImmutableMap.of("subscribedResourceName.RDS", METADATA_ACKED_RDS)) .put(CDS, ImmutableMap.of("subscribedResourceName.CDS", METADATA_ACKED_CDS)) .put(EDS, ImmutableMap.of("subscribedResourceName.EDS", METADATA_ACKED_EDS)) .buildOrThrow(); } - }); + + @Override + public Map> getSubscribedResourceTypesWithTypeUrl() { + return ImmutableMap.of( + LDS.typeUrl(), LDS, + RDS.typeUrl(), RDS, + CDS.typeUrl(), CDS, + EDS.typeUrl(), EDS + ); + } + }; + ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient); verifyClientConfigNode(clientConfig); @@ -340,7 +351,8 @@ public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() // is propagated to the correct resource types. int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); assertThat(xdsConfigCount).isEqualTo(4); - EnumMap configDumps = mapConfigDumps(clientConfig); + Map, GenericXdsConfig> configDumps = mapConfigDumps(fakeXdsClient, + clientConfig); assertThat(configDumps.keySet()).containsExactly(LDS, RDS, CDS, EDS); // LDS. @@ -373,7 +385,7 @@ public void getClientConfigForXdsClient_subscribedResourcesToGenericXdsConfig() public void getClientConfigForXdsClient_noSubscribedResources() throws InterruptedException { ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES); verifyClientConfigNode(clientConfig); - verifyClientConfigNoResources(clientConfig); + verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); } } @@ -381,10 +393,11 @@ public void getClientConfigForXdsClient_noSubscribedResources() throws Interrupt * Assuming {@link MetadataToProtoTests} passes, and metadata converted to corresponding * config dumps correctly, perform a minimal verification of the general shape of ClientConfig. */ - private static void verifyClientConfigNoResources(ClientConfig clientConfig) { + private static void verifyClientConfigNoResources(FakeXdsClient xdsClient, + ClientConfig clientConfig) { int xdsConfigCount = clientConfig.getGenericXdsConfigsCount(); assertThat(xdsConfigCount).isEqualTo(0); - EnumMap configDumps = mapConfigDumps(clientConfig); + Map, GenericXdsConfig> configDumps = mapConfigDumps(xdsClient, clientConfig); assertThat(configDumps).isEmpty(); } @@ -398,25 +411,28 @@ private static void verifyClientConfigNode(ClientConfig clientConfig) { assertThat(node).isEqualTo(BOOTSTRAP_NODE.toEnvoyProtoNode()); } - private static EnumMap mapConfigDumps(ClientConfig config) { - EnumMap xdsConfigMap = new EnumMap<>(ResourceType.class); + private static Map, GenericXdsConfig> mapConfigDumps(FakeXdsClient client, + ClientConfig config) { + Map, GenericXdsConfig> xdsConfigMap = new HashMap<>(); List xdsConfigList = config.getGenericXdsConfigsList(); for (GenericXdsConfig genericXdsConfig : xdsConfigList) { - ResourceType type = ResourceType.fromTypeUrl(genericXdsConfig.getTypeUrl()); - assertThat(type).isNotEqualTo(ResourceType.UNKNOWN); + XdsResourceType type = client.getSubscribedResourceTypesWithTypeUrl() + .get(genericXdsConfig.getTypeUrl()); + assertThat(type).isNotNull(); assertThat(xdsConfigMap).doesNotContainKey(type); xdsConfigMap.put(type, genericXdsConfig); } return xdsConfigMap; } - private static class FakeXdsClient extends XdsClient { - protected Map> getSubscribedResourcesMetadata() { + private static class FakeXdsClient extends XdsClient implements XdsClient.ResourceStore { + protected Map, Map> + getSubscribedResourcesMetadata() { return ImmutableMap.of(); } @Override - ListenableFuture>> + ListenableFuture, Map>> getSubscribedResourcesMetadataSnapshot() { return Futures.immediateFuture(getSubscribedResourcesMetadata()); } @@ -425,6 +441,18 @@ protected Map> getSubscribedResource BootstrapInfo getBootstrapInfo() { return BOOTSTRAP_INFO; } + + @Nullable + @Override + public Collection getSubscribedResources(ServerInfo serverInfo, + XdsResourceType type) { + return null; + } + + @Override + public Map> getSubscribedResourceTypesWithTypeUrl() { + return ImmutableMap.of(); + } } private static class FakeXdsClientPoolFactory implements XdsClientPoolFactory { diff --git a/xds/src/test/java/io/grpc/xds/DataPlaneRule.java b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java new file mode 100644 index 00000000000..faa79444071 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/DataPlaneRule.java @@ -0,0 +1,173 @@ +/* + * Copyright 2022 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.Grpc; +import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import java.net.InetSocketAddress; +import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; + +/** + * This rule creates a new server instance in the "data plane" that is configured by a "control + * plane" xDS server. + */ +public class DataPlaneRule extends TestWatcher { + private static final Logger logger = Logger.getLogger(DataPlaneRule.class.getName()); + + private static final String SERVER_HOST_NAME = "test-server"; + private static final String SCHEME = "test-xds"; + + private final ControlPlaneRule controlPlane; + private Server server; + private HashSet channels = new HashSet<>(); + + /** + * Creates a new {@link DataPlaneRule} that is connected to the given {@link ControlPlaneRule}. + */ + public DataPlaneRule(ControlPlaneRule controlPlane) { + this.controlPlane = controlPlane; + } + + /** + * Returns the server instance. + */ + public Server getServer() { + return server; + } + + /** + * Returns a newly created {@link ManagedChannel} to the server. + */ + public ManagedChannel getManagedChannel() { + ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + SERVER_HOST_NAME, + InsecureChannelCredentials.create()).build(); + channels.add(channel); + return channel; + } + + @Override + protected void starting(Description description) { + // Let the control plane know about our new server. + controlPlane.setLdsConfig(ControlPlaneRule.buildServerListener(), + ControlPlaneRule.buildClientListener(SERVER_HOST_NAME) + ); + + // Start up the server. + try { + startServer(controlPlane.defaultBootstrapOverride()); + } catch (Exception e) { + throw new AssertionError("unable to start the data plane server", e); + } + + // Provide the rest of the configuration to the control plane. + controlPlane.setRdsConfig(ControlPlaneRule.buildRouteConfiguration(SERVER_HOST_NAME)); + controlPlane.setCdsConfig(ControlPlaneRule.buildCluster()); + InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); + controlPlane.setEdsConfig( + ControlPlaneRule.buildClusterLoadAssignment(edsInetSocketAddress.getHostName(), + edsInetSocketAddress.getPort())); + } + + @Override + protected void finished(Description description) { + if (server != null) { + // Shut down any lingering open channels to the server. + for (ManagedChannel channel : channels) { + if (!channel.isShutdown()) { + channel.shutdownNow(); + } + } + + // Shut down the server itself. + server.shutdownNow(); + try { + if (!server.awaitTermination(5, TimeUnit.SECONDS)) { + logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); + } + } catch (InterruptedException e) { + throw new AssertionError("unable to shut down data plane server", e); + } + } + } + + private void startServer(Map bootstrapOverride) throws Exception { + ServerInterceptor metadataInterceptor = new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata requestHeaders, ServerCallHandler next) { + logger.fine("Received following metadata: " + requestHeaders); + + // Make a copy of the headers so that it can be read in a thread-safe manner when copying + // it to the response headers. + Metadata headersToReturn = new Metadata(); + headersToReturn.merge(requestHeaders); + + return next.startCall(new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata responseHeaders) { + responseHeaders.merge(headersToReturn); + super.sendHeaders(responseHeaders); + } + + @Override + public void close(Status status, Metadata trailers) { + super.close(status, trailers); + } + }, requestHeaders); + } + }; + + SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl = + new SimpleServiceGrpc.SimpleServiceImplBase() { + @Override + public void unaryRpc( + SimpleRequest request, StreamObserver responseObserver) { + SimpleResponse response = + SimpleResponse.newBuilder().setResponseMessage("Hi, xDS!").build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + }; + + XdsServerBuilder serverBuilder = XdsServerBuilder.forPort( + 0, InsecureServerCredentials.create()) + .addService(simpleServiceImpl) + .intercept(metadataInterceptor) + .overrideBootstrapForTest(bootstrapOverride); + server = serverBuilder.build().start(); + logger.log(Level.FINE, "data plane server started"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java index dfe4fb4953f..3f0927fb8d3 100644 --- a/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java +++ b/xds/src/test/java/io/grpc/xds/FakeControlPlaneXdsIntegrationTest.java @@ -18,48 +18,16 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_CDS; -import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_EDS; -import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_LDS; -import static io.grpc.xds.XdsTestControlPlaneService.ADS_TYPE_URL_RDS; import static org.junit.Assert.assertEquals; import com.github.xds.type.v3.TypedStruct; -import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; -import com.google.protobuf.Message; import com.google.protobuf.Struct; -import com.google.protobuf.UInt32Value; import com.google.protobuf.Value; -import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; -import io.envoyproxy.envoy.config.core.v3.Address; -import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; -import io.envoyproxy.envoy.config.core.v3.ConfigSource; -import io.envoyproxy.envoy.config.core.v3.HealthStatus; -import io.envoyproxy.envoy.config.core.v3.SocketAddress; -import io.envoyproxy.envoy.config.core.v3.TrafficDirection; import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; -import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; -import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; -import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; -import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; -import io.envoyproxy.envoy.config.listener.v3.ApiListener; -import io.envoyproxy.envoy.config.listener.v3.Filter; -import io.envoyproxy.envoy.config.listener.v3.FilterChain; -import io.envoyproxy.envoy.config.listener.v3.FilterChainMatch; -import io.envoyproxy.envoy.config.listener.v3.Listener; -import io.envoyproxy.envoy.config.route.v3.NonForwardingAction; -import io.envoyproxy.envoy.config.route.v3.Route; -import io.envoyproxy.envoy.config.route.v3.RouteAction; -import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; -import io.envoyproxy.envoy.config.route.v3.RouteMatch; -import io.envoyproxy.envoy.config.route.v3.VirtualHost; -import io.envoyproxy.envoy.extensions.filters.http.router.v3.Router; -import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager; -import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; -import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; import io.envoyproxy.envoy.extensions.load_balancing_policies.wrr_locality.v3.WrrLocality; import io.grpc.CallOptions; import io.grpc.Channel; @@ -67,35 +35,16 @@ import io.grpc.ClientInterceptor; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener; -import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; -import io.grpc.Grpc; -import io.grpc.InsecureChannelCredentials; -import io.grpc.InsecureServerCredentials; import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; -import io.grpc.NameResolverRegistry; -import io.grpc.Server; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.grpc.netty.NettyServerBuilder; -import io.grpc.stub.StreamObserver; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; -import java.net.InetSocketAddress; -import java.util.Collections; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; -import org.junit.After; -import org.junit.Before; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.RuleChain; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -106,104 +55,24 @@ @RunWith(JUnit4.class) public class FakeControlPlaneXdsIntegrationTest { - private static final Logger logger = - Logger.getLogger(FakeControlPlaneXdsIntegrationTest.class.getName()); - private static final String SCHEME = "test-xds"; - private static final String SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT = - "grpc/server?udpa.resource.listening_address="; - private static final String RDS_NAME = "route-config.googleapis.com"; - private static final String CLUSTER_NAME = "cluster0"; - private static final String EDS_NAME = "eds-service-0"; - private static final String HTTP_CONNECTION_MANAGER_TYPE_URL = - "type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3" - + ".HttpConnectionManager"; - - private Server server; - private Server controlPlane; - private XdsTestControlPlaneService controlPlaneService; - private XdsNameResolverProvider nameResolverProvider; - private MetadataLoadBalancerProvider metadataLoadBalancerProvider; - - protected int testServerPort = 0; - protected int controlPlaneServicePort; - protected SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub; - - /** - * For test purpose, use boostrapOverride to programmatically provide bootstrap info. - */ - private Map defaultBootstrapOverride() { - return ImmutableMap.of( - "node", ImmutableMap.of( - "id", UUID.randomUUID().toString(), - "cluster", "cluster0"), - "xds_servers", Collections.singletonList( - - ImmutableMap.of( - "server_uri", "localhost:" + controlPlaneServicePort, - "channel_creds", Collections.singletonList( - ImmutableMap.of("type", "insecure") - ), - "server_features", Collections.singletonList("xds_v3") - ) - ), - "server_listener_resource_name_template", SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT - ); - } + public ControlPlaneRule controlPlane; + public DataPlaneRule dataPlane; /** - * 1. Start control plane server and get control plane port. 2. Start xdsServer using no - * replacement server template, because we do not know the server port yet. Then get the server - * port. 3. Update control plane config using the port in 2 for necessary rds and eds resources to - * set up client and server communication for test cases. + * The {@link ControlPlaneRule} should run before the {@link DataPlaneRule}. */ - @Before - public void setUp() throws Exception { - startControlPlane(); - nameResolverProvider = XdsNameResolverProvider.createForTest(SCHEME, - defaultBootstrapOverride()); - NameResolverRegistry.getDefaultRegistry().register(nameResolverProvider); - metadataLoadBalancerProvider = new MetadataLoadBalancerProvider(); - LoadBalancerRegistry.getDefaultRegistry().register(metadataLoadBalancerProvider); - } - - @After - public void tearDown() throws Exception { - if (server != null) { - server.shutdownNow(); - if (!server.awaitTermination(5, TimeUnit.SECONDS)) { - logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); - } - } - if (controlPlane != null) { - controlPlane.shutdownNow(); - if (!controlPlane.awaitTermination(5, TimeUnit.SECONDS)) { - logger.log(Level.SEVERE, "Timed out waiting for server shutdown"); - } - } - NameResolverRegistry.getDefaultRegistry().deregister(nameResolverProvider); - LoadBalancerRegistry.getDefaultRegistry().deregister(metadataLoadBalancerProvider); + @Rule + public RuleChain ruleChain() { + controlPlane = new ControlPlaneRule(); + dataPlane = new DataPlaneRule(controlPlane); + return RuleChain.outerRule(controlPlane).around(dataPlane); } @Test public void pingPong() throws Exception { - String tcpListenerName = SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT; - String serverHostName = "test-server"; - controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( - tcpListenerName, serverListener(tcpListenerName), - serverHostName, clientListener(serverHostName) - )); - startServer(defaultBootstrapOverride()); - controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, - ImmutableMap.of(RDS_NAME, rds(serverHostName))); - controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, - ImmutableMap.of(CLUSTER_NAME, cds())); - InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); - controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, - ImmutableMap.of(EDS_NAME, eds(edsInetSocketAddress.getHostName(), - edsInetSocketAddress.getPort()))); - ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + serverHostName, - InsecureChannelCredentials.create()).build(); - blockingStub = SimpleServiceGrpc.newBlockingStub(channel); + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); SimpleRequest request = SimpleRequest.newBuilder() .build(); SimpleResponse goldenResponse = SimpleResponse.newBuilder() @@ -214,59 +83,51 @@ serverHostName, clientListener(serverHostName) @Test public void pingPong_metadataLoadBalancer() throws Exception { - String tcpListenerName = SERVER_LISTENER_TEMPLATE_NO_REPLACEMENT; - String serverHostName = "test-server"; - controlPlaneService.setXdsConfig(ADS_TYPE_URL_LDS, ImmutableMap.of( - tcpListenerName, serverListener(tcpListenerName), - serverHostName, clientListener(serverHostName) - )); - startServer(defaultBootstrapOverride()); - controlPlaneService.setXdsConfig(ADS_TYPE_URL_RDS, - ImmutableMap.of(RDS_NAME, rds(serverHostName))); - - // Use the LoadBalancingPolicy to configure a custom LB that adds a header to server calls. - Policy metadataLbPolicy = Policy.newBuilder().setTypedExtensionConfig( - TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( - TypedStruct.newBuilder().setTypeUrl("type.googleapis.com/test.MetadataLoadBalancer") - .setValue(Struct.newBuilder() - .putFields("metadataKey", Value.newBuilder().setStringValue("foo").build()) - .putFields("metadataValue", Value.newBuilder().setStringValue("bar").build())) - .build()))).build(); - Policy wrrLocalityPolicy = Policy.newBuilder() - .setTypedExtensionConfig(TypedExtensionConfig.newBuilder().setTypedConfig( - Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy( - LoadBalancingPolicy.newBuilder().addPolicies(metadataLbPolicy)).build()))).build(); - controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, - ImmutableMap.of(CLUSTER_NAME, cds().toBuilder().setLoadBalancingPolicy( - LoadBalancingPolicy.newBuilder() - .addPolicies(wrrLocalityPolicy)).build())); - - InetSocketAddress edsInetSocketAddress = (InetSocketAddress) server.getListenSockets().get(0); - controlPlaneService.setXdsConfig(ADS_TYPE_URL_EDS, - ImmutableMap.of(EDS_NAME, eds(edsInetSocketAddress.getHostName(), - edsInetSocketAddress.getPort()))); - ManagedChannel channel = Grpc.newChannelBuilder(SCHEME + ":///" + serverHostName, - InsecureChannelCredentials.create()).build(); - ResponseHeaderClientInterceptor responseHeaderInterceptor - = new ResponseHeaderClientInterceptor(); - - // We add an interceptor to catch the response headers from the server. - blockingStub = SimpleServiceGrpc.newBlockingStub(channel) - .withInterceptors(responseHeaderInterceptor); - SimpleRequest request = SimpleRequest.newBuilder() - .build(); - SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setResponseMessage("Hi, xDS!") - .build(); - assertEquals(goldenResponse, blockingStub.unaryRpc(request)); - - // Make sure we got back the header we configured the LB with. - assertThat(responseHeaderInterceptor.reponseHeaders.get( - Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo("bar"); + MetadataLoadBalancerProvider metadataLbProvider = new MetadataLoadBalancerProvider(); + try { + LoadBalancerRegistry.getDefaultRegistry().register(metadataLbProvider); + + // Use the LoadBalancingPolicy to configure a custom LB that adds a header to server calls. + Policy metadataLbPolicy = Policy.newBuilder().setTypedExtensionConfig( + TypedExtensionConfig.newBuilder().setTypedConfig(Any.pack( + TypedStruct.newBuilder().setTypeUrl("type.googleapis.com/test.MetadataLoadBalancer") + .setValue(Struct.newBuilder() + .putFields("metadataKey", Value.newBuilder().setStringValue("foo").build()) + .putFields("metadataValue", Value.newBuilder().setStringValue("bar").build())) + .build()))).build(); + Policy wrrLocalityPolicy = Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder().setTypedConfig( + Any.pack(WrrLocality.newBuilder().setEndpointPickingPolicy( + LoadBalancingPolicy.newBuilder().addPolicies(metadataLbPolicy)).build()))) + .build(); + controlPlane.setCdsConfig( + ControlPlaneRule.buildCluster().toBuilder().setLoadBalancingPolicy( + LoadBalancingPolicy.newBuilder() + .addPolicies(wrrLocalityPolicy)).build()); + + ResponseHeaderClientInterceptor responseHeaderInterceptor + = new ResponseHeaderClientInterceptor(); + + // We add an interceptor to catch the response headers from the server. + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + dataPlane.getManagedChannel()).withInterceptors(responseHeaderInterceptor); + SimpleRequest request = SimpleRequest.newBuilder() + .build(); + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") + .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); + + // Make sure we got back the header we configured the LB with. + assertThat(responseHeaderInterceptor.reponseHeaders.get( + Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER))).isEqualTo("bar"); + } finally { + LoadBalancerRegistry.getDefaultRegistry().deregister(metadataLbProvider); + } } // Captures response headers from the server. - private class ResponseHeaderClientInterceptor implements ClientInterceptor { + private static class ResponseHeaderClientInterceptor implements ClientInterceptor { Metadata reponseHeaders; @Override @@ -292,173 +153,23 @@ public void onHeaders(Metadata headers) { } } - private void startServer(Map bootstrapOverride) throws Exception { - ServerInterceptor metadataInterceptor = new ServerInterceptor() { - @Override - public ServerCall.Listener interceptCall(ServerCall call, - Metadata requestHeaders, ServerCallHandler next) { - logger.fine("Received following metadata: " + requestHeaders); - - // Make a copy of the headers so that it can be read in a thread-safe manner when copying - // it to the response headers. - Metadata headersToReturn = new Metadata(); - headersToReturn.merge(requestHeaders); - - return next.startCall(new SimpleForwardingServerCall(call) { - @Override - public void sendHeaders(Metadata responseHeaders) { - responseHeaders.merge(headersToReturn); - super.sendHeaders(responseHeaders); - } - - @Override - public void close(Status status, Metadata trailers) { - super.close(status, trailers); - } - }, requestHeaders); - } - }; - - SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl = - new SimpleServiceGrpc.SimpleServiceImplBase() { - @Override - public void unaryRpc( - SimpleRequest request, StreamObserver responseObserver) { - SimpleResponse response = - SimpleResponse.newBuilder().setResponseMessage("Hi, xDS!").build(); - responseObserver.onNext(response); - responseObserver.onCompleted(); - } - }; - - XdsServerBuilder serverBuilder = XdsServerBuilder.forPort( - 0, InsecureServerCredentials.create()) - .addService(simpleServiceImpl) - .intercept(metadataInterceptor) - .overrideBootstrapForTest(bootstrapOverride); - server = serverBuilder.build().start(); - testServerPort = server.getPort(); - logger.log(Level.FINE, "server started"); - } - - private void startControlPlane() throws Exception { - controlPlaneService = new XdsTestControlPlaneService(); - NettyServerBuilder controlPlaneServerBuilder = - NettyServerBuilder.forPort(0) - .addService(controlPlaneService); - controlPlane = controlPlaneServerBuilder.build().start(); - controlPlaneServicePort = controlPlane.getPort(); - } - - private static Listener clientListener(String name) { - HttpFilter httpFilter = HttpFilter.newBuilder() - .setName("terminal-filter") - .setTypedConfig(Any.pack(Router.newBuilder().build())) - .setIsOptional(true) - .build(); - ApiListener apiListener = ApiListener.newBuilder().setApiListener(Any.pack( - io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3 - .HttpConnectionManager.newBuilder() - .setRds( - Rds.newBuilder() - .setRouteConfigName(RDS_NAME) - .setConfigSource( - ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.getDefaultInstance()))) - .addAllHttpFilters(Collections.singletonList(httpFilter)) - .build(), - HTTP_CONNECTION_MANAGER_TYPE_URL)).build(); - return Listener.newBuilder() - .setName(name) - .setApiListener(apiListener).build(); - } - - private static Listener serverListener(String name) { - HttpFilter routerFilter = HttpFilter.newBuilder() - .setName("terminal-filter") - .setTypedConfig( - Any.pack(Router.newBuilder().build())) - .setIsOptional(true) - .build(); - VirtualHost virtualHost = io.envoyproxy.envoy.config.route.v3.VirtualHost.newBuilder() - .setName("virtual-host-0") - .addDomains("*") - .addRoutes( - Route.newBuilder() - .setMatch( - RouteMatch.newBuilder().setPrefix("/").build()) - .setNonForwardingAction(NonForwardingAction.newBuilder().build()) - .build()).build(); - RouteConfiguration routeConfig = RouteConfiguration.newBuilder() - .addVirtualHosts(virtualHost) - .build(); - Filter filter = Filter.newBuilder() - .setName("network-filter-0") - .setTypedConfig( - Any.pack( - HttpConnectionManager.newBuilder() - .setRouteConfig(routeConfig) - .addAllHttpFilters(Collections.singletonList(routerFilter)) - .build())).build(); - FilterChainMatch filterChainMatch = FilterChainMatch.newBuilder() - .setSourceType(FilterChainMatch.ConnectionSourceType.ANY) - .build(); - FilterChain filterChain = FilterChain.newBuilder() - .setName("filter-chain-0") - .setFilterChainMatch(filterChainMatch) - .addFilters(filter) - .build(); - return Listener.newBuilder() - .setName(name) - .setTrafficDirection(TrafficDirection.INBOUND) - .addFilterChains(filterChain) - .build(); - } - - private static RouteConfiguration rds(String authority) { - VirtualHost virtualHost = VirtualHost.newBuilder() - .addDomains(authority) - .addRoutes( - Route.newBuilder() - .setMatch( - RouteMatch.newBuilder().setPrefix("/").build()) - .setRoute( - RouteAction.newBuilder().setCluster(CLUSTER_NAME).build()).build()).build(); - return RouteConfiguration.newBuilder().setName(RDS_NAME).addVirtualHosts(virtualHost).build(); - } - - private static Cluster cds() { - return Cluster.newBuilder() - .setName(CLUSTER_NAME) - .setType(Cluster.DiscoveryType.EDS) - .setEdsClusterConfig( - Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_NAME) - .setEdsConfig( - ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder().build()) - .build()) - .build()) - .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) + /** + * Basic test to make sure RING_HASH configuration works. + */ + @Test + public void pingPong_ringHash() { + controlPlane.setCdsConfig( + ControlPlaneRule.buildCluster().toBuilder() + .setLbPolicy(LbPolicy.RING_HASH).build()); + + ManagedChannel channel = dataPlane.getManagedChannel(); + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = SimpleServiceGrpc.newBlockingStub( + channel); + SimpleRequest request = SimpleRequest.newBuilder() .build(); - } - - private static ClusterLoadAssignment eds(String hostName, int port) { - Address address = Address.newBuilder() - .setSocketAddress( - SocketAddress.newBuilder().setAddress(hostName).setPortValue(port).build()).build(); - LocalityLbEndpoints endpoints = LocalityLbEndpoints.newBuilder() - .setLoadBalancingWeight(UInt32Value.of(10)) - .setPriority(0) - .addLbEndpoints( - LbEndpoint.newBuilder() - .setEndpoint( - Endpoint.newBuilder().setAddress(address).build()) - .setHealthStatus(HealthStatus.HEALTHY) - .build()).build(); - return ClusterLoadAssignment.newBuilder() - .setClusterName(EDS_NAME) - .addEndpoints(endpoints) + SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setResponseMessage("Hi, xDS!") .build(); + assertEquals(goldenResponse, blockingStub.unaryRpc(request)); } } diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java index f3d9acda234..e7a3a28e6aa 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -155,8 +155,9 @@ public void tearDown() throws Exception { @Test public void pickAfterResolved() throws Exception { final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(createArgsCaptor.capture()); @@ -206,9 +207,10 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(currentServers).setAttributes(affinity) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -228,8 +230,9 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { // This time with Attributes List latestServers = Lists.newArrayList(oldEag2, newEag); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(latestServers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); verify(newSubchannel, times(1)).requestConnection(); verify(oldSubchannel, times(1)).updateAddresses(Arrays.asList(oldEag2)); @@ -247,25 +250,16 @@ public void pickAfterResolvedUpdatedHosts() throws Exception { picker = pickerCaptor.getValue(); assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel); - // test going from non-empty to empty - loadBalancer.handleResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(Collections.emptyList()) - .setAttributes(affinity) - .build()); - - inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); - verifyNoMoreInteractions(mockHelper); } @Test public void pickAfterStateChange() throws Exception { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); Subchannel subchannel = loadBalancer.getSubchannels().iterator().next(); Ref subchannelStateInfo = subchannel.getAttributes().get( STATE_INFO); @@ -305,9 +299,10 @@ public void pickAfterConfigChange() { final LeastRequestConfig oldConfig = new LeastRequestConfig(4); final LeastRequestConfig newConfig = new LeastRequestConfig(6); final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) .setLoadBalancingPolicyConfig(oldConfig).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(mockHelper, times(2)) @@ -317,9 +312,10 @@ public void pickAfterConfigChange() { pickerCaptor.getValue().pickSubchannel(mockArgs); verify(mockRandom, times(oldConfig.choiceCount)).nextInt(1); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity) .setLoadBalancingPolicyConfig(newConfig).build()); + assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)) .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture()); @@ -332,9 +328,10 @@ public void pickAfterConfigChange() { @Test public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); loadBalancer.shutdown(); @@ -351,9 +348,10 @@ public void ignoreShutdownSubchannelStateChange() { @Test public void stayTransientFailureUntilReady() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -389,9 +387,10 @@ public void stayTransientFailureUntilReady() { @Test public void refreshNameResolutionWhenSubchannelConnectionBroken() { InOrder inOrder = inOrder(mockHelper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); @@ -419,10 +418,11 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { public void pickerLeastRequest() throws Exception { int choiceCount = 2; // This should add inFlight counters to all subchannels. - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) .build()); + assertThat(addressesAccepted).isTrue(); assertEquals(3, loadBalancer.getSubchannels().size()); @@ -505,10 +505,11 @@ public void nameResolutionErrorWithNoChannels() throws Exception { public void nameResolutionErrorWithActiveChannels() throws Exception { int choiceCount = 8; final Subchannel readySubchannel = subchannels.values().iterator().next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setLoadBalancingPolicyConfig(new LeastRequestConfig(choiceCount)) .setAddresses(servers).setAttributes(affinity).build()); + assertThat(addressesAccepted).isTrue(); deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); @@ -538,9 +539,10 @@ public void subchannelStateIsolation() throws Exception { Subchannel sc2 = subchannelIterator.next(); Subchannel sc3 = subchannelIterator.next(); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(Attributes.EMPTY) .build()); + assertThat(addressesAccepted).isTrue(); verify(sc1, times(1)).requestConnection(); verify(sc2, times(1)).requestConnection(); verify(sc3, times(1)).requestConnection(); @@ -613,6 +615,15 @@ public void internalPickerComparisons() { assertFalse(ready5.isEquivalentTo(ready6)); } + @Test + public void emptyAddresses() { + assertThat(loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(affinity) + .build())).isFalse(); + } + private static List getList(SubchannelPicker picker) { return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() : Collections.emptyList(); diff --git a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java index 75aaa00452b..c7217cb45e3 100644 --- a/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadBalancerConfigFactoryTest.java @@ -47,7 +47,7 @@ import io.grpc.internal.JsonUtil; import io.grpc.internal.ServiceConfigUtil; import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import java.util.List; import org.junit.After; import org.junit.Test; diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java index 6b1fae48f1d..a005f40fad7 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java @@ -22,6 +22,7 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.atLeastOnce; @@ -686,6 +687,37 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { verifyNoMoreInteractions(helper); } + @Test + public void noDuplicateOverallBalancingStateUpdate() { + FakeLoadBalancerProvider fakeLbProvider = new FakeLoadBalancerProvider(); + + PriorityChildConfig priorityChildConfig0 = + new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), true); + PriorityChildConfig priorityChildConfig1 = + new PriorityChildConfig(new PolicySelection(fakeLbProvider, new Object()), false); + PriorityLbConfig priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0), + ImmutableList.of("p0")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + + priorityLbConfig = + new PriorityLbConfig( + ImmutableMap.of("p0", priorityChildConfig0, "p1", priorityChildConfig1), + ImmutableList.of("p0", "p1")); + priorityLb.handleResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .setLoadBalancingPolicyConfig(priorityLbConfig) + .build()); + + verify(helper, times(6)).updateBalancingState(any(), any()); + } + private void assertLatestConnectivityState(ConnectivityState expectedState) { verify(helper, atLeastOnce()) .updateBalancingState(connectivityStateCaptor.capture(), pickerCaptor.capture()); @@ -714,4 +746,49 @@ private void assertCurrentPickerIsBufferPicker() { PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); assertThat(pickResult).isEqualTo(PickResult.withNoResult()); } + + private static class FakeLoadBalancerProvider extends LoadBalancerProvider { + + @Override + public boolean isAvailable() { + return true; + } + + @Override + public int getPriority() { + return 5; + } + + @Override + public String getPolicyName() { + return "foo"; + } + + @Override + public LoadBalancer newLoadBalancer(Helper helper) { + return new FakeLoadBalancer(helper); + } + } + + static class FakeLoadBalancer extends LoadBalancer { + + private Helper helper; + + FakeLoadBalancer(Helper helper) { + this.helper = helper; + } + + @Override + public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(Status.INTERNAL)); + } + + @Override + public void handleNameResolutionError(Status error) { + } + + @Override + public void shutdown() { + } + } } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index 9cfa00bc848..aae297d1acc 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -156,9 +156,10 @@ public void tearDown() { public void subchannelLazyConnectUntilPicked() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1); // one server - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); verify(subchannel, never()).requestConnection(); @@ -187,9 +188,10 @@ public void subchannelLazyConnectUntilPicked() { public void subchannelNotAutoReconnectAfterReenteringIdle() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1); // one server - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); InOrder inOrder = Mockito.inOrder(helper, subchannel); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -217,9 +219,10 @@ public void aggregateSubchannelStates_connectingReadyIdleFailure() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -278,9 +281,10 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1, 1); InOrder inOrder = Mockito.inOrder(helper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -336,9 +340,10 @@ public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { public void subchannelStayInTransientFailureUntilBecomeReady() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); reset(helper); @@ -378,9 +383,10 @@ public void updateConnectionIterator() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1); InOrder inOrder = Mockito.inOrder(helper); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -394,9 +400,10 @@ public void updateConnectionIterator() { verifyConnection(1); servers = createWeightedServerAddrs(1,1); - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper) .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); verifyConnection(1); @@ -422,9 +429,10 @@ public void updateConnectionIterator() { public void ignoreShutdownSubchannelStateChange() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -442,9 +450,10 @@ public void ignoreShutdownSubchannelStateChange() { public void deterministicPickWithHostsPartiallyRemoved() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1, 1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); InOrder inOrder = Mockito.inOrder(helper); inOrder.verify(helper, times(5)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -470,9 +479,10 @@ public void deterministicPickWithHostsPartiallyRemoved() { Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build(); updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr)); } - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(subchannels.get(Collections.singletonList(servers.get(0)))) .updateAddresses(Collections.singletonList(updatedServers.get(0))); verify(subchannels.get(Collections.singletonList(servers.get(1)))) @@ -487,9 +497,10 @@ public void deterministicPickWithHostsPartiallyRemoved() { public void deterministicPickWithNewHostsAdded() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1, 1); // server0 and server1 - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); InOrder inOrder = Mockito.inOrder(helper); inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); @@ -511,9 +522,10 @@ public void deterministicPickWithNewHostsAdded() { assertThat(subchannel.getAddresses()).isEqualTo(servers.get(1)); servers = createWeightedServerAddrs(1, 1, 1, 1, 1); // server2, server3, server4 added - loadBalancer.handleResolvedAddresses( + addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); inOrder.verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel()) @@ -526,9 +538,10 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE reset(helper); @@ -583,9 +596,10 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE reset(helper); @@ -649,9 +663,10 @@ public void allSubchannelsInTransientFailure() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -687,9 +702,10 @@ public void firstSubchannelIdle() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -718,9 +734,10 @@ public void firstSubchannelConnecting() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -749,9 +766,10 @@ public void firstSubchannelFailure() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // ring: @@ -784,9 +802,10 @@ public void secondSubchannelConnecting() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // ring: @@ -822,9 +841,10 @@ public void secondSubchannelFailure() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // ring: @@ -864,9 +884,10 @@ public void thirdSubchannelConnecting() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // ring: @@ -908,9 +929,10 @@ public void stickyTransientFailure() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List servers = createWeightedServerAddrs(1, 1, 1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -943,9 +965,10 @@ public void stickyTransientFailure() { public void hostSelectionProportionalToWeights() { RingHashConfig config = new RingHashConfig(10000, 100000); // large ring List servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -979,9 +1002,10 @@ public void hostSelectionProportionalToWeights() { public void hostSelectionProportionalToRepeatedAddressCount() { RingHashConfig config = new RingHashConfig(10000, 100000); List servers = createRepeatedServerAddrs(1, 10, 100); // 1:10:100 - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); @@ -1027,9 +1051,10 @@ public void nameResolutionErrorWithNoActiveSubchannels() { public void nameResolutionErrorWithActiveSubchannels() { RingHashConfig config = new RingHashConfig(10, 100); List servers = createWeightedServerAddrs(1); - loadBalancer.handleResolvedAddresses( + boolean addressesAccepted = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + assertThat(addressesAccepted).isTrue(); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java similarity index 99% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java index dae17bfb89d..993fb910f3e 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplDataTest.java @@ -115,9 +115,7 @@ import io.grpc.lookup.v1.NameMatcher; import io.grpc.lookup.v1.RouteLookupClusterSpecifier; import io.grpc.lookup.v1.RouteLookupConfig; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; import io.grpc.xds.ClusterSpecifierPlugin.NamedPluginConfig; import io.grpc.xds.ClusterSpecifierPlugin.PluginConfig; import io.grpc.xds.Endpoints.LbEndpoint; @@ -130,6 +128,7 @@ import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsResourceType.StructOrError; import io.grpc.xds.internal.Matchers.FractionMatcher; @@ -152,7 +151,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class ClientXdsClientDataTest { +public class XdsClientImplDataTest { private static final ServerInfo LRS_SERVER_INFO = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create(), true); @@ -2699,21 +2698,28 @@ public void validateUpstreamTlsContext_noCommonTlsContext() throws ResourceInval @Test public void validateResourceName() { String traditionalResource = "cluster1.google.com"; - assertThat(XdsClient.isResourceNameValid(traditionalResource, ResourceType.CDS.typeUrl())) + assertThat(XdsClient.isResourceNameValid(traditionalResource, + XdsClusterResource.getInstance().typeUrl())) .isTrue(); - assertThat(XdsClient.isResourceNameValid(traditionalResource, ResourceType.RDS.typeUrlV2())) + assertThat(XdsClient.isResourceNameValid(traditionalResource, + XdsRouteConfigureResource.getInstance().typeUrlV2())) .isTrue(); String invalidPath = "xdstp:/abc/efg"; - assertThat(XdsClient.isResourceNameValid(invalidPath, ResourceType.CDS.typeUrl())).isFalse(); + assertThat(XdsClient.isResourceNameValid(invalidPath, + XdsClusterResource.getInstance().typeUrl())).isFalse(); String invalidPath2 = "xdstp:///envoy.config.route.v3.RouteConfiguration"; - assertThat(XdsClient.isResourceNameValid(invalidPath2, ResourceType.RDS.typeUrl())).isFalse(); + assertThat(XdsClient.isResourceNameValid(invalidPath2, + XdsRouteConfigureResource.getInstance().typeUrl())).isFalse(); String typeMatch = "xdstp:///envoy.config.route.v3.RouteConfiguration/foo/route1"; - assertThat(XdsClient.isResourceNameValid(typeMatch, ResourceType.LDS.typeUrl())).isFalse(); - assertThat(XdsClient.isResourceNameValid(typeMatch, ResourceType.RDS.typeUrl())).isTrue(); - assertThat(XdsClient.isResourceNameValid(typeMatch, ResourceType.RDS.typeUrlV2())).isFalse(); + assertThat(XdsClient.isResourceNameValid(typeMatch, + XdsListenerResource.getInstance().typeUrl())).isFalse(); + assertThat(XdsClient.isResourceNameValid(typeMatch, + XdsRouteConfigureResource.getInstance().typeUrl())).isTrue(); + assertThat(XdsClient.isResourceNameValid(typeMatch, + XdsRouteConfigureResource.getInstance().typeUrlV2())).isFalse(); } @Test diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/XdsClientImplTestBase.java similarity index 97% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplTestBase.java index 788110e1c20..31d10abd841 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplTestBase.java @@ -18,10 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.AbstractXdsClient.ResourceType.CDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; -import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; @@ -66,12 +62,9 @@ import io.grpc.internal.ServiceConfigUtil.LbConfig; import io.grpc.internal.TimeProvider; import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.AbstractXdsClient.ResourceType; import io.grpc.xds.Bootstrapper.AuthorityInfo; import io.grpc.xds.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.Bootstrapper.ServerInfo; -import io.grpc.xds.ClientXdsClient.ResourceInvalidException; -import io.grpc.xds.ClientXdsClient.XdsChannelFactory; import io.grpc.xds.Endpoints.DropOverload; import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; @@ -86,6 +79,8 @@ import io.grpc.xds.XdsClient.ResourceMetadata.UpdateFailureState; import io.grpc.xds.XdsClient.ResourceUpdate; import io.grpc.xds.XdsClient.ResourceWatcher; +import io.grpc.xds.XdsClientImpl.ResourceInvalidException; +import io.grpc.xds.XdsClientImpl.XdsChannelFactory; import io.grpc.xds.XdsClusterResource.CdsUpdate; import io.grpc.xds.XdsClusterResource.CdsUpdate.ClusterType; import io.grpc.xds.XdsEndpointResource.EdsUpdate; @@ -117,10 +112,10 @@ import org.mockito.MockitoAnnotations; /** - * Tests for {@link ClientXdsClient}. + * Tests for {@link XdsClientImpl}. */ @RunWith(JUnit4.class) -public abstract class ClientXdsClientTestBase { +public abstract class XdsClientImplTestBase { private static final String SERVER_URI = "trafficdirector.googleapis.com"; private static final String SERVER_URI_CUSTOME_AUTHORITY = "trafficdirector2.googleapis.com"; private static final String SERVER_URI_EMPTY_AUTHORITY = "trafficdirector3.googleapis.com"; @@ -137,6 +132,10 @@ public abstract class ClientXdsClientTestBase { private static final Node NODE = Node.newBuilder().setId(NODE_ID).build(); private static final Any FAILING_ANY = MessageFactory.FAILING_ANY; private static final ChannelCredentials CHANNEL_CREDENTIALS = InsecureChannelCredentials.create(); + private static final XdsResourceType LDS = XdsListenerResource.getInstance(); + private static final XdsResourceType CDS = XdsClusterResource.getInstance(); + private static final XdsResourceType RDS = XdsRouteConfigureResource.getInstance(); + private static final XdsResourceType EDS = XdsEndpointResource.getInstance(); // xDS control plane server info. private ServerInfo xdsServerInfo; @@ -264,7 +263,7 @@ public long currentTimeNanos() { private ManagedChannel channel; private ManagedChannel channelForCustomAuthority; private ManagedChannel channelForEmptyAuthority; - private ClientXdsClient xdsClient; + private XdsClientImpl xdsClient; private boolean originalEnableFaultInjection; private boolean originalEnableRbac; private boolean originalEnableLeastRequest; @@ -342,7 +341,7 @@ SERVER_URI_EMPTY_AUTHORITY, CHANNEL_CREDENTIALS, useProtocolV3()))))) CertificateProviderInfo.create("file-watcher", ImmutableMap.of()))) .build(); xdsClient = - new ClientXdsClient( + new XdsClientImpl( xdsChannelFactory, bootstrapInfo, Context.ROOT, @@ -399,15 +398,34 @@ protected static boolean matchErrorDetail( private void verifySubscribedResourcesMetadataSizes( int ldsSize, int cdsSize, int rdsSize, int edsSize) { - Map> subscribedResourcesMetadata = + Map, Map> subscribedResourcesMetadata = awaitSubscribedResourcesMetadata(); - assertThat(subscribedResourcesMetadata.get(LDS)).hasSize(ldsSize); - assertThat(subscribedResourcesMetadata.get(CDS)).hasSize(cdsSize); - assertThat(subscribedResourcesMetadata.get(RDS)).hasSize(rdsSize); - assertThat(subscribedResourcesMetadata.get(EDS)).hasSize(edsSize); + Map> subscribedTypeUrls = + xdsClient.getSubscribedResourceTypesWithTypeUrl(); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, LDS, ldsSize); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, CDS, cdsSize); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, RDS, rdsSize); + verifyResourceCount(subscribedTypeUrls, subscribedResourcesMetadata, EDS, edsSize); + } + + private void verifyResourceCount( + Map> subscribedTypeUrls, + Map, Map> subscribedResourcesMetadata, + XdsResourceType type, + int size) { + if (size == 0) { + assertThat(subscribedTypeUrls.containsKey(type.typeUrl())).isFalse(); + assertThat(subscribedTypeUrls.containsKey(type.typeUrlV2())).isFalse(); + assertThat(subscribedResourcesMetadata.containsKey(type)).isFalse(); + } else { + assertThat(subscribedTypeUrls.containsKey(type.typeUrl())).isTrue(); + assertThat(subscribedTypeUrls.containsKey(type.typeUrlV2())).isTrue(); + assertThat(subscribedResourcesMetadata.get(type)).hasSize(size); + } } - private Map> awaitSubscribedResourcesMetadata() { + private Map, Map> + awaitSubscribedResourcesMetadata() { try { return xdsClient.getSubscribedResourcesMetadataSnapshot().get(20, TimeUnit.SECONDS); } catch (Exception e) { @@ -419,20 +437,20 @@ private Map> awaitSubscribedResource } /** Verify the resource requested, but not updated. */ - private void verifyResourceMetadataRequested(ResourceType type, String resourceName) { + private void verifyResourceMetadataRequested(XdsResourceType type, String resourceName) { verifyResourceMetadata( type, resourceName, null, ResourceMetadataStatus.REQUESTED, "", 0, false); } /** Verify that the requested resource does not exist. */ - private void verifyResourceMetadataDoesNotExist(ResourceType type, String resourceName) { + private void verifyResourceMetadataDoesNotExist(XdsResourceType type, String resourceName) { verifyResourceMetadata( type, resourceName, null, ResourceMetadataStatus.DOES_NOT_EXIST, "", 0, false); } /** Verify the resource to be acked. */ private void verifyResourceMetadataAcked( - ResourceType type, String resourceName, Any rawResource, String versionInfo, + XdsResourceType type, String resourceName, Any rawResource, String versionInfo, long updateTimeNanos) { verifyResourceMetadata(type, resourceName, rawResource, ResourceMetadataStatus.ACKED, versionInfo, updateTimeNanos, false); @@ -443,7 +461,7 @@ private void verifyResourceMetadataAcked( * corresponding i-th element of {@code List failedDetails}. */ private void verifyResourceMetadataNacked( - ResourceType type, String resourceName, Any rawResource, String versionInfo, + XdsResourceType type, String resourceName, Any rawResource, String versionInfo, long updateTime, String failedVersion, long failedUpdateTimeNanos, List failedDetails) { ResourceMetadata resourceMetadata = @@ -465,7 +483,7 @@ private void verifyResourceMetadataNacked( } private ResourceMetadata verifyResourceMetadata( - ResourceType type, String resourceName, Any rawResource, ResourceMetadataStatus status, + XdsResourceType type, String resourceName, Any rawResource, ResourceMetadataStatus status, String versionInfo, long updateTimeNanos, boolean hasErrorState) { ResourceMetadata metadata = awaitSubscribedResourcesMetadata().get(type).get(resourceName); assertThat(metadata).isNotNull(); @@ -589,7 +607,7 @@ public void ldsResourceNotFound() { verifyResourceMetadataRequested(LDS, LDS_RESOURCE); verifySubscribedResourcesMetadataSizes(1, 0, 0, 0); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(LDS, LDS_RESOURCE); @@ -872,7 +890,7 @@ public void cachedLdsResource_data() { public void cachedLdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LDS_RESOURCE, ldsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); @@ -1264,7 +1282,7 @@ public void multipleLdsWatchers() { verifyResourceMetadataRequested(LDS, ldsResourceTwo); verifySubscribedResourcesMetadataSizes(2, 0, 0, 0); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(ldsResourceTwo); verify(watcher2).onResourceDoesNotExist(ldsResourceTwo); @@ -1297,7 +1315,7 @@ public void rdsResourceNotFound() { RDS_RESOURCE, rdsResourceWatcher); Any routeConfig = Any.pack(mf.buildRouteConfiguration("route-bar.googleapis.com", mf.buildOpaqueVirtualHosts(2))); - call.sendResponse(ResourceType.RDS, routeConfig, VERSION_1, "0000"); + call.sendResponse(RDS, routeConfig, VERSION_1, "0000"); // Client sends an ACK RDS request. call.verifyRequest(RDS, RDS_RESOURCE, VERSION_1, "0000", NODE); @@ -1305,7 +1323,7 @@ public void rdsResourceNotFound() { verifyResourceMetadataRequested(RDS, RDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 1, 0); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); assertThat(fakeClock.getPendingTasks(RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(RDS, RDS_RESOURCE); @@ -1468,7 +1486,7 @@ public void cachedRdsResource_data() { public void cachedRdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsRouteConfigureResource.getInstance(), RDS_RESOURCE, rdsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); // Add another watcher. ResourceWatcher watcher = mock(ResourceWatcher.class); @@ -1630,7 +1648,7 @@ public void multipleRdsWatchers() { verifyResourceMetadataRequested(RDS, rdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 2, 0); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(rdsResourceWatcher).onResourceDoesNotExist(RDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(rdsResourceTwo); verify(watcher2).onResourceDoesNotExist(rdsResourceTwo); @@ -1677,7 +1695,7 @@ public void cdsResourceNotFound() { verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 1, 0, 0); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); assertThat(fakeClock.getPendingTasks(CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(CDS, CDS_RESOURCE); @@ -1934,7 +1952,7 @@ public void cdsResourceFound_leastRequestLbPolicy() { mf.buildEdsCluster(CDS_RESOURCE, null, "least_request_experimental", null, leastRequestConfig, false, null, "envoy.transport_sockets.tls", null, null )); - call.sendResponse(ResourceType.CDS, clusterRingHash, VERSION_1, "0000"); + call.sendResponse(CDS, clusterRingHash, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); @@ -1966,7 +1984,7 @@ public void cdsResourceFound_ringHashLbPolicy() { mf.buildEdsCluster(CDS_RESOURCE, null, "ring_hash_experimental", ringHashConfig, null, false, null, "envoy.transport_sockets.tls", null, null )); - call.sendResponse(ResourceType.CDS, clusterRingHash, VERSION_1, "0000"); + call.sendResponse(CDS, clusterRingHash, VERSION_1, "0000"); // Client sent an ACK CDS request. call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); @@ -2138,7 +2156,7 @@ public void cdsResponseErrorHandling_badUpstreamTlsContext() { // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + "Cluster cluster.googleapis.com: malformed UpstreamTlsContext: " - + "io.grpc.xds.ClientXdsClient$ResourceInvalidException: " + + "io.grpc.xds.XdsClientImpl$ResourceInvalidException: " + "ca_certificate_provider_instance is required in upstream-tls-context"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); verify(cdsResourceWatcher).onError(errorCaptor.capture()); @@ -2288,7 +2306,7 @@ public void cdsResponseWithInvalidOutlierDetectionNacks() { String errorMsg = "CDS response Cluster 'cluster.googleapis.com' validation error: " + "Cluster cluster.googleapis.com: malformed outlier_detection: " - + "io.grpc.xds.ClientXdsClient$ResourceInvalidException: outlier_detection " + + "io.grpc.xds.XdsClientImpl$ResourceInvalidException: outlier_detection " + "max_ejection_percent is > 100"; call.verifyRequestNack(CDS, CDS_RESOURCE, "", "0000", NODE, ImmutableList.of(errorMsg)); verify(cdsResourceWatcher).onError(errorCaptor.capture()); @@ -2417,7 +2435,7 @@ public void cachedCdsResource_data() { public void cachedCdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsClusterResource.getInstance(), CDS_RESOURCE, cdsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); ResourceWatcher watcher = mock(ResourceWatcher.class); xdsClient.watchXdsResource(XdsClusterResource.getInstance(),CDS_RESOURCE, watcher); @@ -2612,7 +2630,7 @@ public void multipleCdsWatchers() { verifyResourceMetadataRequested(CDS, cdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(cdsResourceWatcher).onResourceDoesNotExist(CDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(cdsResourceTwo); verify(watcher2).onResourceDoesNotExist(cdsResourceTwo); @@ -2690,7 +2708,7 @@ public void edsResourceNotFound() { verifyResourceMetadataRequested(EDS, EDS_RESOURCE); verifySubscribedResourcesMetadataSizes(0, 0, 0, 1); // Server failed to return subscribed resource within expected time window. - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); assertThat(fakeClock.getPendingTasks(EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); verifyResourceMetadataDoesNotExist(EDS, EDS_RESOURCE); @@ -2860,7 +2878,7 @@ public void cachedEdsResource_data() { public void cachedEdsResource_absent() { DiscoveryRpcCall call = startResourceWatcher(XdsEndpointResource.getInstance(), EDS_RESOURCE, edsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); ResourceWatcher watcher = mock(ResourceWatcher.class); xdsClient.watchXdsResource(XdsEndpointResource.getInstance(),EDS_RESOURCE, watcher); @@ -3034,7 +3052,7 @@ public void multipleEdsWatchers() { verifyResourceMetadataRequested(EDS, edsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(edsResourceWatcher).onResourceDoesNotExist(EDS_RESOURCE); verify(watcher1).onResourceDoesNotExist(edsResourceTwo); verify(watcher2).onResourceDoesNotExist(edsResourceTwo); @@ -3339,7 +3357,7 @@ public void reportLoadStatsToServer() { @Test public void serverSideListenerFound() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = + XdsClientImplTestBase.DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( @@ -3353,10 +3371,9 @@ public void serverSideListenerFound() { Message listener = mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // Client sends an ACK LDS request. - call.verifyRequest( - ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); + call.verifyRequest(LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); EnvoyServerProtoData.Listener parsedListener = ldsUpdateCaptor.getValue().listener(); assertThat(parsedListener.name()).isEqualTo(LISTENER_RESOURCE); @@ -3376,7 +3393,7 @@ public void serverSideListenerFound() { @Test public void serverSideListenerNotFound() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = + XdsClientImplTestBase.DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( @@ -3390,13 +3407,12 @@ public void serverSideListenerNotFound() { Message listener = mf.buildListenerWithFilterChain( "grpc/server?xds.resource.listening_address=0.0.0.0:8000", 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // Client sends an ACK LDS request. - call.verifyRequest( - ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); + call.verifyRequest(LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); verifyNoInteractions(ldsResourceWatcher); - fakeClock.forwardTime(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); + fakeClock.forwardTime(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC, TimeUnit.SECONDS); verify(ldsResourceWatcher).onResourceDoesNotExist(LISTENER_RESOURCE); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @@ -3404,7 +3420,7 @@ public void serverSideListenerNotFound() { @Test public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = + XdsClientImplTestBase.DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( @@ -3418,7 +3434,7 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { Message listener = mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "LDS response Listener \'grpc/server?xds.resource.listening_address=" + "0.0.0.0:7000\' validation error: " @@ -3431,7 +3447,7 @@ public void serverSideListenerResponseErrorHandling_badDownstreamTlsContext() { @Test public void serverSideListenerResponseErrorHandling_badTransportSocketName() { Assume.assumeTrue(useProtocolV3()); - ClientXdsClientTestBase.DiscoveryRpcCall call = + XdsClientImplTestBase.DiscoveryRpcCall call = startResourceWatcher(XdsListenerResource.getInstance(), LISTENER_RESOURCE, ldsResourceWatcher); Message hcmFilter = mf.buildHttpConnectionManagerFilter( @@ -3445,7 +3461,7 @@ public void serverSideListenerResponseErrorHandling_badTransportSocketName() { Message listener = mf.buildListenerWithFilterChain(LISTENER_RESOURCE, 7000, "0.0.0.0", filterChain); List listeners = ImmutableList.of(Any.pack(listener)); - call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); + call.sendResponse(LDS, listeners, "0", "0000"); // The response NACKed with errors indicating indices of the failed resources. String errorMsg = "LDS response Listener \'grpc/server?xds.resource.listening_address=" + "0.0.0.0:7000\' validation error: " @@ -3460,51 +3476,51 @@ private DiscoveryRpcCall startResourceWatcher( XdsResourceType type, String name, ResourceWatcher watcher) { FakeClock.TaskFilter timeoutTaskFilter; switch (type.typeName()) { - case LDS: + case "LDS": timeoutTaskFilter = LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; xdsClient.watchXdsResource(type, name, watcher); break; - case RDS: + case "RDS": timeoutTaskFilter = RDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; xdsClient.watchXdsResource(type, name, watcher); break; - case CDS: + case "CDS": timeoutTaskFilter = CDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; xdsClient.watchXdsResource(type, name, watcher); break; - case EDS: + case "EDS": timeoutTaskFilter = EDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER; xdsClient.watchXdsResource(type, name, watcher); break; - case UNKNOWN: default: throw new AssertionError("should never be here"); } DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); - call.verifyRequest(type.typeName(), Collections.singletonList(name), "", "", NODE); + call.verifyRequest(type, Collections.singletonList(name), "", "", NODE); ScheduledTask timeoutTask = Iterables.getOnlyElement(fakeClock.getPendingTasks(timeoutTaskFilter)); assertThat(timeoutTask.getDelay(TimeUnit.SECONDS)) - .isEqualTo(ClientXdsClient.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC); + .isEqualTo(XdsClientImpl.INITIAL_RESOURCE_FETCH_TIMEOUT_SEC); return call; } protected abstract static class DiscoveryRpcCall { protected abstract void verifyRequest( - ResourceType type, List resources, String versionInfo, String nonce, Node node); + XdsResourceType type, List resources, String versionInfo, String nonce, + Node node); protected void verifyRequest( - ResourceType type, String resource, String versionInfo, String nonce, Node node) { + XdsResourceType type, String resource, String versionInfo, String nonce, Node node) { verifyRequest(type, ImmutableList.of(resource), versionInfo, nonce, node); } protected abstract void verifyRequestNack( - ResourceType type, List resources, String versionInfo, String nonce, Node node, - List errorMessages); + XdsResourceType type, List resources, String versionInfo, String nonce, + Node node, List errorMessages); protected void verifyRequestNack( - ResourceType type, String resource, String versionInfo, String nonce, Node node, + XdsResourceType type, String resource, String versionInfo, String nonce, Node node, List errorMessages) { verifyRequestNack(type, ImmutableList.of(resource), versionInfo, nonce, node, errorMessages); } @@ -3512,9 +3528,10 @@ protected void verifyRequestNack( protected abstract void verifyNoMoreRequest(); protected abstract void sendResponse( - ResourceType type, List resources, String versionInfo, String nonce); + XdsResourceType type, List resources, String versionInfo, String nonce); - protected void sendResponse(ResourceType type, Any resource, String versionInfo, String nonce) { + protected void sendResponse(XdsResourceType type, Any resource, String versionInfo, + String nonce) { sendResponse(type, ImmutableList.of(resource), versionInfo, nonce); } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/XdsClientImplV2Test.java similarity index 98% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplV2Test.java index 3bb6e421388..347a2dd2d39 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplV2Test.java @@ -92,7 +92,6 @@ import io.grpc.Context.CancellationListener; import io.grpc.Status; import io.grpc.stub.StreamObserver; -import io.grpc.xds.AbstractXdsClient.ResourceType; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -108,10 +107,10 @@ import org.mockito.InOrder; /** - * Tests for {@link ClientXdsClient} with protocol version v2. + * Tests for {@link XdsClientImpl} with protocol version v2. */ @RunWith(Parameterized.class) -public class ClientXdsClientV2Test extends ClientXdsClientTestBase { +public class XdsClientImplV2Test extends XdsClientImplTestBase { /** Parameterized test cases. */ @Parameters(name = "ignoreResourceDeletion={0}") @@ -197,7 +196,7 @@ private DiscoveryRpcCallV2(StreamObserver requestObserver, @Override protected void verifyRequest( - ResourceType type, List resources, String versionInfo, String nonce, + XdsResourceType type, List resources, String versionInfo, String nonce, EnvoyProtoData.Node node) { verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNodeV2(), versionInfo, resources, type.typeUrlV2(), nonce, null, null))); @@ -205,7 +204,7 @@ protected void verifyRequest( @Override protected void verifyRequestNack( - ResourceType type, List resources, String versionInfo, String nonce, + XdsResourceType type, List resources, String versionInfo, String nonce, EnvoyProtoData.Node node, List errorMessages) { verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNodeV2(), versionInfo, resources, type.typeUrlV2(), nonce, @@ -219,7 +218,7 @@ protected void verifyNoMoreRequest() { @Override protected void sendResponse( - ResourceType type, List resources, String versionInfo, String nonce) { + XdsResourceType type, List resources, String versionInfo, String nonce) { DiscoveryResponse response = DiscoveryResponse.newBuilder() .setVersionInfo(versionInfo) diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/XdsClientImplV3Test.java similarity index 98% rename from xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java rename to xds/src/test/java/io/grpc/xds/XdsClientImplV3Test.java index 6eb48e5bb00..55f03566c97 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientImplV3Test.java @@ -99,7 +99,6 @@ import io.grpc.Context.CancellationListener; import io.grpc.Status; import io.grpc.stub.StreamObserver; -import io.grpc.xds.AbstractXdsClient.ResourceType; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -116,10 +115,10 @@ import org.mockito.InOrder; /** - * Tests for {@link ClientXdsClient} with protocol version v3. + * Tests for {@link XdsClientImpl} with protocol version v3. */ @RunWith(Parameterized.class) -public class ClientXdsClientV3Test extends ClientXdsClientTestBase { +public class XdsClientImplV3Test extends XdsClientImplTestBase { /** Parameterized test cases. */ @Parameters(name = "ignoreResourceDeletion={0}") @@ -205,7 +204,7 @@ private DiscoveryRpcCallV3(StreamObserver requestObserver, @Override protected void verifyRequest( - ResourceType type, List resources, String versionInfo, String nonce, + XdsResourceType type, List resources, String versionInfo, String nonce, EnvoyProtoData.Node node) { verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNode(), versionInfo, resources, type.typeUrl(), nonce, null, null))); @@ -213,7 +212,7 @@ protected void verifyRequest( @Override protected void verifyRequestNack( - ResourceType type, List resources, String versionInfo, String nonce, + XdsResourceType type, List resources, String versionInfo, String nonce, EnvoyProtoData.Node node, List errorMessages) { verify(requestObserver).onNext(argThat(new DiscoveryRequestMatcher( node.toEnvoyProtoNode(), versionInfo, resources, type.typeUrl(), nonce, @@ -227,7 +226,7 @@ protected void verifyNoMoreRequest() { @Override protected void sendResponse( - ResourceType type, List resources, String versionInfo, String nonce) { + XdsResourceType type, List resources, String versionInfo, String nonce) { DiscoveryResponse response = DiscoveryResponse.newBuilder() .setVersionInfo(versionInfo) diff --git a/xds/src/test/java/io/grpc/xds/XdsClientTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsClientTestHelper.java deleted file mode 100644 index 9ea9b1f8dc2..00000000000 --- a/xds/src/test/java/io/grpc/xds/XdsClientTestHelper.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import com.google.protobuf.Any; -import io.envoyproxy.envoy.config.core.v3.Address; -import io.envoyproxy.envoy.config.listener.v3.ApiListener; -import io.envoyproxy.envoy.config.listener.v3.FilterChain; -import io.envoyproxy.envoy.config.listener.v3.Listener; -import io.envoyproxy.envoy.config.route.v3.Route; -import io.envoyproxy.envoy.config.route.v3.RouteAction; -import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; -import io.envoyproxy.envoy.config.route.v3.RouteMatch; -import io.envoyproxy.envoy.config.route.v3.VirtualHost; -import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; -import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; -import io.grpc.xds.EnvoyProtoData.Node; -import java.util.List; - -/** - * Helper methods for building protobuf messages with custom data for xDS protocols. - */ -// TODO(chengyuanzhang, sanjaypujare): delete this class, should not dump everything here. -class XdsClientTestHelper { - static DiscoveryResponse buildDiscoveryResponse(String versionInfo, - List resources, String typeUrl, String nonce) { - return - DiscoveryResponse.newBuilder() - .setVersionInfo(versionInfo) - .setTypeUrl(typeUrl) - .addAllResources(resources) - .setNonce(nonce) - .build(); - } - - static io.envoyproxy.envoy.api.v2.DiscoveryResponse buildDiscoveryResponseV2(String versionInfo, - List resources, String typeUrl, String nonce) { - return - io.envoyproxy.envoy.api.v2.DiscoveryResponse.newBuilder() - .setVersionInfo(versionInfo) - .setTypeUrl(typeUrl) - .addAllResources(resources) - .setNonce(nonce) - .build(); - } - - static DiscoveryRequest buildDiscoveryRequest(Node node, String versionInfo, - List resourceNames, String typeUrl, String nonce) { - return - DiscoveryRequest.newBuilder() - .setVersionInfo(versionInfo) - .setNode(node.toEnvoyProtoNode()) - .setTypeUrl(typeUrl) - .addAllResourceNames(resourceNames) - .setResponseNonce(nonce) - .build(); - } - - static Listener buildListener(String name, com.google.protobuf.Any apiListener) { - return - Listener.newBuilder() - .setName(name) - .setAddress(Address.getDefaultInstance()) - .addFilterChains(FilterChain.getDefaultInstance()) - .setApiListener(ApiListener.newBuilder().setApiListener(apiListener)) - .build(); - } - - static io.envoyproxy.envoy.api.v2.Listener buildListenerV2( - String name, com.google.protobuf.Any apiListener) { - return - io.envoyproxy.envoy.api.v2.Listener.newBuilder() - .setName(name) - .setAddress(io.envoyproxy.envoy.api.v2.core.Address.getDefaultInstance()) - .addFilterChains(io.envoyproxy.envoy.api.v2.listener.FilterChain.getDefaultInstance()) - .setApiListener(io.envoyproxy.envoy.config.listener.v2.ApiListener.newBuilder() - .setApiListener(apiListener)) - .build(); - } - - static RouteConfiguration buildRouteConfiguration(String name, - List virtualHosts) { - return - RouteConfiguration.newBuilder() - .setName(name) - .addAllVirtualHosts(virtualHosts) - .build(); - } - - static io.envoyproxy.envoy.api.v2.RouteConfiguration buildRouteConfigurationV2(String name, - List virtualHosts) { - return - io.envoyproxy.envoy.api.v2.RouteConfiguration.newBuilder() - .setName(name) - .addAllVirtualHosts(virtualHosts) - .build(); - } - - static VirtualHost buildVirtualHost(List domains, String clusterName) { - return VirtualHost.newBuilder() - .setName("virtualhost00.googleapis.com") // don't care - .addAllDomains(domains) - .addRoutes( - Route.newBuilder() - .setRoute(RouteAction.newBuilder().setCluster(clusterName)) - .setMatch(RouteMatch.newBuilder().setPrefix(""))) - .build(); - } - - static io.envoyproxy.envoy.api.v2.route.VirtualHost buildVirtualHostV2( - List domains, String clusterName) { - return io.envoyproxy.envoy.api.v2.route.VirtualHost.newBuilder() - .setName("virtualhost00.googleapis.com") // don't care - .addAllDomains(domains) - .addRoutes( - io.envoyproxy.envoy.api.v2.route.Route.newBuilder() - .setRoute( - io.envoyproxy.envoy.api.v2.route.RouteAction.newBuilder() - .setCluster(clusterName)) - .setMatch(io.envoyproxy.envoy.api.v2.route.RouteMatch.newBuilder().setPrefix(""))) - .build(); - } -} diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index f2eddf00fe6..2deab4ae688 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -2089,14 +2089,14 @@ void watchXdsResource(XdsResourceType resourceType ResourceWatcher watcher) { switch (resourceType.typeName()) { - case LDS: + case "LDS": assertThat(ldsResource).isNull(); assertThat(ldsWatcher).isNull(); assertThat(resourceName).isEqualTo(expectedLdsResourceName); ldsResource = resourceName; ldsWatcher = (ResourceWatcher) watcher; break; - case RDS: + case "RDS": assertThat(rdsResource).isNull(); assertThat(rdsWatcher).isNull(); rdsResource = resourceName; @@ -2111,14 +2111,14 @@ void cancelXdsResourceWatch(XdsResourceType type, String resourceName, ResourceWatcher watcher) { switch (type.typeName()) { - case LDS: + case "LDS": assertThat(ldsResource).isNotNull(); assertThat(ldsWatcher).isNotNull(); assertThat(resourceName).isEqualTo(expectedLdsResourceName); ldsResource = null; ldsWatcher = null; break; - case RDS: + case "RDS": assertThat(rdsResource).isNotNull(); assertThat(rdsWatcher).isNotNull(); rdsResource = null; diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index c13be0361df..bce71c1c2ba 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -184,12 +184,12 @@ void watchXdsResource(XdsResourceType resourceType String resourceName, ResourceWatcher watcher) { switch (resourceType.typeName()) { - case LDS: + case "LDS": assertThat(ldsWatcher).isNull(); ldsWatcher = (ResourceWatcher) watcher; ldsResource.set(resourceName); break; - case RDS: + case "RDS": //re-register is not allowed. assertThat(rdsWatchers.put(resourceName, (ResourceWatcher)watcher)).isNull(); rdsCount.countDown(); @@ -203,12 +203,12 @@ void cancelXdsResourceWatch(XdsResourceType type, String resourceName, ResourceWatcher watcher) { switch (type.typeName()) { - case LDS: + case "LDS": assertThat(ldsWatcher).isNotNull(); ldsResource = null; ldsWatcher = null; break; - case RDS: + case "RDS": rdsWatchers.remove(resourceName); break; default: diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 81b267587fd..728bb06efec 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -424,7 +424,7 @@ public TestCallback(Executor executor) { } @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 0531189f2ac..863e4dcfecf 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -168,7 +168,7 @@ public void clientSdsHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -245,7 +245,7 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -381,7 +381,7 @@ public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSecret(SslContext sslContext) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 35c1437d34c..8030b458d27 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -85,8 +85,8 @@ public void get_updateSecret() { SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSecret(mockSslContext); - verify(mockCallback, times(1)).updateSecret(eq(mockSslContext)); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class);