blob: a90a896afffa40fc1824d0cd4906bad50e789d68 [file] [log] [blame]
Bence Békyb28709c22018-03-06 13:03:441// Copyright 2018 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_handshake_stream_base.h"
6
7#include <unordered_set>
8
9#include "base/strings/string_util.h"
10#include "net/http/http_request_headers.h"
11#include "net/http/http_response_headers.h"
12#include "net/websockets/websocket_extension.h"
13#include "net/websockets/websocket_extension_parser.h"
14#include "net/websockets/websocket_handshake_constants.h"
15
16namespace net {
17
18// static
19std::string WebSocketHandshakeStreamBase::MultipleHeaderValuesMessage(
20 const std::string& header_name) {
21 return std::string("'") + header_name +
22 "' header must not appear more than once in a response";
23}
24
25// static
26void WebSocketHandshakeStreamBase::AddVectorHeaderIfNonEmpty(
27 const char* name,
28 const std::vector<std::string>& value,
29 HttpRequestHeaders* headers) {
30 if (value.empty())
31 return;
32 headers->SetHeader(name, base::JoinString(value, ", "));
33}
34
35// static
36bool WebSocketHandshakeStreamBase::ValidateSubProtocol(
37 const HttpResponseHeaders* headers,
38 const std::vector<std::string>& requested_sub_protocols,
39 std::string* sub_protocol,
40 std::string* failure_message) {
41 size_t iter = 0;
42 std::string value;
43 std::unordered_set<std::string> requested_set(requested_sub_protocols.begin(),
44 requested_sub_protocols.end());
45 int count = 0;
46 bool has_multiple_protocols = false;
47 bool has_invalid_protocol = false;
48
49 while (!has_invalid_protocol || !has_multiple_protocols) {
50 std::string temp_value;
51 if (!headers->EnumerateHeader(&iter, websockets::kSecWebSocketProtocol,
52 &temp_value))
53 break;
54 value = temp_value;
55 if (requested_set.count(value) == 0)
56 has_invalid_protocol = true;
57 if (++count > 1)
58 has_multiple_protocols = true;
59 }
60
61 if (has_multiple_protocols) {
62 *failure_message =
63 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
64 return false;
65 } else if (count > 0 && requested_sub_protocols.size() == 0) {
66 *failure_message = std::string(
67 "Response must not include 'Sec-WebSocket-Protocol' "
68 "header if not present in request: ") +
69 value;
70 return false;
71 } else if (has_invalid_protocol) {
72 *failure_message = "'Sec-WebSocket-Protocol' header value '" + value +
73 "' in response does not match any of sent values";
74 return false;
75 } else if (requested_sub_protocols.size() > 0 && count == 0) {
76 *failure_message =
77 "Sent non-empty 'Sec-WebSocket-Protocol' header "
78 "but no response was received";
79 return false;
80 }
81 *sub_protocol = value;
82 return true;
83}
84
85// static
86bool WebSocketHandshakeStreamBase::ValidateExtensions(
87 const HttpResponseHeaders* headers,
88 std::string* accepted_extensions_descriptor,
89 std::string* failure_message,
90 WebSocketExtensionParams* params) {
91 size_t iter = 0;
92 std::string header_value;
93 std::vector<std::string> header_values;
94 // TODO(ricea): If adding support for additional extensions, generalise this
95 // code.
96 bool seen_permessage_deflate = false;
97 while (headers->EnumerateHeader(&iter, websockets::kSecWebSocketExtensions,
98 &header_value)) {
99 WebSocketExtensionParser parser;
100 if (!parser.Parse(header_value)) {
101 // TODO(yhirano) Set appropriate failure message.
102 *failure_message =
103 "'Sec-WebSocket-Extensions' header value is "
104 "rejected by the parser: " +
105 header_value;
106 return false;
107 }
108
109 const std::vector<WebSocketExtension>& extensions = parser.extensions();
110 for (const auto& extension : extensions) {
111 if (extension.name() == "permessage-deflate") {
112 if (seen_permessage_deflate) {
113 *failure_message = "Received duplicate permessage-deflate response";
114 return false;
115 }
116 seen_permessage_deflate = true;
117 auto& deflate_parameters = params->deflate_parameters;
118 if (!deflate_parameters.Initialize(extension, failure_message) ||
119 !deflate_parameters.IsValidAsResponse(failure_message)) {
120 *failure_message = "Error in permessage-deflate: " + *failure_message;
121 return false;
122 }
123 // Note that we don't have to check the request-response compatibility
124 // here because we send a request compatible with any valid responses.
125 // TODO(yhirano): Place a DCHECK here.
126
127 header_values.push_back(header_value);
128 } else {
129 *failure_message = "Found an unsupported extension '" +
130 extension.name() +
131 "' in 'Sec-WebSocket-Extensions' header";
132 return false;
133 }
134 }
135 }
136 *accepted_extensions_descriptor = base::JoinString(header_values, ", ");
137 params->deflate_enabled = seen_permessage_deflate;
138 return true;
139}
140
141} // namespace net