[websocket] Use HttpRequestHeaders, not string, to represent headers

net::WebSocketChannel and various related classes use std::string to
represent request headers. This CL changes them to HttpRequestHeaders.

Bug: 721400
Cq-Include-Trybots: master.tryserver.chromium.linux:linux_mojo
Change-Id: Id730779b36f3a319a61b44516bd3e8389ebdfc23
Reviewed-on: https://chromium-review.googlesource.com/1065713
Reviewed-by: Kinuko Yasuda <[email protected]>
Reviewed-by: Adam Rice <[email protected]>
Commit-Queue: Yutaka Hirano <[email protected]>
Cr-Commit-Position: refs/heads/master@{#560888}
diff --git a/net/websockets/websocket_channel.cc b/net/websockets/websocket_channel.cc
index b5c0639..0b05c050 100644
--- a/net/websockets/websocket_channel.cc
+++ b/net/websockets/websocket_channel.cc
@@ -365,7 +365,7 @@
     const std::vector<std::string>& requested_subprotocols,
     const url::Origin& origin,
     const GURL& site_for_cookies,
-    const std::string& additional_headers) {
+    const HttpRequestHeaders& additional_headers) {
   SendAddChannelRequestWithSuppliedCallback(
       socket_url, requested_subprotocols, origin, site_for_cookies,
       additional_headers, base::Bind(&WebSocketStream::CreateAndConnectStream));
@@ -575,7 +575,7 @@
     const std::vector<std::string>& requested_subprotocols,
     const url::Origin& origin,
     const GURL& site_for_cookies,
-    const std::string& additional_headers,
+    const HttpRequestHeaders& additional_headers,
     const WebSocketStreamRequestCreationCallback& callback) {
   SendAddChannelRequestWithSuppliedCallback(socket_url, requested_subprotocols,
                                             origin, site_for_cookies,
@@ -597,7 +597,7 @@
     const std::vector<std::string>& requested_subprotocols,
     const url::Origin& origin,
     const GURL& site_for_cookies,
-    const std::string& additional_headers,
+    const HttpRequestHeaders& additional_headers,
     const WebSocketStreamRequestCreationCallback& callback) {
   DCHECK_EQ(FRESHLY_CONSTRUCTED, state_);
   if (!socket_url.SchemeIsWSOrWSS()) {
diff --git a/net/websockets/websocket_channel.h b/net/websockets/websocket_channel.h
index 470a77753..666a252 100644
--- a/net/websockets/websocket_channel.h
+++ b/net/websockets/websocket_channel.h
@@ -31,8 +31,9 @@
 
 namespace net {
 
-class NetLogWithSource;
+class HttpRequestHeaders;
 class IOBuffer;
+class NetLogWithSource;
 class URLRequest;
 class URLRequestContext;
 struct WebSocketHandshakeRequestInfo;
@@ -53,7 +54,7 @@
       std::unique_ptr<WebSocketHandshakeStreamCreateHelper>,
       const url::Origin&,
       const GURL&,
-      const std::string&,
+      const HttpRequestHeaders&,
       URLRequestContext*,
       const NetLogWithSource&,
       std::unique_ptr<WebSocketStream::ConnectDelegate>)>
@@ -77,7 +78,7 @@
       const std::vector<std::string>& requested_protocols,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers);
+      const HttpRequestHeaders& additional_headers);
 
   // Sends a data frame to the remote side. It is the responsibility of the
   // caller to ensure that they have sufficient send quota to send this data,
@@ -129,7 +130,7 @@
       const std::vector<std::string>& requested_protocols,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers,
+      const HttpRequestHeaders& additional_headers,
       const WebSocketStreamRequestCreationCallback& callback);
 
   // The default timout for the closing handshake is a sensible value (see
@@ -188,7 +189,7 @@
       const std::vector<std::string>& requested_protocols,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers,
+      const HttpRequestHeaders& additional_headers,
       const WebSocketStreamRequestCreationCallback& callback);
 
   // Called when a URLRequest is created for handshaking.
diff --git a/net/websockets/websocket_channel_test.cc b/net/websockets/websocket_channel_test.cc
index 38a31ed..186d99a 100644
--- a/net/websockets/websocket_channel_test.cc
+++ b/net/websockets/websocket_channel_test.cc
@@ -27,6 +27,7 @@
 #include "base/threading/thread_task_runner_handle.h"
 #include "net/base/net_errors.h"
 #include "net/base/test_completion_callback.h"
