Clear 0-RTT on existing sessions when 0-RTT is rejected

Bug: 1066623
Change-Id: I3925e1d9e394983f6d6ecb8dc97b31222a333ca1
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/2182567
Reviewed-by: Steven Valdez <[email protected]>
Commit-Queue: David Benjamin <[email protected]>
Cr-Commit-Position: refs/heads/master@{#768858}
diff --git a/net/socket/ssl_client_socket_impl.cc b/net/socket/ssl_client_socket_impl.cc
index 568b2db..12fb375 100644
--- a/net/socket/ssl_client_socket_impl.cc
+++ b/net/socket/ssl_client_socket_impl.cc
@@ -1490,19 +1490,19 @@
 }
 
 void SSLClientSocketImpl::DoPeek() {
-  if (ssl_config_.disable_post_handshake_peek_for_testing ||
-      !completed_connect_ || peek_complete_) {
+  if (!completed_connect_) {
     return;
   }
 
   crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE);
 
-  if (ssl_config_.early_data_enabled && !recorded_early_data_result_) {
+  if (ssl_config_.early_data_enabled && !handled_early_data_result_) {
     // |SSL_peek| will implicitly run |SSL_do_handshake| if needed, but run it
     // manually to pick up the reject reason.
     int rv = SSL_do_handshake(ssl_.get());
     int ssl_err = SSL_get_error(ssl_.get(), rv);
-    if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
+    int err = rv > 0 ? OK : MapOpenSSLError(ssl_err, err_tracer);
+    if (err == ERR_IO_PENDING) {
       return;
     }
 
@@ -1513,13 +1513,28 @@
     UMA_HISTOGRAM_ENUMERATION("Net.SSLHandshakeEarlyDataReason",
                               SSL_get_early_data_reason(ssl_.get()),
                               ssl_early_data_reason_max_value + 1);
-    recorded_early_data_result_ = true;
-    if (ssl_err != SSL_ERROR_NONE) {
+
+    // On early data reject, clear early data on any other sessions in the
+    // cache, so retries do not get stuck attempting 0-RTT. See
+    // https://crbug.com/1066623.
+    if (err == ERR_EARLY_DATA_REJECTED ||
+        err == ERR_WRONG_VERSION_ON_EARLY_DATA) {
+      context_->ssl_client_session_cache()->ClearEarlyData(
+          GetSessionCacheKey(base::nullopt));
+    }
+
+    handled_early_data_result_ = true;
+
+    if (err != OK) {
       peek_complete_ = true;
       return;
     }
   }
 
+  if (ssl_config_.disable_post_handshake_peek_for_testing || peek_complete_) {
+    return;
+  }
+
   char byte;
   int rv = SSL_peek(ssl_.get(), &byte, 1);
   int ssl_err = SSL_get_error(ssl_.get(), rv);
diff --git a/net/socket/ssl_client_socket_impl.h b/net/socket/ssl_client_socket_impl.h
index 9ffa830..7388274 100644
--- a/net/socket/ssl_client_socket_impl.h
+++ b/net/socket/ssl_client_socket_impl.h
@@ -220,9 +220,8 @@
   int user_write_buf_len_;
   bool first_post_handshake_write_ = true;
 
-  // True if we've already recorded the result of our attempt to
-  // use early data.
-  bool recorded_early_data_result_ = false;
+  // True if we've already handled the result of our attempt to use early data.
+  bool handled_early_data_result_ = false;
 
   // Used by DoPayloadRead() when attempting to fill the caller's buffer with
   // as much data as possible without blocking.
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc
index 4bfc415..41aea1c 100644
--- a/net/socket/ssl_client_socket_unittest.cc
+++ b/net/socket/ssl_client_socket_unittest.cc
@@ -4979,14 +4979,13 @@
   EXPECT_EQ(SSLInfo::HANDSHAKE_RESUME, ssl_info.handshake_type);
 }
 
-TEST_F(SSLClientSocketZeroRTTTest, ZeroRTTNoZeroRTTOnResume) {
+TEST_F(SSLClientSocketZeroRTTTest, ZeroRTTReject) {
   ASSERT_TRUE(StartServer());
   ASSERT_TRUE(RunInitialConnection());
 
   SSLServerConfig server_config;
   server_config.early_data_enabled = false;
   server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_3;
-
   SetServerConfig(server_config);
 
   // 0-RTT Connection
@@ -5003,6 +5002,50 @@
   EXPECT_EQ(ERR_EARLY_DATA_REJECTED, rv);
   rv = WriteAndWait(kRequest);
   EXPECT_EQ(ERR_EARLY_DATA_REJECTED, rv);
+
+  // Retrying the connection should succeed.
+  socket = MakeClient(true);
+  ASSERT_THAT(Connect(), IsOk());
+  ASSERT_THAT(MakeHTTPRequest(ssl_socket()), IsOk());
+  SSLInfo ssl_info;
+  ASSERT_TRUE(GetSSLInfo(&ssl_info));
+  EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type);
+}
+
+TEST_F(SSLClientSocketZeroRTTTest, ZeroRTTWrongVersion) {
+  ASSERT_TRUE(StartServer());
+  ASSERT_TRUE(RunInitialConnection());
+
+  SSLServerConfig server_config;
+  server_config.version_max = SSL_PROTOCOL_VERSION_TLS1_2;
+  SetServerConfig(server_config);
+
+  // 0-RTT Connection
+  FakeBlockingStreamSocket* socket = MakeClient(true);
+  socket->BlockReadResult();
+  ASSERT_THAT(Connect(), IsOk());
+  constexpr base::StringPiece kRequest = "GET /zerortt HTTP/1.0\r\n\r\n";
+  EXPECT_EQ(static_cast<int>(kRequest.size()), WriteAndWait(kRequest));
+  socket->UnblockReadResult();
+
+  // Expect early data to be rejected because the TLS version was incorrect.
+  scoped_refptr<IOBuffer> buf = base::MakeRefCounted<IOBuffer>(4096);
+  int rv = ReadAndWait(buf.get(), 4096);
+  EXPECT_EQ(ERR_WRONG_VERSION_ON_EARLY_DATA, rv);
+  rv = WriteAndWait(kRequest);
+  // TODO(https://crbug.com/1078515): This should be
+  // ERR_WRONG_VERSION_ON_EARLY_DATA. We assert on the current value so that,
+  // when the bug is fixed (likely in BoringSSL), we remember to fix the test to
+  // set a proper test expectation.
+  EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, rv);
+
+  // Retrying the connection should succeed.
+  socket = MakeClient(true);
+  ASSERT_THAT(Connect(), IsOk());
+  ASSERT_THAT(MakeHTTPRequest(ssl_socket()), IsOk());
+  SSLInfo ssl_info;
+  ASSERT_TRUE(GetSSLInfo(&ssl_info));
+  EXPECT_EQ(SSLInfo::HANDSHAKE_FULL, ssl_info.handshake_type);
 }
 
 // Test that the ConfirmHandshake successfully completes the handshake and that
diff --git a/net/ssl/ssl_client_session_cache.cc b/net/ssl/ssl_client_session_cache.cc
index 7edbe7b..00bd9176 100644
--- a/net/ssl/ssl_client_session_cache.cc
+++ b/net/ssl/ssl_client_session_cache.cc
@@ -93,6 +93,17 @@
   iter->second.Push(std::move(session));
 }
 
