Skip to content

Commit c813d55

Browse files
fix: timing of stale token refreshes on ComputeEngine (#749)
* fix: timing of stale token refreshes on ComputeEngine ComputeEngine metadata server has its own token caching mechanism that will return a cached token until the last 5 minutes of its expiration. This has a negative interaction with stale token refreshes because stale token refresh kicks in T-6mins until T-5mins. This will cause every stale refresh to return the same stale token. This PR updates the timing for ComputeEngineCredentials to start a stale refresh at T-4mins and consider the token expired at T-3 mins. The implementation is a bit noisy because it includes a change OAuth2Credentials to make the thresholds configureable and now that we targeting java8, I migrated to using java8 time data types * fmt * fix test * fix test again * remove debug code * comments
1 parent e1cbce1 commit c813d55

File tree

5 files changed

+181
-51
lines changed

5 files changed

+181
-51
lines changed

oauth2_http/java/com/google/auth/oauth2/ComputeEngineCredentials.java

+11
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import java.io.ObjectInputStream;
5151
import java.net.SocketTimeoutException;
5252
import java.net.UnknownHostException;
53+
import java.time.Duration;
5354
import java.util.ArrayList;
5455
import java.util.Arrays;
5556
import java.util.Collection;
@@ -71,6 +72,14 @@
7172
public class ComputeEngineCredentials extends GoogleCredentials
7273
implements ServiceAccountSigner, IdTokenProvider {
7374

75+
// Decrease timing margins on GCE.
76+
// This is needed because GCE VMs maintain their own OAuth cache that expires T-5mins, attempting
77+
// to refresh a token before then, will yield the same stale token. To enable pre-emptive
78+
// refreshes, the margins must be shortened. This shouldn't cause problems since the clock skew
79+
// on the VM and metadata proxy should be non-existent.
80+
static final Duration COMPUTE_EXPIRATION_MARGIN = Duration.ofMinutes(3);
81+
static final Duration COMPUTE_REFRESH_MARGIN = Duration.ofMinutes(4);
82+
7483
private static final Logger LOGGER = Logger.getLogger(ComputeEngineCredentials.class.getName());
7584

7685
static final String DEFAULT_METADATA_SERVER_URL = "http://metadata.google.internal";
@@ -116,6 +125,8 @@ private ComputeEngineCredentials(
116125
HttpTransportFactory transportFactory,
117126
Collection<String> scopes,
118127
Collection<String> defaultScopes) {
128+
super(/* accessToken= */ null, COMPUTE_REFRESH_MARGIN, COMPUTE_EXPIRATION_MARGIN);
129+
119130
this.transportFactory =
120131
firstNonNull(
121132
transportFactory,

oauth2_http/java/com/google/auth/oauth2/GoogleCredentials.java

+11
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.io.IOException;
4141
import java.io.InputStream;
4242
import java.nio.charset.StandardCharsets;
43+
import java.time.Duration;
4344
import java.util.Collection;
4445
import java.util.Collections;
4546
import java.util.HashMap;
@@ -213,6 +214,16 @@ public GoogleCredentials(AccessToken accessToken) {
213214
super(accessToken);
214215
}
215216

217+
/**
218+
* Constructor with explicit access token and refresh times
219+
*
220+
* @param accessToken initial or temporary access token
221+
*/
222+
protected GoogleCredentials(
223+
AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
224+
super(accessToken, refreshMargin, expirationMargin);
225+
}
226+
216227
public static Builder newBuilder() {
217228
return new Builder();
218229
}

oauth2_http/java/com/google/auth/oauth2/OAuth2Credentials.java

+44-9
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@
3131

3232
package com.google.auth.oauth2;
3333

34-
import static java.util.concurrent.TimeUnit.MINUTES;
35-
3634
import com.google.api.client.util.Clock;
3735
import com.google.auth.Credentials;
3836
import com.google.auth.RequestMetadataCallback;
3937
import com.google.auth.http.AuthHttpConstants;
4038
import com.google.common.annotations.VisibleForTesting;
4139
import com.google.common.base.MoreObjects;
40+
import com.google.common.base.Preconditions;
4241
import com.google.common.collect.ImmutableList;
4342
import com.google.common.collect.ImmutableMap;
4443
import com.google.common.collect.Iterables;
@@ -51,6 +50,7 @@
5150
import java.io.ObjectInputStream;
5251
import java.io.Serializable;
5352
import java.net.URI;
53+
import java.time.Duration;
5454
import java.util.ArrayList;
5555
import java.util.Date;
5656
import java.util.List;
@@ -67,10 +67,13 @@
6767
public class OAuth2Credentials extends Credentials {
6868

6969
private static final long serialVersionUID = 4556936364828217687L;
70-
static final long MINIMUM_TOKEN_MILLISECONDS = MINUTES.toMillis(5);
71-
static final long REFRESH_MARGIN_MILLISECONDS = MINIMUM_TOKEN_MILLISECONDS + MINUTES.toMillis(1);
70+
static final Duration DEFAULT_EXPIRATION_MARGIN = Duration.ofMinutes(5);
71+
static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(6);
7272
private static final ImmutableMap<String, List<String>> EMPTY_EXTRA_HEADERS = ImmutableMap.of();
7373

74+
private final Duration expirationMargin;
75+
private final Duration refreshMargin;
76+
7477
// byte[] is serializable, so the lock variable can be final
7578
@VisibleForTesting final Object lock = new byte[0];
7679
private volatile OAuthValue value = null;
@@ -102,9 +105,20 @@ protected OAuth2Credentials() {
102105
* @param accessToken initial or temporary access token
103106
*/
104107
protected OAuth2Credentials(AccessToken accessToken) {
108+
this(accessToken, DEFAULT_REFRESH_MARGIN, DEFAULT_EXPIRATION_MARGIN);
109+
}
110+
111+
protected OAuth2Credentials(
112+
AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
105113
if (accessToken != null) {
106114
this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS);
107115
}
116+
117+
this.refreshMargin = Preconditions.checkNotNull(refreshMargin, "refreshMargin");
118+
Preconditions.checkArgument(!refreshMargin.isNegative(), "refreshMargin can't be negative");
119+
this.expirationMargin = Preconditions.checkNotNull(expirationMargin, "expirationMargin");
120+
Preconditions.checkArgument(
121+
!expirationMargin.isNegative(), "expirationMargin can't be negative");
108122
}
109123

110124
@Override
@@ -324,13 +338,12 @@ private CacheState getState() {
324338
return CacheState.FRESH;
325339
}
326340

327-
long remainingMillis = expirationTime.getTime() - clock.currentTimeMillis();
328-
329-
if (remainingMillis <= MINIMUM_TOKEN_MILLISECONDS) {
341+
Duration remaining = Duration.ofMillis(expirationTime.getTime() - clock.currentTimeMillis());
342+
if (remaining.compareTo(expirationMargin) <= 0) {
330343
return CacheState.EXPIRED;
331344
}
332345

333-
if (remainingMillis <= REFRESH_MARGIN_MILLISECONDS) {
346+
if (remaining.compareTo(refreshMargin) <= 0) {
334347
return CacheState.STALE;
335348
}
336349

@@ -572,24 +585,46 @@ void executeIfNew(Executor executor) {
572585
public static class Builder {
573586

574587
private AccessToken accessToken;
588+
private Duration refreshMargin = DEFAULT_REFRESH_MARGIN;
589+
private Duration expirationMargin = DEFAULT_EXPIRATION_MARGIN;
575590

576591
protected Builder() {}
577592

578593
protected Builder(OAuth2Credentials credentials) {
579594
this.accessToken = credentials.getAccessToken();
595+
this.refreshMargin = credentials.refreshMargin;
596+
this.expirationMargin = credentials.expirationMargin;
580597
}
581598

582599
public Builder setAccessToken(AccessToken token) {
583600
this.accessToken = token;
584601
return this;
585602
}
586603

604+
public Builder setRefreshMargin(Duration refreshMargin) {
605+
this.refreshMargin = refreshMargin;
606+
return this;
607+
}
608+
609+
public Duration getRefreshMargin() {
610+
return refreshMargin;
611+
}
612+
613+
public Builder setExpirationMargin(Duration expirationMargin) {
614+
this.expirationMargin = expirationMargin;
615+
return this;
616+
}
617+
618+
public Duration getExpirationMargin() {
619+
return expirationMargin;
620+
}
621+
587622
public AccessToken getAccessToken() {
588623
return accessToken;
589624
}
590625

591626
public OAuth2Credentials build() {
592-
return new OAuth2Credentials(accessToken);
627+
return new OAuth2Credentials(accessToken, refreshMargin, expirationMargin);
593628
}
594629
}
595630
}

oauth2_http/javatests/com/google/auth/oauth2/MockRequestMetadataCallback.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939

4040
/** Mock RequestMetadataCallback */
4141
public final class MockRequestMetadataCallback implements RequestMetadataCallback {
42-
Map<String, List<String>> metadata;
43-
Throwable exception;
42+
volatile Map<String, List<String>> metadata;
43+
volatile Throwable exception;
4444
CountDownLatch latch = new CountDownLatch(1);
4545

4646
/** Called when metadata is successfully produced. */

0 commit comments

Comments
 (0)