+#include "net/http/http_request_headers.h"
 #include "net/http/http_response_headers.h"
 #include "net/log/net_log_with_source.h"
 #include "net/test/test_with_scoped_task_environment.h"
@@ -696,7 +697,7 @@
       std::unique_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers,
+      const HttpRequestHeaders& additional_headers,
       URLRequestContext* url_request_context,
       const NetLogWithSource& net_log,
       std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate) {
@@ -751,7 +752,8 @@
         CreateEventInterface(), &connect_data_.url_request_context);
     channel_->SendAddChannelRequestForTesting(
         connect_data_.socket_url, connect_data_.requested_subprotocols,
-        connect_data_.origin, connect_data_.site_for_cookies, "",
+        connect_data_.origin, connect_data_.site_for_cookies,
+        HttpRequestHeaders(),
         base::Bind(&WebSocketStreamCreationCallbackArgumentSaver::Create,
                    base::Unretained(&connect_data_.argument_saver)));
   }
diff --git a/net/websockets/websocket_end_to_end_test.cc b/net/websockets/websocket_end_to_end_test.cc
index 3c96f2b..df9a8ba 100644
--- a/net/websockets/websocket_end_to_end_test.cc
+++ b/net/websockets/websocket_end_to_end_test.cc
@@ -27,6 +27,7 @@
 #include "build/build_config.h"
 #include "net/base/auth.h"
 #include "net/base/proxy_delegate.h"
+#include "net/http/http_request_headers.h"
 #include "net/proxy_resolution/proxy_resolution_service.h"
 #include "net/test/embedded_test_server/embedded_test_server.h"
 #include "net/test/spawned_test_server/spawned_test_server.h"
@@ -246,7 +247,7 @@
     channel_ = std::make_unique<WebSocketChannel>(
         base::WrapUnique(event_interface_), &context_);
     channel_->SendAddChannelRequest(GURL(socket_url), sub_protocols_, origin,
-                                    site_for_cookies, "");
+                                    site_for_cookies, HttpRequestHeaders());
     event_interface_->WaitForResponse();
     return !event_interface_->failed();
   }
diff --git a/net/websockets/websocket_stream.cc b/net/websockets/websocket_stream.cc
index cc5ab59..4dd39c3 100644
--- a/net/websockets/websocket_stream.cc
+++ b/net/websockets/websocket_stream.cc
@@ -105,7 +105,7 @@
       const URLRequestContext* context,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers,
+      const HttpRequestHeaders& additional_headers,
       std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
       std::unique_ptr<WebSocketHandshakeStreamCreateHelper> create_helper)
       : delegate_(std::make_unique<Delegate>(this)),
