// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/server/web_socket.h" #include #include "base/base64.h" #include "base/logging.h" #include "base/sha1.h" #include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" #include "base/sys_byteorder.h" #include "net/server/http_connection.h" #include "net/server/http_server.h" #include "net/server/http_server_request_info.h" #include "net/server/http_server_response_info.h" #include "net/server/web_socket_encoder.h" #include "net/websockets/websocket_deflate_parameters.h" #include "net/websockets/websocket_extension.h" #include "net/websockets/websocket_handshake_constants.h" namespace net { namespace { std::string ExtensionsHeaderString( const std::vector& extensions) { if (extensions.empty()) return std::string(); std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString(); for (size_t i = 1; i < extensions.size(); ++i) result += ", " + extensions[i].ToString(); return result + "\r\n"; } std::string ValidResponseString( const std::string& accept_hash, const std::vector extensions) { return base::StringPrintf( "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" "Upgrade: WebSocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" "%s" "\r\n", accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str()); } } // namespace WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) : server_(server), connection_(connection), closed_(false) {} WebSocket::~WebSocket() = default; void WebSocket::Accept(const HttpServerRequestInfo& request) { std::string version = request.GetHeaderValue("sec-websocket-version"); if (version != "8" && version != "13") { SendErrorResponse("Invalid request format. The version is not valid."); return; } std::string key = request.GetHeaderValue("sec-websocket-key"); if (key.empty()) { SendErrorResponse( "Invalid request format. Sec-WebSocket-Key is empty or isn't " "specified."); return; } std::string encoded_hash; base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid), &encoded_hash); std::vector response_extensions; auto i = request.headers.find("sec-websocket-extensions"); if (i == request.headers.end()) { encoder_ = WebSocketEncoder::CreateServer(); } else { WebSocketDeflateParameters params; encoder_ = WebSocketEncoder::CreateServer(i->second, ¶ms); if (!encoder_) { Fail(); return; } if (encoder_->deflate_enabled()) { DCHECK(params.IsValidAsResponse()); response_extensions.push_back(params.AsExtension()); } } server_->SendRaw(connection_->id(), ValidResponseString(encoded_hash, response_extensions)); } WebSocket::ParseResult WebSocket::Read(std::string* message) { if (closed_) return FRAME_CLOSE; if (!encoder_) { // RFC6455, section 4.1 says "Once the client's opening handshake has been // sent, the client MUST wait for a response from the server before sending // any further data". If |encoder_| is null here, ::Accept either has not // been called at all, or has rejected a request rather than producing // a server handshake. Either way, the client clearly couldn't have gotten // a proper server handshake, so error out, especially since this method // can't proceed without an |encoder_|. return FRAME_ERROR; } HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); int bytes_consumed = 0; ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); if (result == FRAME_OK) read_buf->DidConsume(bytes_consumed); if (result == FRAME_CLOSE) closed_ = true; return result; } void WebSocket::Send(const std::string& message) { if (closed_) return; std::string encoded; encoder_->EncodeFrame(message, 0, &encoded); server_->SendRaw(connection_->id(), encoded); } void WebSocket::Fail() { closed_ = true; // TODO(yhirano): The server SHOULD log the problem. server_->Close(connection_->id()); } void WebSocket::SendErrorResponse(const std::string& message) { if (closed_) return; closed_ = true; server_->Send500(connection_->id(), message); } } // namespace net