naiveproxy/src/net/websockets/websocket_stream.cc

500 lines
18 KiB
C++
Raw Normal View History

2022-05-03 13:16:59 +03:00
// Copyright 2013 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/websockets/websocket_stream.h"
#include <utility>
#include "base/bind.h"
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/time/time.h"
#include "base/timer/timer.h"
#include "net/base/ip_endpoint.h"
#include "net/base/isolation_info.h"
#include "net/base/load_flags.h"
#include "net/base/url_util.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_response_headers.h"
#include "net/http/http_response_info.h"
#include "net/http/http_status_code.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "net/url_request/redirect_info.h"
#include "net/url_request/url_request.h"
#include "net/url_request/url_request_context.h"
#include "net/url_request/websocket_handshake_userdata_key.h"
#include "net/websockets/websocket_basic_handshake_stream.h"
#include "net/websockets/websocket_errors.h"
#include "net/websockets/websocket_event_interface.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "net/websockets/websocket_handshake_stream_base.h"
#include "net/websockets/websocket_handshake_stream_create_helper.h"
#include "net/websockets/websocket_http2_handshake_stream.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
#include "url/origin.h"
namespace net {
namespace {
// The timeout duration of WebSocket handshake.
// It is defined as the same value as the TCP connection timeout value in
// net/socket/websocket_transport_client_socket_pool.cc to make it hard for
// JavaScript programs to recognize the timeout cause.
const int kHandshakeTimeoutIntervalInSeconds = 240;
class WebSocketStreamRequestImpl;
class Delegate : public URLRequest::Delegate {
public:
explicit Delegate(WebSocketStreamRequestImpl* owner) : owner_(owner) {}
~Delegate() override = default;
// Implementation of URLRequest::Delegate methods.
void OnReceivedRedirect(URLRequest* request,
const RedirectInfo& redirect_info,
bool* defer_redirect) override;
void OnResponseStarted(URLRequest* request, int net_error) override;
void OnAuthRequired(URLRequest* request,
const AuthChallengeInfo& auth_info) override;
void OnCertificateRequested(URLRequest* request,
SSLCertRequestInfo* cert_request_info) override;
void OnSSLCertificateError(URLRequest* request,
int net_error,
const SSLInfo& ssl_info,
bool fatal) override;
void OnReadCompleted(URLRequest* request, int bytes_read) override;
private:
void OnAuthRequiredComplete(URLRequest* request,
const AuthCredentials* auth_credentials);
raw_ptr<WebSocketStreamRequestImpl> owner_;
};
class WebSocketStreamRequestImpl : public WebSocketStreamRequestAPI {
public:
WebSocketStreamRequestImpl(
const GURL& url,
const std::vector<std::string>& requested_subprotocols,
const URLRequestContext* context,
const url::Origin& origin,
const SiteForCookies& site_for_cookies,
const IsolationInfo& isolation_info,
const HttpRequestHeaders& additional_headers,
NetworkTrafficAnnotationTag traffic_annotation,
std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
std::unique_ptr<WebSocketStreamRequestAPI> api_delegate)
: delegate_(this),
url_request_(context->CreateRequest(url,
DEFAULT_PRIORITY,
&delegate_,
traffic_annotation,
/*is_for_websockets=*/true)),
connect_delegate_(std::move(connect_delegate)),
api_delegate_(std::move(api_delegate)) {
DCHECK_EQ(IsolationInfo::RequestType::kOther,
isolation_info.request_type());
HttpRequestHeaders headers = additional_headers;
headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase);
headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade);
headers.SetHeader(HttpRequestHeaders::kOrigin, origin.Serialize());
headers.SetHeader(websockets::kSecWebSocketVersion,
websockets::kSupportedVersion);
// Remove HTTP headers that are important to websocket connections: they
// will be added later.
headers.RemoveHeader(websockets::kSecWebSocketExtensions);
headers.RemoveHeader(websockets::kSecWebSocketKey);
headers.RemoveHeader(websockets::kSecWebSocketProtocol);
url_request_->SetExtraRequestHeaders(headers);
url_request_->set_initiator(origin);
url_request_->set_site_for_cookies(site_for_cookies);
url_request_->set_isolation_info(isolation_info);
auto create_helper = std::make_unique<WebSocketHandshakeStreamCreateHelper>(
connect_delegate_.get(), requested_subprotocols, this);
url_request_->SetUserData(kWebSocketHandshakeUserDataKey,
std::move(create_helper));
url_request_->SetLoadFlags(LOAD_DISABLE_CACHE | LOAD_BYPASS_CACHE);
connect_delegate_->OnCreateRequest(url_request_.get());
}
// Destroying this object destroys the URLRequest, which cancels the request
// and so terminates the handshake if it is incomplete.
~WebSocketStreamRequestImpl() override = default;
void OnBasicHandshakeStreamCreated(
WebSocketBasicHandshakeStream* handshake_stream) override {
if (api_delegate_) {
api_delegate_->OnBasicHandshakeStreamCreated(handshake_stream);
}
OnHandshakeStreamCreated(handshake_stream);
}
void OnHttp2HandshakeStreamCreated(
WebSocketHttp2HandshakeStream* handshake_stream) override {
if (api_delegate_) {
api_delegate_->OnHttp2HandshakeStreamCreated(handshake_stream);
}
OnHandshakeStreamCreated(handshake_stream);
}
void OnFailure(const std::string& message,
int net_error,
absl::optional<int> response_code) override {
if (api_delegate_)
api_delegate_->OnFailure(message, net_error, response_code);
failure_message_ = message;
failure_net_error_ = net_error;
failure_response_code_ = response_code;
}
void Start(std::unique_ptr<base::OneShotTimer> timer) {
DCHECK(timer);
base::TimeDelta timeout(base::Seconds(kHandshakeTimeoutIntervalInSeconds));
timer_ = std::move(timer);
timer_->Start(FROM_HERE, timeout,
base::BindOnce(&WebSocketStreamRequestImpl::OnTimeout,
base::Unretained(this)));
url_request_->Start();
}
void PerformUpgrade() {
DCHECK(timer_);
DCHECK(connect_delegate_);
timer_->Stop();
if (!handshake_stream_) {
ReportFailureWithMessage(
"No handshake stream has been created or handshake stream is already "
"destroyed.",
ERR_FAILED, absl::nullopt);
return;
}
std::unique_ptr<URLRequest> url_request = std::move(url_request_);
WebSocketHandshakeStreamBase* handshake_stream = handshake_stream_.get();
handshake_stream_.reset();
auto handshake_response_info =
std::make_unique<WebSocketHandshakeResponseInfo>(
url_request->url(), url_request->response_headers(),
url_request->GetResponseRemoteEndpoint(),
url_request->response_time());
connect_delegate_->OnSuccess(handshake_stream->Upgrade(),
std::move(handshake_response_info));
// This is safe even if |this| has already been deleted.
url_request->CancelWithError(ERR_WS_UPGRADE);
}
std::string FailureMessageFromNetError(int net_error) {
if (net_error == ERR_TUNNEL_CONNECTION_FAILED) {
// This error is common and confusing, so special-case it.
// TODO(ricea): Include the HostPortPair of the selected proxy server in
// the error message. This is not currently possible because it isn't set
// in HttpResponseInfo when a ERR_TUNNEL_CONNECTION_FAILED error happens.
return "Establishing a tunnel via proxy server failed.";
} else {
return std::string("Error in connection establishment: ") +
ErrorToString(net_error);
}
}
void ReportFailure(int net_error, absl::optional<int> response_code) {
DCHECK(timer_);
timer_->Stop();
if (failure_message_.empty()) {
switch (net_error) {
case OK:
case ERR_IO_PENDING:
break;
case ERR_ABORTED:
failure_message_ = "WebSocket opening handshake was canceled";
break;
case ERR_TIMED_OUT:
failure_message_ = "WebSocket opening handshake timed out";
break;
default:
failure_message_ = FailureMessageFromNetError(net_error);
break;
}
}
ReportFailureWithMessage(
failure_message_, failure_net_error_.value_or(net_error),
failure_response_code_ ? failure_response_code_ : response_code);
}
void ReportFailureWithMessage(const std::string& failure_message,
int net_error,
absl::optional<int> response_code) {
connect_delegate_->OnFailure(failure_message, net_error, response_code);
}
WebSocketStream::ConnectDelegate* connect_delegate() const {
return connect_delegate_.get();
}
void OnTimeout() {
url_request_->CancelWithError(ERR_TIMED_OUT);
}
private:
void OnHandshakeStreamCreated(
WebSocketHandshakeStreamBase* handshake_stream) {
DCHECK(handshake_stream);
handshake_stream_ = handshake_stream->GetWeakPtr();
}
// |delegate_| needs to be declared before |url_request_| so that it gets
// initialised first.
Delegate delegate_;
// Deleting the WebSocketStreamRequestImpl object deletes this URLRequest
// object, cancelling the whole connection.
std::unique_ptr<URLRequest> url_request_;
std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate_;
// This is owned by the caller of
// WebsocketHandshakeStreamCreateHelper::CreateBasicStream() or
// CreateHttp2Stream(). Both the stream and this object will be destroyed
// during the destruction of the URLRequest object associated with the
// handshake. This is only guaranteed to be a valid pointer if the handshake
// succeeded.
base::WeakPtr<WebSocketHandshakeStreamBase> handshake_stream_;
// The failure information supplied by WebSocketBasicHandshakeStream, if any.
std::string failure_message_;
absl::optional<int> failure_net_error_;
absl::optional<int> failure_response_code_;
// A timer for handshake timeout.
std::unique_ptr<base::OneShotTimer> timer_;
// A delegate for On*HandshakeCreated and OnFailure calls.
std::unique_ptr<WebSocketStreamRequestAPI> api_delegate_;
};
class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks {
public:
explicit SSLErrorCallbacks(URLRequest* url_request)
: url_request_(url_request->GetWeakPtr()) {}
void CancelSSLRequest(int error, const SSLInfo* ssl_info) override {
if (!url_request_)
return;
if (ssl_info) {
url_request_->CancelWithSSLError(error, *ssl_info);
} else {
url_request_->CancelWithError(error);
}
}
void ContinueSSLRequest() override {
if (url_request_)
url_request_->ContinueDespiteLastError();
}
private:
base::WeakPtr<URLRequest> url_request_;
};
void Delegate::OnReceivedRedirect(URLRequest* request,
const RedirectInfo& redirect_info,
bool* defer_redirect) {
// This code should never be reached for externally generated redirects,
// as WebSocketBasicHandshakeStream is responsible for filtering out
// all response codes besides 101, 401, and 407. As such, the URLRequest
// should never see a redirect sent over the network. However, internal
// redirects also result in this method being called, such as those
// caused by HSTS.
// Because it's security critical to prevent externally-generated
// redirects in WebSockets, perform additional checks to ensure this
// is only internal.
GURL::Replacements replacements;
replacements.SetSchemeStr("wss");
GURL expected_url = request->original_url().ReplaceComponents(replacements);
if (redirect_info.new_method != "GET" ||
redirect_info.new_url != expected_url) {
// This should not happen.
DLOG(FATAL) << "Unauthorized WebSocket redirect to "
<< redirect_info.new_method << " "
<< redirect_info.new_url.spec();
request->Cancel();
}
}
void Delegate::OnResponseStarted(URLRequest* request, int net_error) {
DCHECK_NE(ERR_IO_PENDING, net_error);
// All error codes, including OK and ABORTED, as with
// Net.ErrorCodesForMainFrame4
base::UmaHistogramSparse("Net.WebSocket.ErrorCodes", -net_error);
if (net::IsLocalhost(request->url())) {
base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_Localhost", -net_error);
} else {
base::UmaHistogramSparse("Net.WebSocket.ErrorCodes_NotLocalhost",
-net_error);
}
if (net_error != OK) {
DVLOG(3) << "OnResponseStarted (request failed)";
owner_->ReportFailure(net_error, absl::nullopt);
return;
}
const int response_code = request->GetResponseCode();
DVLOG(3) << "OnResponseStarted (response code " << response_code << ")";
if (request->response_info().connection_info ==
HttpResponseInfo::CONNECTION_INFO_HTTP2) {
if (response_code == HTTP_OK) {
owner_->PerformUpgrade();
return;
}
owner_->ReportFailure(net_error, absl::nullopt);
return;
}
switch (response_code) {
case HTTP_SWITCHING_PROTOCOLS:
owner_->PerformUpgrade();
return;
case HTTP_UNAUTHORIZED:
owner_->ReportFailureWithMessage(
"HTTP Authentication failed; no valid credentials available",
net_error, response_code);
return;
case HTTP_PROXY_AUTHENTICATION_REQUIRED:
owner_->ReportFailureWithMessage("Proxy authentication failed", net_error,
response_code);
return;
default:
owner_->ReportFailure(net_error, response_code);
}
}
void Delegate::OnAuthRequired(URLRequest* request,
const AuthChallengeInfo& auth_info) {
absl::optional<AuthCredentials> credentials;
// This base::Unretained(this) relies on an assumption that |callback| can
// be called called during the opening handshake.
int rv = owner_->connect_delegate()->OnAuthRequired(
auth_info, request->response_headers(),
request->GetResponseRemoteEndpoint(),
base::BindOnce(&Delegate::OnAuthRequiredComplete, base::Unretained(this),
request),
&credentials);
request->LogBlockedBy("WebSocketStream::Delegate::OnAuthRequired");
if (rv == ERR_IO_PENDING)
return;
if (rv != OK) {
request->LogUnblocked();
owner_->ReportFailure(rv, absl::nullopt);
return;
}
OnAuthRequiredComplete(request, nullptr);
}
void Delegate::OnAuthRequiredComplete(URLRequest* request,
const AuthCredentials* credentials) {
request->LogUnblocked();
if (!credentials) {
request->CancelAuth();
return;
}
request->SetAuth(*credentials);
}
void Delegate::OnCertificateRequested(URLRequest* request,
SSLCertRequestInfo* cert_request_info) {
// This method is called when a client certificate is requested, and the
// request context does not already contain a client certificate selection for
// the endpoint. In this case, a main frame resource request would pop-up UI
// to permit selection of a client certificate, but since WebSockets are
// sub-resources they should not pop-up UI and so there is nothing more we can
// do.
request->Cancel();
}
void Delegate::OnSSLCertificateError(URLRequest* request,
int net_error,
const SSLInfo& ssl_info,
bool fatal) {
owner_->connect_delegate()->OnSSLCertificateError(
std::make_unique<SSLErrorCallbacks>(request), net_error, ssl_info, fatal);
}
void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) {
NOTREACHED();
}
} // namespace
WebSocketStreamRequest::~WebSocketStreamRequest() = default;
WebSocketStream::WebSocketStream() = default;
WebSocketStream::~WebSocketStream() = default;
WebSocketStream::ConnectDelegate::~ConnectDelegate() = default;
std::unique_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream(
const GURL& socket_url,
const std::vector<std::string>& requested_subprotocols,
const url::Origin& origin,
const SiteForCookies& site_for_cookies,
const IsolationInfo& isolation_info,
const HttpRequestHeaders& additional_headers,
URLRequestContext* url_request_context,
const NetLogWithSource& net_log,
NetworkTrafficAnnotationTag traffic_annotation,
std::unique_ptr<ConnectDelegate> connect_delegate) {
auto request = std::make_unique<WebSocketStreamRequestImpl>(
socket_url, requested_subprotocols, url_request_context, origin,
site_for_cookies, isolation_info, additional_headers, traffic_annotation,
std::move(connect_delegate), nullptr);
request->Start(std::make_unique<base::OneShotTimer>());
return std::move(request);
}
std::unique_ptr<WebSocketStreamRequest>
WebSocketStream::CreateAndConnectStreamForTesting(
const GURL& socket_url,
const std::vector<std::string>& requested_subprotocols,
const url::Origin& origin,
const SiteForCookies& site_for_cookies,
const IsolationInfo& isolation_info,
const HttpRequestHeaders& additional_headers,
URLRequestContext* url_request_context,
const NetLogWithSource& net_log,
NetworkTrafficAnnotationTag traffic_annotation,
std::unique_ptr<WebSocketStream::ConnectDelegate> connect_delegate,
std::unique_ptr<base::OneShotTimer> timer,
std::unique_ptr<WebSocketStreamRequestAPI> api_delegate) {
auto request = std::make_unique<WebSocketStreamRequestImpl>(
socket_url, requested_subprotocols, url_request_context, origin,
site_for_cookies, isolation_info, additional_headers, traffic_annotation,
std::move(connect_delegate), std::move(api_delegate));
request->Start(std::move(timer));
return std::move(request);
}
} // namespace net