@@ -116,7 +116,7 @@
         connect_delegate_(std::move(connect_delegate)),
         handshake_stream_(nullptr) {
     create_helper->set_stream_request(this);
-    HttpRequestHeaders headers;
+    HttpRequestHeaders headers = additional_headers;
     headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
     if (base::FeatureList::IsEnabled(WebSocketBasicHandshakeStream::
                                          kWebSocketHandshakeReuseConnection)) {
@@ -134,7 +134,11 @@
     headers.SetHeader(websockets::kSecWebSocketVersion,
                       websockets::kSupportedVersion);
 
-    headers.AddHeadersFromString(additional_headers);
+    // Remove HTTP headers that are important to websocket connections: they
+    // will be added later.
+    headers.RemoveHeader(websockets::kSecWebSocketExtensions);
+    headers.RemoveHeader(websockets::kSecWebSocketKey);
+    headers.RemoveHeader(websockets::kSecWebSocketProtocol);
 
     url_request_->SetExtraRequestHeaders(headers);
     url_request_->set_initiator(origin);
@@ -400,7 +404,7 @@
     std::unique_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
     const url::Origin& origin,
     const GURL& site_for_cookies,
-    const std::string& additional_headers,
+    const HttpRequestHeaders& additional_headers,
     URLRequestContext* url_request_context,
     const NetLogWithSource& net_log,
     std::unique_ptr<ConnectDelegate> connect_delegate) {
@@ -418,7 +422,7 @@
     std::unique_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
     const url::Origin& origin,
     const GURL& site_for_cookies,
-    const std::string& additional_headers,
+    const HttpRequestHeaders& additional_headers,
     URLRequestContext* url_request_context,
     const NetLogWithSource& net_log,
     std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
diff --git a/net/websockets/websocket_stream.h b/net/websockets/websocket_stream.h
index d315e5ae..fd9e1c0 100644
--- a/net/websockets/websocket_stream.h
+++ b/net/websockets/websocket_stream.h
@@ -31,6 +31,7 @@
 
 namespace net {
 
+class HttpRequestHeaders;
 class NetLogWithSource;
 class URLRequest;
 class URLRequestContext;
@@ -123,7 +124,7 @@
       std::unique_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers,
+      const HttpRequestHeaders& additional_headers,
       URLRequestContext* url_request_context,
       const NetLogWithSource& net_log,
       std::unique_ptr<ConnectDelegate> connect_delegate);
@@ -136,7 +137,7 @@
       std::unique_ptr<WebSocketHandshakeStreamCreateHelper> create_helper,
       const url::Origin& origin,
       const GURL& site_for_cookies,
-      const std::string& additional_headers,
+      const HttpRequestHeaders& additional_headers,
       URLRequestContext* url_request_context,
       const NetLogWithSource& net_log,
       std::unique_ptr<ConnectDelegate> connect_delegate,
diff --git a/net/websockets/websocket_stream_cookie_test.cc b/net/websockets/websocket_stream_cookie_test.cc
index 9ba8f0c1..d16b3db 100644
--- a/net/websockets/websocket_stream_cookie_test.cc
+++ b/net/websockets/websocket_stream_cookie_test.cc
@@ -14,6 +14,7 @@
 #include "net/cookies/canonical_cookie.h"
 #include "net/cookies/canonical_cookie_test_helpers.h"
 #include "net/cookies/cookie_store.h"
+#include "net/http/http_request_headers.h"
 #include "net/socket/socket_test_util.h"
 #include "net/websockets/websocket_stream_create_test_base.h"
 #include "net/websockets/websocket_test_util.h"
@@ -48,8 +49,8 @@
                                             cookie_header, std::string(),
                                             std::string()),
         response_body);
-    CreateAndConnectStream(url, NoSubProtocols(), origin, site_for_cookies, "",
-                           nullptr);
+    CreateAndConnectStream(url, NoSubProtocols(), origin, site_for_cookies,
+                           HttpRequestHeaders(), nullptr);
   }
 
   std::string AddCRLFIfNotEmpty(const std::string& s) {
diff --git a/net/websockets/websocket_stream_create_test_base.cc b/net/websockets/websocket_stream_create_test_base.cc
index 7098374..11311b98 100644
--- a/net/websockets/websocket_stream_create_test_base.cc
+++ b/net/websockets/websocket_stream_create_test_base.cc
@@ -85,7 +85,7 @@
     const std::vector<std::string>& sub_protocols,
     const url::Origin& origin,
     const GURL& site_for_cookies,
-    const std::string& additional_headers,
+    const HttpRequestHeaders& additional_headers,
     std::unique_ptr<base::Timer> timer) {
   auto connect_delegate = std::make_unique<TestConnectDelegate>(
       this, connect_run_loop_.QuitClosure());
diff --git a/net/websockets/websocket_stream_create_test_base.h b/net/websockets/websocket_stream_create_test_base.h
index f0ac6fa..30a2f4ae 100644
--- a/net/websockets/websocket_stream_create_test_base.h
+++ b/net/websockets/websocket_stream_create_test_base.h
@@ -44,7 +44,7 @@
                               const std::vector<std::string>& sub_protocols,
                               const url::Origin& origin,
                               const GURL& site_for_cookies,
-                              const std::string& additional_headers,
+                              const HttpRequestHeaders& additional_headers,
                               std::unique_ptr<base::Timer> timer);
 
   static std::vector<HeaderKeyValuePair> RequestHeadersToVector(
diff --git a/net/websockets/websocket_stream_test.cc b/net/websockets/websocket_stream_test.cc
index b54badb..99dc9815 100644
--- a/net/websockets/websocket_stream_test.cc
+++ b/net/websockets/websocket_stream_test.cc
@@ -151,10 +151,11 @@
           WebSocketStandardResponse(
               WebSocketExtraHeadersToString(extra_response_headers)) +
               additional_data_);