+void SSLClientSessionCache::ClearEarlyData(const Key& cache_key) {
+  auto iter = cache_.Get(cache_key);
+  if (iter != cache_.end()) {
+    for (auto& session : iter->second.sessions) {
+      if (session) {
+        session.reset(SSL_SESSION_copy_without_early_data(session.get()));
+      }
+    }
+  }
+}
+
 void SSLClientSessionCache::FlushForServer(const HostPortPair& server) {
   auto iter = cache_.begin();
   while (iter != cache_.end()) {
diff --git a/net/ssl/ssl_client_session_cache.h b/net/ssl/ssl_client_session_cache.h
index f6ccc61..2994c0c8 100644
--- a/net/ssl/ssl_client_session_cache.h
+++ b/net/ssl/ssl_client_session_cache.h
@@ -78,6 +78,11 @@
   // checked for stale entries.
   void Insert(const Key& cache_key, bssl::UniquePtr<SSL_SESSION> session);
 
+  // Clears early data support for all current sessions associated with
+  // |cache_key|. This may be used after a 0-RTT reject to avoid unnecessarily
+  // offering 0-RTT data on retries. See https://crbug.com/1066623.
+  void ClearEarlyData(const Key& cache_key);
+
   // Removes all entries associated with |server|.
   void FlushForServer(const HostPortPair& server);
 
diff --git a/net/ssl/ssl_config.h b/net/ssl/ssl_config.h
index 9855190..40b61396 100644
--- a/net/ssl/ssl_config.h
+++ b/net/ssl/ssl_config.h
@@ -142,8 +142,8 @@
   PrivacyMode privacy_mode = PRIVACY_MODE_DISABLED;
 
   // True if the post-handshake peeking of the transport should be skipped. This
-  // logic ensures 0-RTT and tickets are resolved early, but can interfere with
-  // some unit tests.
+  // logic ensures tickets are resolved early, but can interfere with some unit
+  // tests.
   bool disable_post_handshake_peek_for_testing = false;
 };