blob: 7019bc5306d6ca97b6e17afb9aba0e40a57fa7ec [file] [log] [blame]
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/websockets/websocket_impl.h"
#include <inttypes.h>
#include <utility>
#include "base/bind.h"
#include "base/bind_helpers.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/single_thread_task_runner.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/threading/thread_task_runner_handle.h"
#include "content/browser/bad_message.h"
#include "content/browser/child_process_security_policy_impl.h"
#include "content/browser/ssl/ssl_error_handler.h"
#include "content/browser/ssl/ssl_manager.h"
#include "content/browser/websockets/websocket_handshake_request_info_impl.h"
#include "content/public/browser/storage_partition.h"
#include "ipc/ipc_message.h"
#include "net/base/io_buffer.h"
#include "net/base/net_errors.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_util.h"
#include "net/ssl/ssl_info.h"
#include "net/url_request/url_request_context_getter.h"
#include "net/websockets/websocket_channel.h"
#include "net/websockets/websocket_errors.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode
#include "net/websockets/websocket_handshake_request_info.h"
#include "net/websockets/websocket_handshake_response_info.h"
#include "url/origin.h"
namespace content {
namespace {
typedef net::WebSocketEventInterface::ChannelState ChannelState;
// Convert a blink::mojom::WebSocketMessageType to a
// net::WebSocketFrameHeader::OpCode
net::WebSocketFrameHeader::OpCode MessageTypeToOpCode(
blink::mojom::WebSocketMessageType type) {
DCHECK(type == blink::mojom::WebSocketMessageType::CONTINUATION ||
type == blink::mojom::WebSocketMessageType::TEXT ||
type == blink::mojom::WebSocketMessageType::BINARY);
typedef net::WebSocketFrameHeader::OpCode OpCode;
// These compile asserts verify that the same underlying values are used for
// both types, so we can simply cast between them.
static_assert(
static_cast<OpCode>(blink::mojom::WebSocketMessageType::CONTINUATION) ==
net::WebSocketFrameHeader::kOpCodeContinuation,
"enum values must match for opcode continuation");
static_assert(
static_cast<OpCode>(blink::mojom::WebSocketMessageType::TEXT) ==
net::WebSocketFrameHeader::kOpCodeText,
"enum values must match for opcode text");
static_assert(
static_cast<OpCode>(blink::mojom::WebSocketMessageType::BINARY) ==
net::WebSocketFrameHeader::kOpCodeBinary,
"enum values must match for opcode binary");
return static_cast<OpCode>(type);
}
blink::mojom::WebSocketMessageType OpCodeToMessageType(
net::WebSocketFrameHeader::OpCode opCode) {
DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation ||
opCode == net::WebSocketFrameHeader::kOpCodeText ||
opCode == net::WebSocketFrameHeader::kOpCodeBinary);
// This cast is guaranteed valid by the static_assert() statements above.
return static_cast<blink::mojom::WebSocketMessageType>(opCode);
}
} // namespace
// Implementation of net::WebSocketEventInterface. Receives events from our
// WebSocketChannel object.
class WebSocketImpl::WebSocketEventHandler final
: public net::WebSocketEventInterface {
public:
explicit WebSocketEventHandler(WebSocketImpl* impl);
~WebSocketEventHandler() override;
// net::WebSocketEventInterface implementation
void OnCreateURLRequest(net::URLRequest* url_request) override;
ChannelState OnAddChannelResponse(const std::string& selected_subprotocol,
const std::string& extensions) override;
ChannelState OnDataFrame(bool fin,
WebSocketMessageType type,
scoped_refptr<net::IOBuffer> buffer,
size_t buffer_size) override;
ChannelState OnClosingHandshake() override;
ChannelState OnFlowControl(int64_t quota) override;
ChannelState OnDropChannel(bool was_clean,
uint16_t code,
const std::string& reason) override;
ChannelState OnFailChannel(const std::string& message) override;
ChannelState OnStartOpeningHandshake(
std::unique_ptr<net::WebSocketHandshakeRequestInfo> request) override;
ChannelState OnFinishOpeningHandshake(
std::unique_ptr<net::WebSocketHandshakeResponseInfo> response) override;
ChannelState OnSSLCertificateError(
std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks>
callbacks,
const GURL& url,
const net::SSLInfo& ssl_info,
bool fatal) override;
private:
class SSLErrorHandlerDelegate final : public SSLErrorHandler::Delegate {
public:
SSLErrorHandlerDelegate(
std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks>
callbacks);
~SSLErrorHandlerDelegate() override;
base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
// SSLErrorHandler::Delegate methods
void CancelSSLRequest(int error, const net::SSLInfo* ssl_info) override;
void ContinueSSLRequest() override;
private:
std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
};
WebSocketImpl* const impl_;
std::unique_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
};
WebSocketImpl::WebSocketEventHandler::WebSocketEventHandler(WebSocketImpl* impl)
: impl_(impl) {
DVLOG(1) << "WebSocketEventHandler created @"
<< reinterpret_cast<void*>(this);
}
WebSocketImpl::WebSocketEventHandler::~WebSocketEventHandler() {
DVLOG(1) << "WebSocketEventHandler destroyed @"
<< reinterpret_cast<void*>(this);
}
void WebSocketImpl::WebSocketEventHandler::OnCreateURLRequest(
net::URLRequest* url_request) {
WebSocketHandshakeRequestInfoImpl::CreateInfoAndAssociateWithRequest(
impl_->child_id_, impl_->frame_id_, url_request);
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnAddChannelResponse(
const std::string& selected_protocol,
const std::string& extensions) {
DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse @"
<< reinterpret_cast<void*>(this)
<< " selected_protocol=\"" << selected_protocol << "\""
<< " extensions=\"" << extensions << "\"";
impl_->delegate_->OnReceivedResponseFromServer(impl_);
impl_->client_->OnAddChannelResponse(selected_protocol, extensions);
return net::WebSocketEventInterface::CHANNEL_ALIVE;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnDataFrame(
bool fin,
net::WebSocketFrameHeader::OpCode type,
scoped_refptr<net::IOBuffer> buffer,
size_t buffer_size) {
DVLOG(3) << "WebSocketEventHandler::OnDataFrame @"
<< reinterpret_cast<void*>(this)
<< " fin=" << fin
<< " type=" << type << " data is " << buffer_size << " bytes";
// TODO(darin): Avoid this copy.
std::vector<uint8_t> data_to_pass(buffer_size);
if (buffer_size > 0) {
std::copy(buffer->data(), buffer->data() + buffer_size,
data_to_pass.begin());
}
impl_->client_->OnDataFrame(fin, OpCodeToMessageType(type), data_to_pass);
return net::WebSocketEventInterface::CHANNEL_ALIVE;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnClosingHandshake() {
DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake @"
<< reinterpret_cast<void*>(this);
impl_->client_->OnClosingHandshake();
return net::WebSocketEventInterface::CHANNEL_ALIVE;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnFlowControl(
int64_t quota) {
DVLOG(3) << "WebSocketEventHandler::OnFlowControl @"
<< reinterpret_cast<void*>(this)
<< " quota=" << quota;
impl_->client_->OnFlowControl(quota);
return net::WebSocketEventInterface::CHANNEL_ALIVE;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnDropChannel(
bool was_clean,
uint16_t code,
const std::string& reason) {
DVLOG(3) << "WebSocketEventHandler::OnDropChannel @"
<< reinterpret_cast<void*>(this)
<< " was_clean=" << was_clean << " code=" << code
<< " reason=\"" << reason << "\"";
impl_->client_->OnDropChannel(was_clean, code, reason);
// net::WebSocketChannel requires that we delete it at this point.
impl_->channel_.reset();
return net::WebSocketEventInterface::CHANNEL_DELETED;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnFailChannel(
const std::string& message) {
DVLOG(3) << "WebSocketEventHandler::OnFailChannel @"
<< reinterpret_cast<void*>(this) << " message=\"" << message << "\"";
impl_->client_->OnFailChannel(message);
// net::WebSocketChannel requires that we delete it at this point.
impl_->channel_.reset();
return net::WebSocketEventInterface::CHANNEL_DELETED;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnStartOpeningHandshake(
std::unique_ptr<net::WebSocketHandshakeRequestInfo> request) {
bool should_send =
ChildProcessSecurityPolicyImpl::GetInstance()->CanReadRawCookies(
impl_->delegate_->GetClientProcessId());
DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake @"
<< reinterpret_cast<void*>(this) << " should_send=" << should_send;
if (!should_send)
return WebSocketEventInterface::CHANNEL_ALIVE;
blink::mojom::WebSocketHandshakeRequestPtr request_to_pass(
blink::mojom::WebSocketHandshakeRequest::New());
request_to_pass->url.Swap(&request->url);
net::HttpRequestHeaders::Iterator it(request->headers);
while (it.GetNext()) {
blink::mojom::HttpHeaderPtr header(blink::mojom::HttpHeader::New());
header->name = it.name();
header->value = it.value();
request_to_pass->headers.push_back(std::move(header));
}
request_to_pass->headers_text =
base::StringPrintf("GET %s HTTP/1.1\r\n",
request_to_pass->url.spec().c_str()) +
request->headers.ToString();
impl_->client_->OnStartOpeningHandshake(std::move(request_to_pass));
return WebSocketEventInterface::CHANNEL_ALIVE;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnFinishOpeningHandshake(
std::unique_ptr<net::WebSocketHandshakeResponseInfo> response) {
bool should_send =
ChildProcessSecurityPolicyImpl::GetInstance()->CanReadRawCookies(
impl_->delegate_->GetClientProcessId());
DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
<< reinterpret_cast<void*>(this) << " should_send=" << should_send;
if (!should_send)
return WebSocketEventInterface::CHANNEL_ALIVE;
blink::mojom::WebSocketHandshakeResponsePtr response_to_pass(
blink::mojom::WebSocketHandshakeResponse::New());
response_to_pass->url.Swap(&response->url);
response_to_pass->status_code = response->status_code;
response_to_pass->status_text = response->status_text;
size_t iter = 0;
std::string name, value;
while (response->headers->EnumerateHeaderLines(&iter, &name, &value)) {
blink::mojom::HttpHeaderPtr header(blink::mojom::HttpHeader::New());
header->name = name;
header->value = value;
response_to_pass->headers.push_back(std::move(header));
}
response_to_pass->headers_text =
net::HttpUtil::ConvertHeadersBackToHTTPResponse(
response->headers->raw_headers());
impl_->client_->OnFinishOpeningHandshake(std::move(response_to_pass));
return WebSocketEventInterface::CHANNEL_ALIVE;
}
ChannelState WebSocketImpl::WebSocketEventHandler::OnSSLCertificateError(
std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
const GURL& url,
const net::SSLInfo& ssl_info,
bool fatal) {
DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
<< reinterpret_cast<void*>(this) << " url=" << url.spec()
<< " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
ssl_error_handler_delegate_.reset(
new SSLErrorHandlerDelegate(std::move(callbacks)));
SSLManager::OnSSLCertificateSubresourceError(
ssl_error_handler_delegate_->GetWeakPtr(),
url,
impl_->delegate_->GetClientProcessId(),
impl_->frame_id_,
ssl_info,
fatal);
// The above method is always asynchronous.
return WebSocketEventInterface::CHANNEL_ALIVE;
}
WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::
SSLErrorHandlerDelegate(
std::unique_ptr<net::WebSocketEventInterface::SSLErrorCallbacks>
callbacks)
: callbacks_(std::move(callbacks)), weak_ptr_factory_(this) {}
WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::
~SSLErrorHandlerDelegate() {}
base::WeakPtr<SSLErrorHandler::Delegate>
WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
void WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::
CancelSSLRequest(int error, const net::SSLInfo* ssl_info) {
DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
<< " error=" << error
<< " cert_status=" << (ssl_info ? ssl_info->cert_status
: static_cast<net::CertStatus>(-1));
callbacks_->CancelSSLRequest(error, ssl_info);
}
void WebSocketImpl::WebSocketEventHandler::SSLErrorHandlerDelegate::
ContinueSSLRequest() {
DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
callbacks_->ContinueSSLRequest();
}
WebSocketImpl::WebSocketImpl(Delegate* delegate,
blink::mojom::WebSocketRequest request,
int child_id,
int frame_id,
base::TimeDelta delay)
: delegate_(delegate),
binding_(this, std::move(request)),
delay_(delay),
pending_flow_control_quota_(0),
child_id_(child_id),
frame_id_(frame_id),
handshake_succeeded_(false),
weak_ptr_factory_(this) {
binding_.set_connection_error_handler(
base::Bind(&WebSocketImpl::OnConnectionError, base::Unretained(this)));
}
WebSocketImpl::~WebSocketImpl() {}
void WebSocketImpl::GoAway() {
StartClosingHandshake(static_cast<uint16_t>(net::kWebSocketErrorGoingAway),
"");
}
void WebSocketImpl::AddChannelRequest(
const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
const url::Origin& origin,
const GURL& first_party_for_cookies,
const std::string& user_agent_override,
blink::mojom::WebSocketClientPtr client) {
DVLOG(3) << "WebSocketImpl::AddChannelRequest @"
<< reinterpret_cast<void*>(this)
<< " socket_url=\"" << socket_url << "\" requested_protocols=\""
<< base::JoinString(requested_protocols, ", ")
<< "\" origin=\"" << origin
<< "\" first_party_for_cookies=\"" << first_party_for_cookies
<< "\" user_agent_override=\"" << user_agent_override
<< "\"";
if (client_ || !client) {
bad_message::ReceivedBadMessage(
delegate_->GetClientProcessId(),
bad_message::WSI_UNEXPECTED_ADD_CHANNEL_REQUEST);
return;
}
client_ = std::move(client);
DCHECK(!channel_);
if (delay_ > base::TimeDelta()) {
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::Bind(&WebSocketImpl::AddChannel,
weak_ptr_factory_.GetWeakPtr(),
socket_url,
requested_protocols,
origin,
first_party_for_cookies,
user_agent_override),
delay_);
} else {
AddChannel(socket_url, requested_protocols, origin, first_party_for_cookies,
user_agent_override);
}
}
void WebSocketImpl::SendFrame(bool fin,
blink::mojom::WebSocketMessageType type,
const std::vector<uint8_t>& data) {
DVLOG(3) << "WebSocketImpl::SendFrame @"
<< reinterpret_cast<void*>(this) << " fin=" << fin
<< " type=" << type << " data is " << data.size() << " bytes";
if (!channel_) {
// The client should not be sending us frames until after we've informed
// it that the channel has been opened (OnAddChannelResponse).
if (handshake_succeeded_) {
DVLOG(1) << "Dropping frame sent to closed websocket";
} else {
bad_message::ReceivedBadMessage(
delegate_->GetClientProcessId(),
bad_message::WSI_UNEXPECTED_SEND_FRAME);
}
return;
}
// TODO(darin): Avoid this copy.
scoped_refptr<net::IOBuffer> data_to_pass(new net::IOBuffer(data.size()));
std::copy(data.begin(), data.end(), data_to_pass->data());
channel_->SendFrame(fin, MessageTypeToOpCode(type), std::move(data_to_pass),
data.size());
}
void WebSocketImpl::SendFlowControl(int64_t quota) {
DVLOG(3) << "WebSocketImpl::OnFlowControl @"
<< reinterpret_cast<void*>(this) << " quota=" << quota;
if (!channel_) {
// WebSocketChannel is not yet created due to the delay introduced by
// per-renderer WebSocket throttling.
// SendFlowControl() is called after WebSocketChannel is created.
pending_flow_control_quota_ += quota;
return;
}
ignore_result(channel_->SendFlowControl(quota));
}
void WebSocketImpl::StartClosingHandshake(uint16_t code,
const std::string& reason) {
DVLOG(3) << "WebSocketImpl::StartClosingHandshake @"
<< reinterpret_cast<void*>(this)
<< " code=" << code << " reason=\"" << reason << "\"";
if (!channel_) {
// WebSocketChannel is not yet created due to the delay introduced by
// per-renderer WebSocket throttling.
if (client_)
client_->OnDropChannel(false, net::kWebSocketErrorAbnormalClosure, "");
return;
}
ignore_result(channel_->StartClosingHandshake(code, reason));
}
void WebSocketImpl::OnConnectionError() {
DVLOG(3) << "WebSocketImpl::OnConnectionError @"
<< reinterpret_cast<void*>(this);
delegate_->OnLostConnectionToClient(this);
}
void WebSocketImpl::AddChannel(
const GURL& socket_url,
const std::vector<std::string>& requested_protocols,
const url::Origin& origin,
const GURL& first_party_for_cookies,
const std::string& user_agent_override) {
DVLOG(3) << "WebSocketImpl::AddChannel @"
<< reinterpret_cast<void*>(this)
<< " socket_url=\"" << socket_url
<< "\" requested_protocols=\""
<< base::JoinString(requested_protocols, ", ")
<< "\" origin=\"" << origin
<< "\" first_party_for_cookies=\"" << first_party_for_cookies
<< "\" user_agent_override=\"" << user_agent_override
<< "\"";
DCHECK(!channel_);
StoragePartition* partition = delegate_->GetStoragePartition();
std::unique_ptr<net::WebSocketEventInterface> event_interface(
new WebSocketEventHandler(this));
channel_.reset(
new net::WebSocketChannel(
std::move(event_interface),
partition->GetURLRequestContext()->GetURLRequestContext()));
int64_t quota = pending_flow_control_quota_;
pending_flow_control_quota_ = 0;
std::string additional_headers;
if (!user_agent_override.empty()) {
if (!net::HttpUtil::IsValidHeaderValue(user_agent_override)) {
bad_message::ReceivedBadMessage(
delegate_->GetClientProcessId(),
bad_message::WSI_INVALID_HEADER_VALUE);
return;
}
additional_headers = base::StringPrintf("%s:%s",
net::HttpRequestHeaders::kUserAgent,
user_agent_override.c_str());
}
channel_->SendAddChannelRequest(socket_url, requested_protocols, origin,
first_party_for_cookies, additional_headers);
if (quota > 0)
SendFlowControl(quota);
}
} // namespace content