-      CreateAndConnectStream(
-          socket_url, sub_protocols, Origin(), SiteForCookies(),
-          WebSocketExtraHeadersToString(send_additional_request_headers),
-          std::move(timer_));
+      CreateAndConnectStream(socket_url, sub_protocols, Origin(),
+                             SiteForCookies(),
+                             WebSocketExtraHeadersToHttpRequestHeaders(
+                                 send_additional_request_headers),
+                             std::move(timer_));
       return;
     }
 
@@ -283,10 +284,11 @@
     base::RunLoop().Run();
     EXPECT_FALSE(request->is_pending());
 
-    CreateAndConnectStream(
-        socket_url, sub_protocols, Origin(), SiteForCookies(),
-        WebSocketExtraHeadersToString(send_additional_request_headers),
-        std::move(timer_));
+    CreateAndConnectStream(socket_url, sub_protocols, Origin(),
+                           SiteForCookies(),
+                           WebSocketExtraHeadersToHttpRequestHeaders(
+                               send_additional_request_headers),
+                           std::move(timer_));
   }
 
   // Like CreateAndConnectStandard(), but allow for arbitrary response body.
@@ -309,10 +311,11 @@
             WebSocketExtraHeadersToString(send_additional_request_headers),
             WebSocketExtraHeadersToString(extra_request_headers)),
         response_body);
-    CreateAndConnectStream(
-        socket_url, sub_protocols, Origin(), SiteForCookies(),
-        WebSocketExtraHeadersToString(send_additional_request_headers),
-        nullptr);
+    CreateAndConnectStream(socket_url, sub_protocols, Origin(),
+                           SiteForCookies(),
+                           WebSocketExtraHeadersToHttpRequestHeaders(
+                               send_additional_request_headers),
+                           nullptr);
   }
 
   // Like CreateAndConnectStandard(), but take extra response headers as a
@@ -332,20 +335,20 @@
         WebSocketStandardRequest(socket_path, socket_host, Origin(), "", ""),
         WebSocketStandardResponse(extra_response_headers));
     CreateAndConnectStream(socket_url, sub_protocols, Origin(),
-                           SiteForCookies(), "", nullptr);
+                           SiteForCookies(), HttpRequestHeaders(), nullptr);
   }
 
   // Like CreateAndConnectStandard(), but take raw mock data.
   void CreateAndConnectRawExpectations(
       base::StringPiece url,
       const std::vector<std::string>& sub_protocols,
-      const std::string& send_additional_request_headers,
+      const HttpRequestHeaders& additional_headers,
       std::unique_ptr<SequencedSocketData> socket_data) {
     ASSERT_EQ(BASIC_HANDSHAKE_STREAM, stream_type_);
 
     url_request_context_host_.AddRawExpectations(std::move(socket_data));
     CreateAndConnectStream(GURL(url), sub_protocols, Origin(), SiteForCookies(),
-                           send_additional_request_headers, std::move(timer_));
+                           additional_headers, std::move(timer_));
   }
 
  private:
@@ -487,7 +490,7 @@
     const std::string request =
         base::StringPrintf(request2format, base64_user_pass.c_str());
     CreateAndConnectRawExpectations(
-        url, NoSubProtocols(), "",
+        url, NoSubProtocols(), HttpRequestHeaders(),
         helper_.BuildSocketData2(request, response2));
   }
 
@@ -661,8 +664,8 @@
 
   std::vector<HeaderKeyValuePair> request_headers =
       RequestHeadersToVector(request_info_->headers);
-  EXPECT_EQ(HeaderKeyValuePair("User-Agent", "OveRrIde"), request_headers[7]);
-  EXPECT_EQ(HeaderKeyValuePair("rAnDomHeader", "foobar"), request_headers[8]);
+  EXPECT_EQ(HeaderKeyValuePair("User-Agent", "OveRrIde"), request_headers[4]);
+  EXPECT_EQ(HeaderKeyValuePair("rAnDomHeader", "foobar"), request_headers[5]);
 }
 
 // Confirm that the stream isn't established until the message loop runs.
