diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index 7bd9f4b68e3..8af93d81e09 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -30,17 +30,14 @@ import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; -import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancerProvider; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; -import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; @@ -51,7 +48,6 @@ import io.grpc.lookup.v1.RouteLookupServiceGrpc.RouteLookupServiceStub; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener; -import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy; import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; import io.grpc.rls.LruCache.EvictionListener; @@ -138,7 +134,8 @@ private CachingRlsLbClient(Builder builder) { rlsConfig.getCacheSizeBytes(), builder.evictionListener, scheduledExecutorService, - timeProvider); + timeProvider, + lock); logger = helper.getChannelLogger(); String serverHost = null; try { @@ -181,7 +178,9 @@ private CachingRlsLbClient(Builder builder) { new ChildLoadBalancerHelperProvider(helper, new SubchannelStateManagerImpl(), rlsPicker); refCountedChildPolicyWrapperFactory = new RefCountedChildPolicyWrapperFactory( - childLbHelperProvider, new BackoffRefreshListener()); + lbPolicyConfig.getLoadBalancingPolicy(), childLbResolvedAddressFactory, + childLbHelperProvider, + new BackoffRefreshListener()); logger.log(ChannelLogLevel.DEBUG, "CachingRlsLbClient created"); } @@ -536,6 +535,7 @@ final class DataCacheEntry extends CacheEntry { private final long staleTime; private final ChildPolicyWrapper childPolicyWrapper; + // GuardedBy CachingRlsLbClient.lock DataCacheEntry(RouteLookupRequest request, final RouteLookupResponse response) { super(request); this.response = checkNotNull(response, "response"); @@ -546,29 +546,6 @@ final class DataCacheEntry extends CacheEntry { long now = timeProvider.currentTimeNanos(); expireTime = now + maxAgeNanos; staleTime = now + staleAgeNanos; - - if (childPolicyWrapper.getPicker() != null) { - childPolicyWrapper.refreshState(); - } else { - createChildLbPolicy(); - } - } - - private void createChildLbPolicy() { - ChildLoadBalancingPolicy childPolicy = lbPolicyConfig.getLoadBalancingPolicy(); - LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider(); - ConfigOrError lbConfig = - lbProvider - .parseLoadBalancingPolicyConfig( - childPolicy.getEffectiveChildPolicy(childPolicyWrapper.getTarget())); - - LoadBalancer lb = lbProvider.newLoadBalancer(childPolicyWrapper.getHelper()); - logger.log( - ChannelLogLevel.DEBUG, - "RLS child lb created. config: {0}", - lbConfig.getConfig()); - lb.handleResolvedAddresses(childLbResolvedAddressFactory.create(lbConfig.getConfig())); - lb.requestConnection(); } /** @@ -637,7 +614,9 @@ boolean isStaled(long now) { @Override void cleanup() { - refCountedChildPolicyWrapperFactory.release(childPolicyWrapper); + synchronized (lock) { + refCountedChildPolicyWrapperFactory.release(childPolicyWrapper); + } } @Override @@ -856,14 +835,15 @@ private static final class RlsAsyncLruCache RlsAsyncLruCache(long maxEstimatedSizeBytes, @Nullable EvictionListener evictionListener, - ScheduledExecutorService ses, TimeProvider timeProvider) { + ScheduledExecutorService ses, TimeProvider timeProvider, Object lock) { super( maxEstimatedSizeBytes, new AutoCleaningEvictionListener(evictionListener), 1, TimeUnit.MINUTES, ses, - timeProvider); + timeProvider, + lock); } @Override @@ -985,27 +965,9 @@ private void startFallbackChildPolicy() { } fallbackChildPolicyWrapper = refCountedChildPolicyWrapperFactory.createOrGet(defaultTarget); } - LoadBalancerProvider lbProvider = - lbPolicyConfig.getLoadBalancingPolicy().getEffectiveLbProvider(); - final LoadBalancer lb = - lbProvider.newLoadBalancer(fallbackChildPolicyWrapper.getHelper()); - final ConfigOrError lbConfig = - lbProvider - .parseLoadBalancingPolicyConfig( - lbPolicyConfig - .getLoadBalancingPolicy() - .getEffectiveChildPolicy(defaultTarget)); - helper.getSynchronizationContext().execute( - new Runnable() { - @Override - public void run() { - lb.handleResolvedAddresses( - childLbResolvedAddressFactory.create(lbConfig.getConfig())); - lb.requestConnection(); - } - }); } + // GuardedBy CachingRlsLbClient.lock void close() { if (fallbackChildPolicyWrapper != null) { refCountedChildPolicyWrapperFactory.release(fallbackChildPolicyWrapper); diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java index f54e441ffe5..94a9de9801f 100644 --- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java +++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java @@ -22,12 +22,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; import io.grpc.internal.ObjectPool; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.RlsProtoData.RouteLookupConfig; @@ -191,33 +194,49 @@ public String toString() { /** Factory for {@link ChildPolicyWrapper}. */ static final class RefCountedChildPolicyWrapperFactory { + // GuardedBy CachingRlsLbClient.lock @VisibleForTesting final Map childPolicyMap = new HashMap<>(); private final ChildLoadBalancerHelperProvider childLbHelperProvider; private final ChildLbStatusListener childLbStatusListener; + private final ChildLoadBalancingPolicy childPolicy; + private final ResolvedAddressFactory childLbResolvedAddressFactory; public RefCountedChildPolicyWrapperFactory( + ChildLoadBalancingPolicy childPolicy, + ResolvedAddressFactory childLbResolvedAddressFactory, ChildLoadBalancerHelperProvider childLbHelperProvider, ChildLbStatusListener childLbStatusListener) { + this.childPolicy = checkNotNull(childPolicy, "childPolicy"); + this.childLbResolvedAddressFactory = + checkNotNull(childLbResolvedAddressFactory, "childLbResolvedAddressFactory"); this.childLbHelperProvider = checkNotNull(childLbHelperProvider, "childLbHelperProvider"); this.childLbStatusListener = checkNotNull(childLbStatusListener, "childLbStatusListener"); } + // GuardedBy CachingRlsLbClient.lock ChildPolicyWrapper createOrGet(String target) { // TODO(creamsoup) check if the target is valid or not RefCountedChildPolicyWrapper pooledChildPolicyWrapper = childPolicyMap.get(target); if (pooledChildPolicyWrapper == null) { - ChildPolicyWrapper childPolicyWrapper = - new ChildPolicyWrapper(target, childLbHelperProvider, childLbStatusListener); + ChildPolicyWrapper childPolicyWrapper = new ChildPolicyWrapper( + target, childPolicy, childLbResolvedAddressFactory, childLbHelperProvider, + childLbStatusListener); pooledChildPolicyWrapper = RefCountedChildPolicyWrapper.of(childPolicyWrapper); childPolicyMap.put(target, pooledChildPolicyWrapper); + return pooledChildPolicyWrapper.getObject(); + } else { + ChildPolicyWrapper childPolicyWrapper = pooledChildPolicyWrapper.getObject(); + if (childPolicyWrapper.getPicker() != null) { + childPolicyWrapper.refreshState(); + } + return childPolicyWrapper; } - - return pooledChildPolicyWrapper.getObject(); } + // GuardedBy CachingRlsLbClient.lock void release(ChildPolicyWrapper childPolicyWrapper) { checkNotNull(childPolicyWrapper, "childPolicyWrapper"); String target = childPolicyWrapper.getTarget(); @@ -238,16 +257,36 @@ static final class ChildPolicyWrapper { private final String target; private final ChildPolicyReportingHelper helper; + private final LoadBalancer lb; private volatile SubchannelPicker picker; private ConnectivityState state; public ChildPolicyWrapper( String target, + ChildLoadBalancingPolicy childPolicy, + final ResolvedAddressFactory childLbResolvedAddressFactory, ChildLoadBalancerHelperProvider childLbHelperProvider, ChildLbStatusListener childLbStatusListener) { this.target = target; this.helper = new ChildPolicyReportingHelper(childLbHelperProvider, childLbStatusListener); + LoadBalancerProvider lbProvider = childPolicy.getEffectiveLbProvider(); + final ConfigOrError lbConfig = + lbProvider + .parseLoadBalancingPolicyConfig( + childPolicy.getEffectiveChildPolicy(target)); + this.lb = lbProvider.newLoadBalancer(helper); + helper.getChannelLogger().log( + ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", lbConfig.getConfig()); + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + lb.handleResolvedAddresses( + childLbResolvedAddressFactory.create(lbConfig.getConfig())); + lb.requestConnection(); + } + }); } String getTarget() { @@ -263,7 +302,25 @@ ChildPolicyReportingHelper getHelper() { } void refreshState() { - helper.updateBalancingState(state, picker); + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + helper.updateBalancingState(state, picker); + } + } + ); + } + + void shutdown() { + helper.getSynchronizationContext().execute( + new Runnable() { + @Override + public void run() { + lb.shutdown(); + } + } + ); } @Override @@ -346,6 +403,7 @@ public ChildPolicyWrapper returnObject(Object object) { long newCnt = refCnt.decrementAndGet(); checkState(newCnt != -1, "Cannot return never pooled childPolicyWrapper"); if (newCnt == 0) { + childPolicyWrapper.shutdown(); childPolicyWrapper = null; } return null; diff --git a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java index 2d60f103a60..9c1a24a0d1e 100644 --- a/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java +++ b/rls/src/main/java/io/grpc/rls/LinkedHashLruCache.java @@ -48,7 +48,7 @@ @ThreadSafe abstract class LinkedHashLruCache implements LruCache { - private final Object lock = new Object(); + private final Object lock; @GuardedBy("lock") private final LinkedHashMap delegate; @@ -64,9 +64,11 @@ abstract class LinkedHashLruCache implements LruCache { int cleaningInterval, TimeUnit cleaningIntervalUnit, ScheduledExecutorService ses, - final TimeProvider timeProvider) { + final TimeProvider timeProvider, + Object lock) { checkState(estimatedMaxSizeBytes > 0, "max estimated cache size should be positive"); this.estimatedMaxSizeBytes = estimatedMaxSizeBytes; + this.lock = checkNotNull(lock, "lock"); this.evictionListener = new SizeHandlingEvictionListener(evictionListener); this.timeProvider = checkNotNull(timeProvider, "timeProvider"); delegate = new LinkedHashMap( @@ -200,14 +202,15 @@ private V invalidate(K key, EvictionType cause) { } @Override - public final void invalidateAll(Iterable keys) { - checkNotNull(keys, "keys"); + public final void invalidateAll() { synchronized (lock) { - for (K key : keys) { - SizedValue existing = delegate.remove(key); - if (existing != null) { - evictionListener.onEviction(key, existing, EvictionType.EXPLICIT); + Iterator> iterator = delegate.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (entry.getValue() != null) { + evictionListener.onEviction(entry.getKey(), entry.getValue(), EvictionType.EXPLICIT); } + iterator.remove(); } } } @@ -291,13 +294,10 @@ private boolean cleanupExpiredEntries(int maxExpiredEntries, long now) { public final void close() { synchronized (lock) { periodicCleaner.stop(); - doClose(); - delegate.clear(); + invalidateAll(); } } - protected void doClose() {} - /** Periodically cleans up the AsyncRequestCache. */ private final class PeriodicCleaner { diff --git a/rls/src/main/java/io/grpc/rls/LruCache.java b/rls/src/main/java/io/grpc/rls/LruCache.java index 6ab5c4bcb46..1ad5a958289 100644 --- a/rls/src/main/java/io/grpc/rls/LruCache.java +++ b/rls/src/main/java/io/grpc/rls/LruCache.java @@ -49,10 +49,10 @@ interface LruCache { V invalidate(K key); /** - * Invalidates cache entries for given keys. This operation will trigger {@link EvictionListener} + * Invalidates cache entries for all keys. This operation will trigger {@link EvictionListener} * with {@link EvictionType#EXPLICIT}. */ - void invalidateAll(Iterable keys); + void invalidateAll(); /** Returns {@code true} if given key is cached. */ @CheckReturnValue diff --git a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java index c8222a02b8a..d10bc82d071 100644 --- a/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java +++ b/rls/src/test/java/io/grpc/rls/CachingRlsLbClientTest.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; +import static com.google.common.truth.Truth.assertWithMessage; import static io.grpc.rls.CachingRlsLbClient.RLS_DATA_KEY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -81,6 +82,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; @@ -172,6 +174,9 @@ private void setUpRlsLbClient() { public void tearDown() throws Exception { rlsLbClient.close(); CachingRlsLbClient.enableOobChannelDirectPath = existingEnableOobChannelDirectPath; + assertWithMessage( + "On client shut down, RlsLoadBalancer must shut down with all its child loadbalancers.") + .that(lbProvider.loadBalancers).isEmpty(); } private CachedRouteLookupResponse getInSyncContext( @@ -462,6 +467,7 @@ public BackoffPolicy get() { * immediately fails when using the fallback target. */ private static final class TestLoadBalancerProvider extends LoadBalancerProvider { + final Set loadBalancers = new HashSet<>(); @Override public boolean isAvailable() { @@ -486,7 +492,7 @@ public ConfigOrError parseLoadBalancingPolicyConfig( @Override public LoadBalancer newLoadBalancer(final Helper helper) { - return new LoadBalancer() { + LoadBalancer loadBalancer = new LoadBalancer() { @Override public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { @@ -527,8 +533,12 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { @Override public void shutdown() { + loadBalancers.remove(this); } }; + + loadBalancers.add(loadBalancer); + return loadBalancer; } } diff --git a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java index 7c6be039c24..36022cb25e6 100644 --- a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java +++ b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java @@ -18,17 +18,25 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; +import io.grpc.EquivalentAddressGroup; +import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancerProvider; import io.grpc.LoadBalancerRegistry; +import io.grpc.NameResolver.ConfigOrError; +import io.grpc.SynchronizationContext; import io.grpc.rls.ChildLoadBalancerHelper.ChildLoadBalancerHelperProvider; import io.grpc.rls.LbPolicyConfiguration.ChildLbStatusListener; import io.grpc.rls.LbPolicyConfiguration.ChildLoadBalancingPolicy; @@ -36,23 +44,58 @@ import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper; import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; +import java.lang.Thread.UncaughtExceptionHandler; import java.util.Map; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatchers; @RunWith(JUnit4.class) public class LbPolicyConfigurationTest { private final Helper helper = mock(Helper.class); + private final LoadBalancerProvider lbProvider = mock(LoadBalancerProvider.class); private final SubchannelStateManager subchannelStateManager = new SubchannelStateManagerImpl(); private final SubchannelPicker picker = mock(SubchannelPicker.class); private final ChildLbStatusListener childLbStatusListener = mock(ChildLbStatusListener.class); + private final ResolvedAddressFactory resolvedAddressFactory = + new ResolvedAddressFactory() { + @Override + public ResolvedAddresses create(Object childLbConfig) { + return ResolvedAddresses.newBuilder() + .setAddresses(ImmutableList.of()) + .build(); + } + }; private final RefCountedChildPolicyWrapperFactory factory = new RefCountedChildPolicyWrapperFactory( + new ChildLoadBalancingPolicy( + "targetFieldName", + ImmutableMap.of("foo", "bar"), + lbProvider), + resolvedAddressFactory, new ChildLoadBalancerHelperProvider(helper, subchannelStateManager, picker), childLbStatusListener); + @Before + public void setUp() { + doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); + doReturn( + new SynchronizationContext( + new UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + throw new AssertionError(e); + } + })) + .when(helper).getSynchronizationContext(); + doReturn(mock(LoadBalancer.class)).when(lbProvider).newLoadBalancer(any(Helper.class)); + doReturn(ConfigOrError.fromConfig(new Object())) + .when(lbProvider).parseLoadBalancingPolicyConfig(ArgumentMatchers.>any()); + } + @Test public void childPolicyWrapper_refCounted() { String target = "target"; diff --git a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java index b7254341f64..60266e15998 100644 --- a/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java +++ b/rls/src/test/java/io/grpc/rls/LinkedHashLruCacheTest.java @@ -23,7 +23,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -import com.google.common.collect.ImmutableList; import io.grpc.rls.DoNotUseDirectScheduledExecutorService.FakeTimeProvider; import io.grpc.rls.LruCache.EvictionListener; import io.grpc.rls.LruCache.EvictionType; @@ -62,7 +61,8 @@ public void setUp() { 10, TimeUnit.NANOSECONDS, fakeScheduledService, - timeProvider) { + timeProvider, + new Object()) { @Override protected boolean isExpired(Integer key, Entry value, long nowNanos) { return value.expireTime <= nowNanos; @@ -210,7 +210,7 @@ public void invalidateAll() { assertThat(cache.estimatedSize()).isEqualTo(2); - cache.invalidateAll(ImmutableList.of(1, 2)); + cache.invalidateAll(); assertThat(cache.estimatedSize()).isEqualTo(0); }