| // Copyright 2013 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "content/browser/websockets/websocket_impl.h" |
| |
| #include <inttypes.h> |
| |
| #include <utility> |
| |
| #include "base/bind.h" |
| #include "base/bind_helpers.h" |
| #include "base/location.h" |
| #include "base/logging.h" |
| #include "base/macros.h" |
| #include "base/memory/ptr_util.h" |
| #include "base/single_thread_task_runner.h" |
| #include "base/strings/string_util.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "content/browser/bad_message.h" |
| #include "content/browser/child_process_security_policy_impl.h" |
| #include "content/browser/ssl/ssl_error_handler.h" |
| #include "content/browser/ssl/ssl_manager.h" |
| #include "content/browser/websockets/websocket_handshake_request_info_impl.h" |
| #include "content/public/browser/storage_partition.h" |
| #include "ipc/ipc_message.h" |
| #include "net/base/io_buffer.h" |
| #include "net/base/net_errors.h" |
| #include "net/http/http_request_headers.h" |
| #include "net/http/http_response_headers.h" |
| #include "net/http/http_util.h" |
| #include "net/ssl/ssl_info.h" |
| #include "net/url_request/url_request_context_getter.h" |
| #include "net/websockets/websocket_channel.h" |
| #include "net/websockets/websocket_errors.h" |
| #include "net/websockets/websocket_event_interface.h" |
| #include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode |
| #include "net/websockets/websocket_handshake_request_info.h" |
| #include "net/websockets/websocket_handshake_response_info.h" |
| #include "url/origin.h" |
| |
| namespace content { |
| namespace { |
| |
| typedef net::WebSocketEventInterface::ChannelState ChannelState; |
| |
| // Convert a blink::mojom::WebSocketMessageType to a |
| // net::WebSocketFrameHeader::OpCode |
| net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( |
| blink::mojom::WebSocketMessageType type) { |
| DCHECK(type == blink::mojom::WebSocketMessageType::CONTINUATION || |
| type == blink::mojom::WebSocketMessageType::TEXT || |
| type == blink::mojom::WebSocketMessageType::BINARY); |
| typedef net::WebSocketFrameHeader::OpCode OpCode; |
| // These compile asserts verify that the same underlying values are used for |
| // both types, so we can simply cast between them. |
| static_assert( |
| static_cast<OpCode>(blink::mojom::WebSocketMessageType::CONTINUATION) == |
| net::WebSocketFrameHeader::kOpCodeContinuation, |
| "enum values must match for opcode continuation"); |
| static_assert( |
| static_cast<OpCode>(blink::mojom::WebSocketMessageType::TEXT) == |
| net::WebSocketFrameHeader::kOpCodeText, |
| "enum values must match for opcode text"); |
| static_assert( |
| static_cast<OpCode>(blink::mojom::WebSocketMessageType::BINARY) == |
| net::WebSocketFrameHeader::kOpCodeBinary, |
| "enum values must match for opcode binary"); |
| return static_cast<OpCode>(type); |
| } |
| |
| blink::mojom::WebSocketMessageType OpCodeToMessageType( |
| net::WebSocketFrameHeader::OpCode opCode) { |
| DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation || |
| opCode == net::WebSocketFrameHeader::kOpCodeText || |
| opCode == net::WebSocketFrameHeader::kOpCodeBinary); |
| // This cast is guaranteed valid by the static_assert() statements above. |
| return static_cast<blink::mojom::WebSocketMessageType>(opCode); |
| } |
| |
| } // namespace |
| |
| // Implementation of net::WebSocketEventInterface. Receives events from our |
| // WebSocketChannel object. |
| class WebSocketImpl::WebSocketEventHandler final |
| : public net::WebSocketEventInterface { |
| public: |
| explicit WebSocketEventHandler(WebSocketImpl* impl); |
| ~WebSocketEventHandler() override; |
| |
| // net::WebSocketEventInterface implementation |
| |
| void OnCreateURLRequest(net::URLRequest* url_request) override; |
| ChannelState OnAddChannelResponse(const std::string& selected_subprotocol, |
| const std::string& extensions) override; |
| ChannelState OnDataFrame(bool fin, |
| WebSocketMessageType type, |
| scoped_refptr<net::IOBuffer> buffer, |
| size_t buffer_size) override; |
| ChannelState OnClosingHandshake() override; |
| ChannelState OnFlowControl(int64_t quota) override; |
| ChannelState OnDropChannel(bool was_clean, |
| uint16_t code, |
| const std::string& reason) override; |
| ChannelState OnFailChannel(const std::string& message) override; |
| ChannelState OnStartOpeningHandshake( |
| std::unique_ptr<net::WebSocketHandshakeRequestInfo> request) override; |
| ChannelState OnFinishOpeningHandshake( |
| std::unique_ptr<net::WebSocketHandshakeResponseInfo> response) override; |
| ChannelState OnSSLCertificateError( |
| std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> |
| callbacks, |
| const GURL& url, |
| const net::SSLInfo& ssl_info, |
| bool fatal) override; |
| |
| private: |
| class SSLErrorHandlerDelegate final : public SSLErrorHandler::Delegate { |
| public: |
| SSLErrorHandlerDelegate( |
| std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> |
| callbacks); |
| ~SSLErrorHandlerDelegate() override; |
| |
| base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr(); |
| |
| // SSLErrorHandler::Delegate methods |
| void CancelSSLRequest(int error, const net::SSLInfo* ssl_info) override; |
| void ContinueSSLRequest() override; |
| |
| private: |
| std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_; |
| base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_; |
| |
| DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate); |
| }; |
| |
| WebSocketImpl* const impl_; |
| std::unique_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_; |
| |
| DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); |
| }; |
| |
| WebSocketImpl::WebSocketEventHandler::WebSocketEventHandler(WebSocketImpl* impl) |
| : impl_(impl) { |
| DVLOG(1) << "WebSocketEventHandler created @" |
| << reinterpret_cast<void*>(this); |
| } |
| |
| WebSocketImpl::WebSocketEventHandler::~WebSocketEventHandler() { |
| DVLOG(1) << "WebSocketEventHandler destroyed @" |
| << reinterpret_cast<void*>(this); |
| } |
| |
| void WebSocketImpl::WebSocketEventHandler::OnCreateURLRequest( |
| net::URLRequest* url_request) { |
| WebSocketHandshakeRequestInfoImpl::CreateInfoAndAssociateWithRequest( |
| impl_->child_id_, impl_->frame_id_, url_request); |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnAddChannelResponse( |
| const std::string& selected_protocol, |
| const std::string& extensions) { |
| DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse @" |
| << reinterpret_cast<void*>(this) |
| << " selected_protocol=\"" << selected_protocol << "\"" |
| << " extensions=\"" << extensions << "\""; |
| |
| impl_->delegate_->OnReceivedResponseFromServer(impl_); |
| |
| impl_->client_->OnAddChannelResponse(selected_protocol, extensions); |
| |
| return net::WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnDataFrame( |
| bool fin, |
| net::WebSocketFrameHeader::OpCode type, |
| scoped_refptr<net::IOBuffer> buffer, |
| size_t buffer_size) { |
| DVLOG(3) << "WebSocketEventHandler::OnDataFrame @" |
| << reinterpret_cast<void*>(this) |
| << " fin=" << fin |
| << " type=" << type << " data is " << buffer_size << " bytes"; |
| |
| // TODO(darin): Avoid this copy. |
| std::vector<uint8_t> data_to_pass(buffer_size); |
| if (buffer_size > 0) { |
| std::copy(buffer->data(), buffer->data() + buffer_size, |
| data_to_pass.begin()); |
| } |
| |
| impl_->client_->OnDataFrame(fin, OpCodeToMessageType(type), data_to_pass); |
| |
| return net::WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnClosingHandshake() { |
| DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake @" |
| << reinterpret_cast<void*>(this); |
| |
| impl_->client_->OnClosingHandshake(); |
| |
| return net::WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnFlowControl( |
| int64_t quota) { |
| DVLOG(3) << "WebSocketEventHandler::OnFlowControl @" |
| << reinterpret_cast<void*>(this) |
| << " quota=" << quota; |
| |
| impl_->client_->OnFlowControl(quota); |
| |
| return net::WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnDropChannel( |
| bool was_clean, |
| uint16_t code, |
| const std::string& reason) { |
| DVLOG(3) << "WebSocketEventHandler::OnDropChannel @" |
| << reinterpret_cast<void*>(this) |
| << " was_clean=" << was_clean << " code=" << code |
| << " reason=\"" << reason << "\""; |
| |
| impl_->client_->OnDropChannel(was_clean, code, reason); |
| |
| // net::WebSocketChannel requires that we delete it at this point. |
| impl_->channel_.reset(); |
| |
| return net::WebSocketEventInterface::CHANNEL_DELETED; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnFailChannel( |
| const std::string& message) { |
| DVLOG(3) << "WebSocketEventHandler::OnFailChannel @" |
| << reinterpret_cast<void*>(this) << " message=\"" << message << "\""; |
| |
| impl_->client_->OnFailChannel(message); |
| |
| // net::WebSocketChannel requires that we delete it at this point. |
| impl_->channel_.reset(); |
| |
| return net::WebSocketEventInterface::CHANNEL_DELETED; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnStartOpeningHandshake( |
| std::unique_ptr<net::WebSocketHandshakeRequestInfo> request) { |
| bool should_send = |
| ChildProcessSecurityPolicyImpl::GetInstance()->CanReadRawCookies( |
| impl_->delegate_->GetClientProcessId()); |
| |
| DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake @" |
| << reinterpret_cast<void*>(this) << " should_send=" << should_send; |
| |
| if (!should_send) |
| return WebSocketEventInterface::CHANNEL_ALIVE; |
| |
| blink::mojom::WebSocketHandshakeRequestPtr request_to_pass( |
| blink::mojom::WebSocketHandshakeRequest::New()); |
| request_to_pass->url.Swap(&request->url); |
| net::HttpRequestHeaders::Iterator it(request->headers); |
| while (it.GetNext()) { |
| blink::mojom::HttpHeaderPtr header(blink::mojom::HttpHeader::New()); |
| header->name = it.name(); |
| header->value = it.value(); |
| request_to_pass->headers.push_back(std::move(header)); |
| } |
| request_to_pass->headers_text = |
| base::StringPrintf("GET %s HTTP/1.1\r\n", |
| request_to_pass->url.spec().c_str()) + |
| request->headers.ToString(); |
| |
| impl_->client_->OnStartOpeningHandshake(std::move(request_to_pass)); |
| |
| return WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnFinishOpeningHandshake( |
| std::unique_ptr<net::WebSocketHandshakeResponseInfo> response) { |
| bool should_send = |
| ChildProcessSecurityPolicyImpl::GetInstance()->CanReadRawCookies( |
| impl_->delegate_->GetClientProcessId()); |
| |
| DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake " |
| << reinterpret_cast<void*>(this) << " should_send=" << should_send; |
| |
| if (!should_send) |
| return WebSocketEventInterface::CHANNEL_ALIVE; |
| |
| blink::mojom::WebSocketHandshakeResponsePtr response_to_pass( |
| blink::mojom::WebSocketHandshakeResponse::New()); |
| response_to_pass->url.Swap(&response->url); |
| response_to_pass->status_code = response->status_code; |
| response_to_pass->status_text = response->status_text; |
| size_t iter = 0; |
| std::string name, value; |
| while (response->headers->EnumerateHeaderLines(&iter, &name, &value)) { |
| blink::mojom::HttpHeaderPtr header(blink::mojom::HttpHeader::New()); |
| header->name = name; |
| header->value = value; |
| response_to_pass->headers.push_back(std::move(header)); |
| } |
| response_to_pass->headers_text = |
| net::HttpUtil::ConvertHeadersBackToHTTPResponse( |
| response->headers->raw_headers()); |
| |
| impl_->client_->OnFinishOpeningHandshake(std::move(response_to_pass)); |
| |
| return WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| ChannelState WebSocketImpl::WebSocketEventHandler::OnSSLCertificateError( |
| std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, |
| const GURL& url, |
| const net::SSLInfo& ssl_info, |
| bool fatal) { |
| DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError" |
| << reinterpret_cast<void*>(this) << " url=" << url.spec() |
| << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal; |
| ssl_error_handler_delegate_.reset( |
| new SSLErrorHandlerDelegate(std::move(callbacks))); |
| SSLManager::OnSSLCertificateSubresourceError( |
| ssl_error_handler_delegate_->GetWeakPtr(), |
| url, |
| impl_->delegate_->GetClientProcessId(), |
| impl_->frame_id_, |
| ssl_info, |
| fatal); |
| // The above method is always asynchronous. |
| return WebSocketEventInterface::CHANNEL_ALIVE; |
| } |
| |
| WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
| SSLErrorHandlerDelegate( |
| std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> |
| callbacks) |
| : callbacks_(std::move(callbacks)), weak_ptr_factory_(this) {} |
| |
| WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
| ~SSLErrorHandlerDelegate() {} |
| |
| base::WeakPtr<SSLErrorHandler::Delegate> |
| WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { |
| return weak_ptr_factory_.GetWeakPtr(); |
| } |
| |
| void WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
| CancelSSLRequest(int error, const net::SSLInfo* ssl_info) { |
| DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" |
| << " error=" << error |
| << " cert_status=" << (ssl_info ? ssl_info->cert_status |
| : static_cast<net::CertStatus>(-1)); |
| callbacks_->CancelSSLRequest(error, ssl_info); |
| } |
| |
| void WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate:: |
| ContinueSSLRequest() { |
| DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; |
| callbacks_->ContinueSSLRequest(); |
| } |
| |
| WebSocketImpl::WebSocketImpl(Delegate* delegate, |
| blink::mojom::WebSocketRequest request, |
| int child_id, |
| int frame_id, |
| base::TimeDelta delay) |
| : delegate_(delegate), |
| binding_(this, std::move(request)), |
| delay_(delay), |
| pending_flow_control_quota_(0), |
| child_id_(child_id), |
| frame_id_(frame_id), |
| handshake_succeeded_(false), |
| weak_ptr_factory_(this) { |
| binding_.set_connection_error_handler( |
| base::Bind(&WebSocketImpl::OnConnectionError, base::Unretained(this))); |
| } |
| |
| WebSocketImpl::~WebSocketImpl() {} |
| |
| void WebSocketImpl::GoAway() { |
| StartClosingHandshake(static_cast<uint16_t>(net::kWebSocketErrorGoingAway), |
| ""); |
| } |
| |
| void WebSocketImpl::AddChannelRequest( |
| const GURL& socket_url, |
| const std::vector<std::string>& requested_protocols, |
| const url::Origin& origin, |
| const GURL& first_party_for_cookies, |
| const std::string& user_agent_override, |
| blink::mojom::WebSocketClientPtr client) { |
| DVLOG(3) << "WebSocketImpl::AddChannelRequest @" |
| << reinterpret_cast<void*>(this) |
| << " socket_url=\"" << socket_url << "\" requested_protocols=\"" |
| << base::JoinString(requested_protocols, ", ") |
| << "\" origin=\"" << origin |
| << "\" first_party_for_cookies=\"" << first_party_for_cookies |
| << "\" user_agent_override=\"" << user_agent_override |
| << "\""; |
| |
| if (client_ || !client) { |
| bad_message::ReceivedBadMessage( |
| delegate_->GetClientProcessId(), |
| bad_message::WSI_UNEXPECTED_ADD_CHANNEL_REQUEST); |
| return; |
| } |
| |
| client_ = std::move(client); |
| |
| DCHECK(!channel_); |
| if (delay_ > base::TimeDelta()) { |
| base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( |
| FROM_HERE, |
| base::Bind(&WebSocketImpl::AddChannel, |
| weak_ptr_factory_.GetWeakPtr(), |
| socket_url, |
| requested_protocols, |
| origin, |
| first_party_for_cookies, |
| user_agent_override), |
| delay_); |
| } else { |
| AddChannel(socket_url, requested_protocols, origin, first_party_for_cookies, |
| user_agent_override); |
| } |
| } |
| |
| void WebSocketImpl::SendFrame(bool fin, |
| blink::mojom::WebSocketMessageType type, |
| const std::vector<uint8_t>& data) { |
| DVLOG(3) << "WebSocketImpl::SendFrame @" |
| << reinterpret_cast<void*>(this) << " fin=" << fin |
| << " type=" << type << " data is " << data.size() << " bytes"; |
| |
| if (!channel_) { |
| // The client should not be sending us frames until after we've informed |
| // it that the channel has been opened (OnAddChannelResponse). |
| if (handshake_succeeded_) { |
| DVLOG(1) << "Dropping frame sent to closed websocket"; |
| } else { |
| bad_message::ReceivedBadMessage( |
| delegate_->GetClientProcessId(), |
| bad_message::WSI_UNEXPECTED_SEND_FRAME); |
| } |
| return; |
| } |
| |
| // TODO(darin): Avoid this copy. |
| scoped_refptr<net::IOBuffer> data_to_pass(new net::IOBuffer(data.size())); |
| std::copy(data.begin(), data.end(), data_to_pass->data()); |
| |
| channel_->SendFrame(fin, MessageTypeToOpCode(type), std::move(data_to_pass), |
| data.size()); |
| } |
| |
| void WebSocketImpl::SendFlowControl(int64_t quota) { |
| DVLOG(3) << "WebSocketImpl::OnFlowControl @" |
| << reinterpret_cast<void*>(this) << " quota=" << quota; |
| |
| if (!channel_) { |
| // WebSocketChannel is not yet created due to the delay introduced by |
| // per-renderer WebSocket throttling. |
| // SendFlowControl() is called after WebSocketChannel is created. |
| pending_flow_control_quota_ += quota; |
| return; |
| } |
| |
| ignore_result(channel_->SendFlowControl(quota)); |
| } |
| |
| void WebSocketImpl::StartClosingHandshake(uint16_t code, |
| const std::string& reason) { |
| DVLOG(3) << "WebSocketImpl::StartClosingHandshake @" |
| << reinterpret_cast<void*>(this) |
| << " code=" << code << " reason=\"" << reason << "\""; |
| |
| if (!channel_) { |
| // WebSocketChannel is not yet created due to the delay introduced by |
| // per-renderer WebSocket throttling. |
| if (client_) |
| client_->OnDropChannel(false, net::kWebSocketErrorAbnormalClosure, ""); |
| return; |
| } |
| |
| ignore_result(channel_->StartClosingHandshake(code, reason)); |
| } |
| |
| void WebSocketImpl::OnConnectionError() { |
| DVLOG(3) << "WebSocketImpl::OnConnectionError @" |
| << reinterpret_cast<void*>(this); |
| |
| delegate_->OnLostConnectionToClient(this); |
| } |
| |
| void WebSocketImpl::AddChannel( |
| const GURL& socket_url, |
| const std::vector<std::string>& requested_protocols, |
| const url::Origin& origin, |
| const GURL& first_party_for_cookies, |
| const std::string& user_agent_override) { |
| DVLOG(3) << "WebSocketImpl::AddChannel @" |
| << reinterpret_cast<void*>(this) |
| << " socket_url=\"" << socket_url |
| << "\" requested_protocols=\"" |
| << base::JoinString(requested_protocols, ", ") |
| << "\" origin=\"" << origin |
| << "\" first_party_for_cookies=\"" << first_party_for_cookies |
| << "\" user_agent_override=\"" << user_agent_override |
| << "\""; |
| |
| DCHECK(!channel_); |
| |
| StoragePartition* partition = delegate_->GetStoragePartition(); |
| |
| std::unique_ptr<net::WebSocketEventInterface> event_interface( |
| new WebSocketEventHandler(this)); |
| channel_.reset( |
| new net::WebSocketChannel( |
| std::move(event_interface), |
| partition->GetURLRequestContext()->GetURLRequestContext())); |
| |
| int64_t quota = pending_flow_control_quota_; |
| pending_flow_control_quota_ = 0; |
| |
| std::string additional_headers; |
| if (!user_agent_override.empty()) { |
| if (!net::HttpUtil::IsValidHeaderValue(user_agent_override)) { |
| bad_message::ReceivedBadMessage( |
| delegate_->GetClientProcessId(), |
| bad_message::WSI_INVALID_HEADER_VALUE); |
| return; |
| } |
| additional_headers = base::StringPrintf("%s:%s", |
| net::HttpRequestHeaders::kUserAgent, |
| user_agent_override.c_str()); |
| } |
| channel_->SendAddChannelRequest(socket_url, requested_protocols, origin, |
| first_party_for_cookies, additional_headers); |
| if (quota > 0) |
| SendFlowControl(quota); |
| } |
| |
| } // namespace content |