@@ -1193,8 +1196,8 @@
   std::unique_ptr<SequencedSocketData> socket_data(BuildNullSocketData());
   socket_data->set_connect_data(
       MockConnect(SYNCHRONOUS, ERR_CONNECTION_REFUSED));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   WaitUntilConnectDone();
   EXPECT_TRUE(has_failed());
   EXPECT_EQ("Error in connection establishment: net::ERR_CONNECTION_REFUSED",
@@ -1208,8 +1211,8 @@
   std::unique_ptr<SequencedSocketData> socket_data(BuildNullSocketData());
   socket_data->set_connect_data(
       MockConnect(ASYNC, ERR_CONNECTION_TIMED_OUT));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   WaitUntilConnectDone();
   EXPECT_TRUE(has_failed());
   EXPECT_EQ("Error in connection establishment: net::ERR_CONNECTION_TIMED_OUT",
@@ -1223,8 +1226,8 @@
   auto timer = std::make_unique<MockWeakTimer>(false, false);
   base::WeakPtr<MockWeakTimer> weak_timer = timer->AsWeakPtr();
   SetTimer(std::move(timer));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   EXPECT_FALSE(has_failed());
   ASSERT_TRUE(weak_timer.get());
   EXPECT_TRUE(weak_timer->IsRunning());
@@ -1264,8 +1267,8 @@
   auto timer = std::make_unique<MockWeakTimer>(false, false);
   base::WeakPtr<MockWeakTimer> weak_timer = timer->AsWeakPtr();
   SetTimer(std::move(timer));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   ASSERT_TRUE(weak_timer.get());
   EXPECT_TRUE(weak_timer->IsRunning());
 
@@ -1281,8 +1284,8 @@
 TEST_P(WebSocketStreamCreateTest, CancellationDuringConnect) {
   std::unique_ptr<SequencedSocketData> socket_data(BuildNullSocketData());
   socket_data->set_connect_data(MockConnect(SYNCHRONOUS, ERR_IO_PENDING));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   stream_request_.reset();
   // WaitUntilConnectDone doesn't work in this case.
   base::RunLoop().RunUntilIdle();
@@ -1297,7 +1300,8 @@
   SequencedSocketData* socket_data(
       new SequencedSocketData(base::span<MockRead>(), writes));
   socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(),
                                   base::WrapUnique(socket_data));
   base::RunLoop().RunUntilIdle();
   EXPECT_TRUE(socket_data->AllWriteDataConsumed());
@@ -1321,8 +1325,8 @@
   std::unique_ptr<SequencedSocketData> socket_data(
       BuildSocketData(reads, writes));
   SequencedSocketData* socket_data_raw_ptr = socket_data.get();
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   base::RunLoop().RunUntilIdle();
   EXPECT_TRUE(socket_data_raw_ptr->AllReadDataConsumed());
   stream_request_.reset();
@@ -1374,8 +1378,8 @@
   std::unique_ptr<SequencedSocketData> socket_data(
       BuildSocketData(reads, writes));
   SequencedSocketData* socket_data_raw_ptr = socket_data.get();
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   base::RunLoop().RunUntilIdle();
   EXPECT_TRUE(socket_data_raw_ptr->AllReadDataConsumed());
   EXPECT_TRUE(has_failed());
@@ -1404,7 +1408,8 @@
       std::move(ssl_socket_data));
   std::unique_ptr<SequencedSocketData> raw_socket_data(BuildNullSocketData());
   CreateAndConnectRawExpectations("wss://www.example.org/", NoSubProtocols(),
-                                  "", std::move(raw_socket_data));
+                                  HttpRequestHeaders(),
+                                  std::move(raw_socket_data));
   // WaitUntilConnectDone doesn't work in this case.
   base::RunLoop().RunUntilIdle();
   EXPECT_FALSE(has_failed());
@@ -1475,7 +1480,8 @@
       helper_.BuildSocketData1(kUnauthorizedResponse));
 
   CreateAndConnectRawExpectations(
-      "ws://FooBar:[email protected]/", NoSubProtocols(), "",
+      "ws://FooBar:[email protected]/", NoSubProtocols(),
+      HttpRequestHeaders(),
       helper_.BuildSocketData2(kAuthorizedRequest,
                                WebSocketStandardResponse(std::string())));
   WaitUntilConnectDone();
@@ -1495,7 +1501,8 @@
     MockRead reads[] = {MockRead(ASYNC, ERR_IO_PENDING, 0)};
     MockWrite writes[] = {MockWrite(ASYNC, 1, request.c_str())};
     CreateAndConnectRawExpectations("wss://www.example.org/", NoSubProtocols(),
-                                    "", BuildSocketData(reads, writes));
+                                    HttpRequestHeaders(),
+                                    BuildSocketData(reads, writes));
     base::RunLoop().RunUntilIdle();
     stream_request_.reset();
 
@@ -1571,8 +1578,8 @@
   std::unique_ptr<SequencedSocketData> socket_data(
       BuildSocketData(reads, writes));
   socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK));
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   WaitUntilConnectDone();
   EXPECT_TRUE(has_failed());
 
@@ -1606,8 +1613,8 @@
   std::unique_ptr<SequencedSocketData> socket_data(
       BuildSocketData(reads, writes));
   url_request_context_host_.SetProxyConfig("https=proxy:8000");
-  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(), "",
-                                  std::move(socket_data));
+  CreateAndConnectRawExpectations("ws://www.example.org/", NoSubProtocols(),
+                                  HttpRequestHeaders(), std::move(socket_data));
   WaitUntilConnectDone();
   EXPECT_TRUE(has_failed());
   EXPECT_EQ("Establishing a tunnel via proxy server failed.",
diff --git a/net/websockets/websocket_test_util.cc b/net/websockets/websocket_test_util.cc
index 247b0f8..17a5931 100644
--- a/net/websockets/websocket_test_util.cc
+++ b/net/websockets/websocket_test_util.cc
@@ -48,6 +48,14 @@
   return answer;
 }
 
+HttpRequestHeaders WebSocketExtraHeadersToHttpRequestHeaders(
+    const WebSocketExtraHeaders& headers) {
+  HttpRequestHeaders headers_to_return;
+  for (const auto& header : headers)
+    headers_to_return.SetHeader(header.first, header.second);
+  return headers_to_return;
+}
+
 std::string WebSocketStandardRequest(
     const std::string& path,
     const std::string& host,
@@ -77,11 +85,12 @@
   headers.SetHeader("Connection", "Upgrade");
   headers.SetHeader("Pragma", "no-cache");
   headers.SetHeader("Cache-Control", "no-cache");
+  headers.AddHeadersFromString(send_additional_request_headers);
   headers.SetHeader("Upgrade", "websocket");
   headers.SetHeader("Origin", origin.Serialize());
   headers.SetHeader("Sec-WebSocket-Version", "13");
-  headers.SetHeader("User-Agent", "");
-  headers.AddHeadersFromString(send_additional_request_headers);
+  if (!headers.HasHeader("User-Agent"))
+    headers.SetHeader("User-Agent", "");
   headers.SetHeader("Accept-Encoding", "gzip, deflate");
   headers.SetHeader("Accept-Language", "en-us,fr");
   headers.AddHeadersFromString(cookies);
diff --git a/net/websockets/websocket_test_util.h b/net/websockets/websocket_test_util.h
index c9e5d9bb..ba5de5fd 100644
--- a/net/websockets/websocket_test_util.h
+++ b/net/websockets/websocket_test_util.h
@@ -14,6 +14,7 @@
 
 #include "base/macros.h"
 #include "net/http/http_basic_state.h"
+#include "net/http/http_request_headers.h"
 #include "net/http/http_stream_parser.h"
 #include "net/socket/client_socket_handle.h"
 #include "net/third_party/spdy/core/spdy_header_block.h"
@@ -48,6 +49,10 @@
 // Converts a vector of header key-value pairs into a single string.
 std::string WebSocketExtraHeadersToString(const WebSocketExtraHeaders& headers);
 
+// Converts a vector of header key-value pairs into an HttpRequestHeaders
+HttpRequestHeaders WebSocketExtraHeadersToHttpRequestHeaders(
+    const WebSocketExtraHeaders& headers);
+
 // Generates a standard WebSocket handshake request. The challenge key used is
 // "dGhlIHNhbXBsZSBub25jZQ==". Each header in |extra_headers| must be terminated
 // with "\r\n".