diff options
author | Jocelyn Turcotte <jocelyn.turcotte@digia.com> | 2014-08-08 14:30:41 +0200 |
---|---|---|
committer | Jocelyn Turcotte <jocelyn.turcotte@digia.com> | 2014-08-12 13:49:54 +0200 |
commit | ab0a50979b9eb4dfa3320eff7e187e41efedf7a9 (patch) | |
tree | 498dfb8a97ff3361a9f7486863a52bb4e26bb898 /chromium/net/websockets | |
parent | 4ce69f7403811819800e7c5ae1318b2647e778d1 (diff) | |
download | qtwebengine-chromium-ab0a50979b9eb4dfa3320eff7e187e41efedf7a9.tar.gz |
Update Chromium to beta version 37.0.2062.68
Change-Id: I188e3b5aff1bec75566014291b654eb19f5bc8ca
Reviewed-by: Andras Becsi <andras.becsi@digia.com>
Diffstat (limited to 'chromium/net/websockets')
38 files changed, 4177 insertions, 782 deletions
diff --git a/chromium/net/websockets/OWNERS b/chromium/net/websockets/OWNERS index 8ef489a002e..6c81c618dad 100644 --- a/chromium/net/websockets/OWNERS +++ b/chromium/net/websockets/OWNERS @@ -1,8 +1,3 @@ tyoshino@chromium.org - -# Have been inactive for a while. -yutak@chromium.org -toyoshim@chromium.org - -# On leave -bashi@chromium.org +ricea@chromium.org +yhirano@chromium.org diff --git a/chromium/net/websockets/README b/chromium/net/websockets/README index fab4c203bc8..c428efbb851 100644 --- a/chromium/net/websockets/README +++ b/chromium/net/websockets/README @@ -64,6 +64,10 @@ websocket_handshake_stream_base.h websocket_handshake_stream_create_helper.cc websocket_handshake_stream_create_helper.h websocket_handshake_stream_create_helper_test.cc +websocket_handshake_request_info.cc +websocket_handshake_request_info.h +websocket_handshake_response_info.cc +websocket_handshake_response_info.h websocket_inflater.cc websocket_inflater.h websocket_inflater_test.cc diff --git a/chromium/net/websockets/websocket_basic_handshake_stream.cc b/chromium/net/websockets/websocket_basic_handshake_stream.cc index 73a10453be0..810ac75a6cb 100644 --- a/chromium/net/websockets/websocket_basic_handshake_stream.cc +++ b/chromium/net/websockets/websocket_basic_handshake_stream.cc @@ -6,13 +6,23 @@ #include <algorithm> #include <iterator> +#include <set> +#include <string> +#include <vector> #include "base/base64.h" #include "base/basictypes.h" #include "base/bind.h" #include "base/containers/hash_tables.h" +#include "base/logging.h" +#include "base/metrics/histogram.h" +#include "base/metrics/sparse_histogram.h" #include "base/stl_util.h" +#include "base/strings/string_number_conversions.h" +#include "base/strings/string_piece.h" #include "base/strings/string_util.h" +#include "base/strings/stringprintf.h" +#include "base/time/time.h" #include "crypto/random.h" #include "net/http/http_request_headers.h" #include "net/http/http_request_info.h" @@ -22,13 +32,51 @@ #include "net/http/http_stream_parser.h" #include "net/socket/client_socket_handle.h" #include "net/websockets/websocket_basic_stream.h" +#include "net/websockets/websocket_deflate_predictor.h" +#include "net/websockets/websocket_deflate_predictor_impl.h" +#include "net/websockets/websocket_deflate_stream.h" +#include "net/websockets/websocket_deflater.h" +#include "net/websockets/websocket_extension_parser.h" #include "net/websockets/websocket_handshake_constants.h" #include "net/websockets/websocket_handshake_handler.h" +#include "net/websockets/websocket_handshake_request_info.h" +#include "net/websockets/websocket_handshake_response_info.h" #include "net/websockets/websocket_stream.h" namespace net { + +// TODO(ricea): If more extensions are added, replace this with a more general +// mechanism. +struct WebSocketExtensionParams { + WebSocketExtensionParams() + : deflate_enabled(false), + client_window_bits(15), + deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} + + bool deflate_enabled; + int client_window_bits; + WebSocketDeflater::ContextTakeOverMode deflate_mode; +}; + namespace { +enum GetHeaderResult { + GET_HEADER_OK, + GET_HEADER_MISSING, + GET_HEADER_MULTIPLE, +}; + +std::string MissingHeaderMessage(const std::string& header_name) { + return std::string("'") + header_name + "' header is missing"; +} + +std::string MultipleHeaderValuesMessage(const std::string& header_name) { + return + std::string("'") + + header_name + + "' header must not appear more than once in a response"; +} + std::string GenerateHandshakeChallenge() { std::string raw_challenge(websockets::kRawChallengeLength, '\0'); crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); @@ -45,59 +93,248 @@ void AddVectorHeaderIfNonEmpty(const char* name, headers->SetHeader(name, JoinString(value, ", ")); } -// If |case_sensitive| is false, then |value| must be in lower-case. -bool ValidateSingleTokenHeader( - const scoped_refptr<HttpResponseHeaders>& headers, - const base::StringPiece& name, - const std::string& value, - bool case_sensitive) { +GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers, + const base::StringPiece& name, + std::string* value) { void* state = NULL; - std::string token; - int tokens = 0; - bool has_value = false; - while (headers->EnumerateHeader(&state, name, &token)) { - if (++tokens > 1) - return false; - has_value = case_sensitive ? value == token - : LowerCaseEqualsASCII(token, value.c_str()); + size_t num_values = 0; + std::string temp_value; + while (headers->EnumerateHeader(&state, name, &temp_value)) { + if (++num_values > 1) + return GET_HEADER_MULTIPLE; + *value = temp_value; + } + return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING; +} + +bool ValidateHeaderHasSingleValue(GetHeaderResult result, + const std::string& header_name, + std::string* failure_message) { + if (result == GET_HEADER_MISSING) { + *failure_message = MissingHeaderMessage(header_name); + return false; + } + if (result == GET_HEADER_MULTIPLE) { + *failure_message = MultipleHeaderValuesMessage(header_name); + return false; + } + DCHECK_EQ(result, GET_HEADER_OK); + return true; +} + +bool ValidateUpgrade(const HttpResponseHeaders* headers, + std::string* failure_message) { + std::string value; + GetHeaderResult result = + GetSingleHeaderValue(headers, websockets::kUpgrade, &value); + if (!ValidateHeaderHasSingleValue(result, + websockets::kUpgrade, + failure_message)) { + return false; } - return has_value; + + if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { + *failure_message = + "'Upgrade' header value is not 'WebSocket': " + value; + return false; + } + return true; +} + +bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, + const std::string& expected, + std::string* failure_message) { + std::string actual; + GetHeaderResult result = + GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); + if (!ValidateHeaderHasSingleValue(result, + websockets::kSecWebSocketAccept, + failure_message)) { + return false; + } + + if (expected != actual) { + *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; + return false; + } + return true; +} + +bool ValidateConnection(const HttpResponseHeaders* headers, + std::string* failure_message) { + // Connection header is permitted to contain other tokens. + if (!headers->HasHeader(HttpRequestHeaders::kConnection)) { + *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection); + return false; + } + if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection, + websockets::kUpgrade)) { + *failure_message = "'Connection' header value must contain 'Upgrade'"; + return false; + } + return true; } bool ValidateSubProtocol( - const scoped_refptr<HttpResponseHeaders>& headers, + const HttpResponseHeaders* headers, const std::vector<std::string>& requested_sub_protocols, - std::string* sub_protocol) { + std::string* sub_protocol, + std::string* failure_message) { void* state = NULL; - std::string token; + std::string value; base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), requested_sub_protocols.end()); - int accepted = 0; - while (headers->EnumerateHeader( - &state, websockets::kSecWebSocketProtocol, &token)) { - if (requested_set.count(token) == 0) - return false; + int count = 0; + bool has_multiple_protocols = false; + bool has_invalid_protocol = false; + + while (!has_invalid_protocol || !has_multiple_protocols) { + std::string temp_value; + if (!headers->EnumerateHeader( + &state, websockets::kSecWebSocketProtocol, &temp_value)) + break; + value = temp_value; + if (requested_set.count(value) == 0) + has_invalid_protocol = true; + if (++count > 1) + has_multiple_protocols = true; + } - *sub_protocol = token; - // The server is only allowed to accept one protocol. - if (++accepted > 1) - return false; + if (has_multiple_protocols) { + *failure_message = + MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); + return false; + } else if (count > 0 && requested_sub_protocols.size() == 0) { + *failure_message = + std::string("Response must not include 'Sec-WebSocket-Protocol' " + "header if not present in request: ") + + value; + return false; + } else if (has_invalid_protocol) { + *failure_message = + "'Sec-WebSocket-Protocol' header value '" + + value + + "' in response does not match any of sent values"; + return false; + } else if (requested_sub_protocols.size() > 0 && count == 0) { + *failure_message = + "Sent non-empty 'Sec-WebSocket-Protocol' header " + "but no response was received"; + return false; + } + *sub_protocol = value; + return true; +} + +bool DeflateError(std::string* message, const base::StringPiece& piece) { + *message = "Error in permessage-deflate: "; + piece.AppendToString(message); + return false; +} + +bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, + std::string* failure_message, + WebSocketExtensionParams* params) { + static const char kClientPrefix[] = "client_"; + static const char kServerPrefix[] = "server_"; + static const char kNoContextTakeover[] = "no_context_takeover"; + static const char kMaxWindowBits[] = "max_window_bits"; + const size_t kPrefixLen = arraysize(kClientPrefix) - 1; + COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, + the_strings_server_and_client_must_be_the_same_length); + typedef std::vector<WebSocketExtension::Parameter> ParameterVector; + + DCHECK_EQ("permessage-deflate", extension.name()); + const ParameterVector& parameters = extension.parameters(); + std::set<std::string> seen_names; + for (ParameterVector::const_iterator it = parameters.begin(); + it != parameters.end(); ++it) { + const std::string& name = it->name(); + if (seen_names.count(name) != 0) { + return DeflateError( + failure_message, + "Received duplicate permessage-deflate extension parameter " + name); + } + seen_names.insert(name); + const std::string client_or_server(name, 0, kPrefixLen); + const bool is_client = (client_or_server == kClientPrefix); + if (!is_client && client_or_server != kServerPrefix) { + return DeflateError( + failure_message, + "Received an unexpected permessage-deflate extension parameter"); + } + const std::string rest(name, kPrefixLen); + if (rest == kNoContextTakeover) { + if (it->HasValue()) { + return DeflateError(failure_message, + "Received invalid " + name + " parameter"); + } + if (is_client) + params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; + } else if (rest == kMaxWindowBits) { + if (!it->HasValue()) + return DeflateError(failure_message, name + " must have value"); + int bits = 0; + if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || + it->value()[0] == '0' || + it->value().find_first_not_of("0123456789") != std::string::npos) { + return DeflateError(failure_message, + "Received invalid " + name + " parameter"); + } + if (is_client) + params->client_window_bits = bits; + } else { + return DeflateError( + failure_message, + "Received an unexpected permessage-deflate extension parameter"); + } } - // If the browser requested > 0 protocols, the server is required to accept - // one. - return requested_set.empty() || accepted == 1; + params->deflate_enabled = true; + return true; } -bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers, +bool ValidateExtensions(const HttpResponseHeaders* headers, const std::vector<std::string>& requested_extensions, - std::string* extensions) { + std::string* extensions, + std::string* failure_message, + WebSocketExtensionParams* params) { void* state = NULL; - std::string token; + std::string value; + std::vector<std::string> accepted_extensions; + // TODO(ricea): If adding support for additional extensions, generalise this + // code. + bool seen_permessage_deflate = false; while (headers->EnumerateHeader( - &state, websockets::kSecWebSocketExtensions, &token)) { - // TODO(ricea): Accept permessage-deflate with valid parameters. - return false; + &state, websockets::kSecWebSocketExtensions, &value)) { + WebSocketExtensionParser parser; + parser.Parse(value); + if (parser.has_error()) { + // TODO(yhirano) Set appropriate failure message. + *failure_message = + "'Sec-WebSocket-Extensions' header value is " + "rejected by the parser: " + + value; + return false; + } + if (parser.extension().name() == "permessage-deflate") { + if (seen_permessage_deflate) { + *failure_message = "Received duplicate permessage-deflate response"; + return false; + } + seen_permessage_deflate = true; + if (!ValidatePerMessageDeflateExtension( + parser.extension(), failure_message, params)) + return false; + } else { + *failure_message = + "Found an unsupported extension '" + + parser.extension().name() + + "' in 'Sec-WebSocket-Extensions' header"; + return false; + } + accepted_extensions.push_back(value); } + *extensions = JoinString(accepted_extensions, ", "); return true; } @@ -105,13 +342,20 @@ bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers, WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( scoped_ptr<ClientSocketHandle> connection, + WebSocketStream::ConnectDelegate* connect_delegate, bool using_proxy, std::vector<std::string> requested_sub_protocols, - std::vector<std::string> requested_extensions) + std::vector<std::string> requested_extensions, + std::string* failure_message) : state_(connection.release(), using_proxy), + connect_delegate_(connect_delegate), http_response_info_(NULL), requested_sub_protocols_(requested_sub_protocols), - requested_extensions_(requested_extensions) {} + requested_extensions_(requested_extensions), + failure_message_(failure_message) { + DCHECK(connect_delegate); + DCHECK(failure_message); +} WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {} @@ -120,6 +364,7 @@ int WebSocketBasicHandshakeStream::InitializeStream( RequestPriority priority, const BoundNetLog& net_log, const CompletionCallback& callback) { + url_ = request_info->url; state_.Initialize(request_info, priority, net_log, callback); return OK; } @@ -152,16 +397,22 @@ int WebSocketBasicHandshakeStream::SendRequest( } enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); - AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, - requested_sub_protocols_, - &enriched_headers); AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, requested_extensions_, &enriched_headers); + AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, + requested_sub_protocols_, + &enriched_headers); ComputeSecWebSocketAccept(handshake_challenge, &handshake_challenge_response_); + DCHECK(connect_delegate_); + scoped_ptr<WebSocketHandshakeRequestInfo> request( + new WebSocketHandshakeRequestInfo(url_, base::Time::Now())); + request->headers.CopyFrom(enriched_headers); + connect_delegate_->OnStartOpeningHandshake(request.Pass()); + return parser()->SendRequest( state_.GenerateRequestLine(), enriched_headers, response, callback); } @@ -176,11 +427,9 @@ int WebSocketBasicHandshakeStream::ReadResponseHeaders( base::Bind(&WebSocketBasicHandshakeStream::ReadResponseHeadersCallback, base::Unretained(this), callback)); - return rv == OK ? ValidateResponse() : rv; -} - -const HttpResponseInfo* WebSocketBasicHandshakeStream::GetResponseInfo() const { - return parser()->GetResponseInfo(); + if (rv == ERR_IO_PENDING) + return rv; + return ValidateResponse(rv); } int WebSocketBasicHandshakeStream::ReadResponseBody( @@ -250,16 +499,30 @@ void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { } scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { - // TODO(ricea): Add deflate support. - // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make // sure it does not touch it again before it is destroyed. state_.DeleteParser(); - return scoped_ptr<WebSocketStream>( + scoped_ptr<WebSocketStream> basic_stream( new WebSocketBasicStream(state_.ReleaseConnection(), state_.read_buf(), sub_protocol_, extensions_)); + DCHECK(extension_params_.get()); + if (extension_params_->deflate_enabled) { + UMA_HISTOGRAM_ENUMERATION( + "Net.WebSocket.DeflateMode", + extension_params_->deflate_mode, + WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); + + return scoped_ptr<WebSocketStream>( + new WebSocketDeflateStream(basic_stream.Pass(), + extension_params_->deflate_mode, + extension_params_->client_window_bits, + scoped_ptr<WebSocketDeflatePredictor>( + new WebSocketDeflatePredictorImpl))); + } else { + return basic_stream.Pass(); + } } void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( @@ -270,49 +533,102 @@ void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( const CompletionCallback& callback, int result) { - if (result == OK) - result = ValidateResponse(); - callback.Run(result); + callback.Run(ValidateResponse(result)); } -int WebSocketBasicHandshakeStream::ValidateResponse() { +void WebSocketBasicHandshakeStream::OnFinishOpeningHandshake() { + DCHECK(connect_delegate_); DCHECK(http_response_info_); - const scoped_refptr<HttpResponseHeaders>& headers = - http_response_info_->headers; - - switch (headers->response_code()) { - case HTTP_SWITCHING_PROTOCOLS: - return ValidateUpgradeResponse(headers); - - // We need to pass these through for authentication to work. - case HTTP_UNAUTHORIZED: - case HTTP_PROXY_AUTHENTICATION_REQUIRED: - return OK; - - // Other status codes are potentially risky (see the warnings in the - // WHATWG WebSocket API spec) and so are dropped by default. - default: - return ERR_INVALID_RESPONSE; + scoped_refptr<HttpResponseHeaders> headers = http_response_info_->headers; + // If the headers are too large, HttpStreamParser will just not parse them at + // all. + if (headers) { + scoped_ptr<WebSocketHandshakeResponseInfo> response( + new WebSocketHandshakeResponseInfo(url_, + headers->response_code(), + headers->GetStatusText(), + headers, + http_response_info_->response_time)); + connect_delegate_->OnFinishOpeningHandshake(response.Pass()); + } +} + +int WebSocketBasicHandshakeStream::ValidateResponse(int rv) { + DCHECK(http_response_info_); + // Most net errors happen during connection, so they are not seen by this + // method. The histogram for error codes is created in + // Delegate::OnResponseStarted in websocket_stream.cc instead. + if (rv >= 0) { + const HttpResponseHeaders* headers = http_response_info_->headers.get(); + const int response_code = headers->response_code(); + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ResponseCode", response_code); + switch (response_code) { + case HTTP_SWITCHING_PROTOCOLS: + OnFinishOpeningHandshake(); + return ValidateUpgradeResponse(headers); + + // We need to pass these through for authentication to work. + case HTTP_UNAUTHORIZED: + case HTTP_PROXY_AUTHENTICATION_REQUIRED: + return OK; + + // Other status codes are potentially risky (see the warnings in the + // WHATWG WebSocket API spec) and so are dropped by default. + default: + // A WebSocket server cannot be using HTTP/0.9, so if we see version + // 0.9, it means the response was garbage. + // Reporting "Unexpected response code: 200" in this case is not + // helpful, so use a different error message. + if (headers->GetHttpVersion() == HttpVersion(0, 9)) { + set_failure_message( + "Error during WebSocket handshake: Invalid status line"); + } else { + set_failure_message(base::StringPrintf( + "Error during WebSocket handshake: Unexpected response code: %d", + headers->response_code())); + } + OnFinishOpeningHandshake(); + return ERR_INVALID_RESPONSE; + } + } else { + if (rv == ERR_EMPTY_RESPONSE) { + set_failure_message( + "Connection closed before receiving a handshake response"); + return rv; + } + set_failure_message(std::string("Error during WebSocket handshake: ") + + ErrorToString(rv)); + OnFinishOpeningHandshake(); + return rv; } } int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( - const scoped_refptr<HttpResponseHeaders>& headers) { - if (ValidateSingleTokenHeader(headers, - websockets::kUpgrade, - websockets::kWebSocketLowercase, - false) && - ValidateSingleTokenHeader(headers, - websockets::kSecWebSocketAccept, - handshake_challenge_response_, - true) && - headers->HasHeaderValue(HttpRequestHeaders::kConnection, - websockets::kUpgrade) && - ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) && - ValidateExtensions(headers, requested_extensions_, &extensions_)) { + const HttpResponseHeaders* headers) { + extension_params_.reset(new WebSocketExtensionParams); + std::string failure_message; + if (ValidateUpgrade(headers, &failure_message) && + ValidateSecWebSocketAccept( + headers, handshake_challenge_response_, &failure_message) && + ValidateConnection(headers, &failure_message) && + ValidateSubProtocol(headers, + requested_sub_protocols_, + &sub_protocol_, + &failure_message) && + ValidateExtensions(headers, + requested_extensions_, + &extensions_, + &failure_message, + extension_params_.get())) { return OK; } + set_failure_message("Error during WebSocket handshake: " + failure_message); return ERR_INVALID_RESPONSE; } +void WebSocketBasicHandshakeStream::set_failure_message( + const std::string& failure_message) { + *failure_message_ = failure_message; +} + } // namespace net diff --git a/chromium/net/websockets/websocket_basic_handshake_stream.h b/chromium/net/websockets/websocket_basic_handshake_stream.h index 2e5b628cde6..51f0c4db1c5 100644 --- a/chromium/net/websockets/websocket_basic_handshake_stream.h +++ b/chromium/net/websockets/websocket_basic_handshake_stream.h @@ -13,6 +13,7 @@ #include "net/base/net_export.h" #include "net/http/http_basic_state.h" #include "net/websockets/websocket_handshake_stream_base.h" +#include "url/gurl.h" namespace net { @@ -21,14 +22,19 @@ class HttpResponseHeaders; class HttpResponseInfo; class HttpStreamParser; +struct WebSocketExtensionParams; + class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream : public WebSocketHandshakeStreamBase { public: + // |connect_delegate| and |failure_message| must out-live this object. WebSocketBasicHandshakeStream( scoped_ptr<ClientSocketHandle> connection, + WebSocketStream::ConnectDelegate* connect_delegate, bool using_proxy, std::vector<std::string> requested_sub_protocols, - std::vector<std::string> requested_extensions); + std::vector<std::string> requested_extensions, + std::string* failure_message); virtual ~WebSocketBasicHandshakeStream(); @@ -41,7 +47,6 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream HttpResponseInfo* response, const CompletionCallback& callback) OVERRIDE; virtual int ReadResponseHeaders(const CompletionCallback& callback) OVERRIDE; - virtual const HttpResponseInfo* GetResponseInfo() const OVERRIDE; virtual int ReadResponseBody(IOBuffer* buf, int buf_len, const CompletionCallback& callback) OVERRIDE; @@ -78,20 +83,29 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream void ReadResponseHeadersCallback(const CompletionCallback& callback, int result); - // Validates the response from the server and returns OK or - // ERR_INVALID_RESPONSE. - int ValidateResponse(); + void OnFinishOpeningHandshake(); + + // Validates the response and sends the finished handshake event. + int ValidateResponse(int rv); // Check that the headers are well-formed for a 101 response, and returns // OK if they are, otherwise returns ERR_INVALID_RESPONSE. - int ValidateUpgradeResponse( - const scoped_refptr<HttpResponseHeaders>& headers); + int ValidateUpgradeResponse(const HttpResponseHeaders* headers); HttpStreamParser* parser() const { return state_.parser(); } + void set_failure_message(const std::string& failure_message); + + // The request URL. + GURL url_; + // HttpBasicState holds most of the handshake-related state. HttpBasicState state_; + // Owned by another object. + // |connect_delegate| will live during the lifetime of this object. + WebSocketStream::ConnectDelegate* connect_delegate_; + // This is stored in SendRequest() for use by ReadResponseHeaders(). HttpResponseInfo* http_response_info_; @@ -114,6 +128,12 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream // The extension(s) selected by the server. std::string extensions_; + // The extension parameters. The class is defined in the implementation file + // to avoid including extension-related header files here. + scoped_ptr<WebSocketExtensionParams> extension_params_; + + std::string* failure_message_; + DISALLOW_COPY_AND_ASSIGN(WebSocketBasicHandshakeStream); }; diff --git a/chromium/net/websockets/websocket_basic_stream.cc b/chromium/net/websockets/websocket_basic_stream.cc index 12472b48d1d..fd2766bcb94 100644 --- a/chromium/net/websockets/websocket_basic_stream.cc +++ b/chromium/net/websockets/websocket_basic_stream.cc @@ -12,7 +12,7 @@ #include "base/basictypes.h" #include "base/bind.h" #include "base/logging.h" -#include "base/safe_numerics.h" +#include "base/numerics/safe_conversions.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" #include "net/socket/client_socket_handle.h" @@ -157,13 +157,15 @@ int WebSocketBasicStream::WriteFrames(ScopedVector<WebSocketFrame>* frames, dest += result; remaining_size -= result; - const char* const frame_data = frame->data->data(); const int frame_size = frame->header.payload_length; - CHECK_GE(remaining_size, frame_size); - std::copy(frame_data, frame_data + frame_size, dest); - MaskWebSocketFramePayload(mask, 0, dest, frame_size); - dest += frame_size; - remaining_size -= frame_size; + if (frame_size > 0) { + CHECK_GE(remaining_size, frame_size); + const char* const frame_data = frame->data->data(); + std::copy(frame_data, frame_data + frame_size, dest); + MaskWebSocketFramePayload(mask, 0, dest, frame_size); + dest += frame_size; + remaining_size -= frame_size; + } } DCHECK_EQ(0, remaining_size) << "Buffer size calculation was wrong; " << remaining_size << " bytes left over."; @@ -347,10 +349,10 @@ int WebSocketBasicStream::ConvertChunkToFrame( // header. A check for exact equality can only be used when the whole frame // arrives in one chunk. DCHECK_GE(current_frame_header_->payload_length, - base::checked_numeric_cast<uint64>(chunk_size)); + base::checked_cast<uint64>(chunk_size)); DCHECK(!is_first_chunk || !is_final_chunk || current_frame_header_->payload_length == - base::checked_numeric_cast<uint64>(chunk_size)); + base::checked_cast<uint64>(chunk_size)); // Convert the chunk to a complete frame. *frame = CreateFrame(is_final_chunk, data_buffer); @@ -376,9 +378,16 @@ scoped_ptr<WebSocketFrame> WebSocketBasicStream::CreateFrame( result_frame->header.payload_length = data_size; result_frame->data = data; // Ensure that opcodes Text and Binary are only used for the first frame in - // the message. - if (WebSocketFrameHeader::IsKnownDataOpCode(opcode)) + // the message. Also clear the reserved bits. + // TODO(ricea): If a future extension requires the reserved bits to be + // retained on continuation frames, make this behaviour conditional on a + // flag set at construction time. + if (!is_final_chunk && WebSocketFrameHeader::IsKnownDataOpCode(opcode)) { current_frame_header_->opcode = WebSocketFrameHeader::kOpCodeContinuation; + current_frame_header_->reserved1 = false; + current_frame_header_->reserved2 = false; + current_frame_header_->reserved3 = false; + } } // Make sure that a frame header is not applied to any chunks that do not // belong to it. diff --git a/chromium/net/websockets/websocket_basic_stream_test.cc b/chromium/net/websockets/websocket_basic_stream_test.cc index cb936a5f01e..71af0797818 100644 --- a/chromium/net/websockets/websocket_basic_stream_test.cc +++ b/chromium/net/websockets/websocket_basic_stream_test.cc @@ -13,8 +13,8 @@ #include <string> #include "base/basictypes.h" +#include "base/big_endian.h" #include "base/port.h" -#include "net/base/big_endian.h" #include "net/base/capturing_net_log.h" #include "net/base/test_completion_callback.h" #include "net/socket/socket_test_util.h" @@ -57,6 +57,8 @@ WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(CloseFrame, "\x88\x09\x03\xe8occludo"); WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(WriteFrame, "\x81\x85\x00\x00\x00\x00Write"); +WEBSOCKET_BASIC_STREAM_TEST_DEFINE_CONSTANT(MaskedEmptyPong, + "\x8A\x80\x00\x00\x00\x00"); const WebSocketMaskingKey kNulMaskingKey = {{'\0', '\0', '\0', '\0'}}; const WebSocketMaskingKey kNonNulMaskingKey = { {'\x0d', '\x1b', '\x06', '\x17'}}; @@ -801,7 +803,7 @@ TEST_F(WebSocketBasicStreamSocketChunkedReadTest, OneMegFrame) { (kWireSize + kReadBufferSize - 1) / kReadBufferSize; scoped_ptr<char[]> big_frame(new char[kWireSize]); memcpy(big_frame.get(), "\x81\x7F", 2); - WriteBigEndian(big_frame.get() + 2, kPayloadSize); + base::WriteBigEndian(big_frame.get() + 2, kPayloadSize); memset(big_frame.get() + kLargeFrameHeaderSize, 'A', kPayloadSize); CreateChunkedRead(ASYNC, @@ -826,6 +828,33 @@ TEST_F(WebSocketBasicStreamSocketChunkedReadTest, OneMegFrame) { } } +// A frame with reserved flag(s) set that arrives in chunks should only have the +// reserved flag(s) set on the first chunk when split. +TEST_F(WebSocketBasicStreamSocketChunkedReadTest, ReservedFlagCleared) { + static const char kReservedFlagFrame[] = "\x41\x05Hello"; + const size_t kReservedFlagFrameSize = arraysize(kReservedFlagFrame) - 1; + const size_t kChunkSize = 5; + + CreateChunkedRead(ASYNC, + kReservedFlagFrame, + kReservedFlagFrameSize, + kChunkSize, + 2, + LAST_FRAME_BIG); + + TestCompletionCallback cb[2]; + ASSERT_EQ(ERR_IO_PENDING, stream_->ReadFrames(&frames_, cb[0].callback())); + EXPECT_EQ(OK, cb[0].WaitForResult()); + ASSERT_EQ(1U, frames_.size()); + EXPECT_TRUE(frames_[0]->header.reserved1); + + frames_.clear(); + ASSERT_EQ(ERR_IO_PENDING, stream_->ReadFrames(&frames_, cb[1].callback())); + EXPECT_EQ(OK, cb[1].WaitForResult()); + ASSERT_EQ(1U, frames_.size()); + EXPECT_FALSE(frames_[0]->header.reserved1); +} + // Check that writing a frame all at once works. TEST_F(WebSocketBasicStreamSocketWriteTest, WriteAtOnce) { MockWrite writes[] = {MockWrite(SYNCHRONOUS, kWriteFrame, kWriteFrameSize)}; @@ -856,6 +885,23 @@ TEST_F(WebSocketBasicStreamSocketWriteTest, WriteInBits) { EXPECT_EQ(OK, cb_.WaitForResult()); } +// Check that writing a Pong frame with a NULL body works. +TEST_F(WebSocketBasicStreamSocketWriteTest, WriteNullPong) { + MockWrite writes[] = { + MockWrite(SYNCHRONOUS, kMaskedEmptyPong, kMaskedEmptyPongSize)}; + CreateWriteOnly(writes); + + scoped_ptr<WebSocketFrame> frame( + new WebSocketFrame(WebSocketFrameHeader::kOpCodePong)); + WebSocketFrameHeader& header = frame->header; + header.final = true; + header.masked = true; + header.payload_length = 0; + ScopedVector<WebSocketFrame> frames; + frames.push_back(frame.release()); + EXPECT_EQ(OK, stream_->WriteFrames(&frames, cb_.callback())); +} + // Check that writing with a non-NULL mask works correctly. TEST_F(WebSocketBasicStreamSocketTest, WriteNonNulMask) { std::string masked_frame = std::string("\x81\x88"); diff --git a/chromium/net/websockets/websocket_channel.cc b/chromium/net/websockets/websocket_channel.cc index 61a47b71eb8..12f152f5248 100644 --- a/chromium/net/websockets/websocket_channel.cc +++ b/chromium/net/websockets/websocket_channel.cc @@ -4,28 +4,43 @@ #include "net/websockets/websocket_channel.h" +#include <limits.h> // for INT_MAX + #include <algorithm> +#include <deque> #include "base/basictypes.h" // for size_t +#include "base/big_endian.h" #include "base/bind.h" #include "base/compiler_specific.h" -#include "base/safe_numerics.h" -#include "base/strings/string_util.h" +#include "base/memory/ref_counted.h" +#include "base/memory/weak_ptr.h" +#include "base/message_loop/message_loop.h" +#include "base/metrics/histogram.h" +#include "base/numerics/safe_conversions.h" +#include "base/stl_util.h" +#include "base/strings/stringprintf.h" #include "base/time/time.h" -#include "net/base/big_endian.h" #include "net/base/io_buffer.h" #include "net/base/net_log.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" #include "net/http/http_util.h" #include "net/websockets/websocket_errors.h" #include "net/websockets/websocket_event_interface.h" #include "net/websockets/websocket_frame.h" +#include "net/websockets/websocket_handshake_request_info.h" +#include "net/websockets/websocket_handshake_response_info.h" #include "net/websockets/websocket_mux.h" #include "net/websockets/websocket_stream.h" +#include "url/origin.h" namespace net { namespace { +using base::StreamingUtf8Validator; + const int kDefaultSendQuotaLowWaterMark = 1 << 16; const int kDefaultSendQuotaHighWaterMark = 1 << 17; const size_t kWebSocketCloseCodeLength = 2; @@ -46,12 +61,14 @@ const size_t kMaximumCloseReasonLength = 125 - kWebSocketCloseCodeLength; // used for close codes received from a renderer that we are intending to send // out over the network. See ParseClose() for the restrictions on incoming close // codes. The |code| parameter is type int for convenience of implementation; -// the real type is uint16. +// the real type is uint16. Code 1005 is treated specially; it cannot be set +// explicitly by Javascript but the renderer uses it to indicate we should send +// a Close frame with no payload. bool IsStrictlyValidCloseStatusCode(int code) { static const int kInvalidRanges[] = { // [BAD, OK) 0, 1000, // 1000 is the first valid code - 1005, 1007, // 1005 and 1006 MUST NOT be set. + 1006, 1007, // 1006 MUST NOT be set. 1014, 3000, // 1014 unassigned; 1015 up to 2999 are reserved. 5000, 65536, // Codes above 5000 are invalid. }; @@ -71,6 +88,38 @@ bool IsStrictlyValidCloseStatusCode(int code) { // This function avoids a bunch of boilerplate code. void AllowUnused(ChannelState ALLOW_UNUSED unused) {} +// Sets |name| to the name of the frame type for the given |opcode|. Note that +// for all of Text, Binary and Continuation opcode, this method returns +// "Data frame". +void GetFrameTypeForOpcode(WebSocketFrameHeader::OpCode opcode, + std::string* name) { + switch (opcode) { + case WebSocketFrameHeader::kOpCodeText: // fall-thru + case WebSocketFrameHeader::kOpCodeBinary: // fall-thru + case WebSocketFrameHeader::kOpCodeContinuation: + *name = "Data frame"; + break; + + case WebSocketFrameHeader::kOpCodePing: + *name = "Ping"; + break; + + case WebSocketFrameHeader::kOpCodePong: + *name = "Pong"; + break; + + case WebSocketFrameHeader::kOpCodeClose: + *name = "Close"; + break; + + default: + *name = "Unknown frame type"; + break; + } + + return; +} + } // namespace // A class to encapsulate a set of frames and information about the size of @@ -112,11 +161,30 @@ class WebSocketChannel::ConnectDelegate // |this| may have been deleted. } - virtual void OnFailure(uint16 websocket_error) OVERRIDE { - creator_->OnConnectFailure(websocket_error); + virtual void OnFailure(const std::string& message) OVERRIDE { + creator_->OnConnectFailure(message); // |this| has been deleted. } + virtual void OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE { + creator_->OnStartOpeningHandshake(request.Pass()); + } + + virtual void OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { + creator_->OnFinishOpeningHandshake(response.Pass()); + } + + virtual void OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> + ssl_error_callbacks, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE { + creator_->OnSSLCertificateError( + ssl_error_callbacks.Pass(), ssl_info, fatal); + } + private: // A pointer to the WebSocketChannel that created this object. There is no // danger of this pointer being stale, because deleting the WebSocketChannel @@ -127,6 +195,101 @@ class WebSocketChannel::ConnectDelegate DISALLOW_COPY_AND_ASSIGN(ConnectDelegate); }; +class WebSocketChannel::HandshakeNotificationSender + : public base::SupportsWeakPtr<HandshakeNotificationSender> { + public: + explicit HandshakeNotificationSender(WebSocketChannel* channel); + ~HandshakeNotificationSender(); + + static void Send(base::WeakPtr<HandshakeNotificationSender> sender); + + ChannelState SendImmediately(WebSocketEventInterface* event_interface); + + const WebSocketHandshakeRequestInfo* handshake_request_info() const { + return handshake_request_info_.get(); + } + + void set_handshake_request_info( + scoped_ptr<WebSocketHandshakeRequestInfo> request_info) { + handshake_request_info_ = request_info.Pass(); + } + + const WebSocketHandshakeResponseInfo* handshake_response_info() const { + return handshake_response_info_.get(); + } + + void set_handshake_response_info( + scoped_ptr<WebSocketHandshakeResponseInfo> response_info) { + handshake_response_info_ = response_info.Pass(); + } + + private: + WebSocketChannel* owner_; + scoped_ptr<WebSocketHandshakeRequestInfo> handshake_request_info_; + scoped_ptr<WebSocketHandshakeResponseInfo> handshake_response_info_; +}; + +WebSocketChannel::HandshakeNotificationSender::HandshakeNotificationSender( + WebSocketChannel* channel) + : owner_(channel) {} + +WebSocketChannel::HandshakeNotificationSender::~HandshakeNotificationSender() {} + +void WebSocketChannel::HandshakeNotificationSender::Send( + base::WeakPtr<HandshakeNotificationSender> sender) { + // Do nothing if |sender| is already destructed. + if (sender) { + WebSocketChannel* channel = sender->owner_; + AllowUnused(sender->SendImmediately(channel->event_interface_.get())); + } +} + +ChannelState WebSocketChannel::HandshakeNotificationSender::SendImmediately( + WebSocketEventInterface* event_interface) { + + if (handshake_request_info_.get()) { + if (CHANNEL_DELETED == event_interface->OnStartOpeningHandshake( + handshake_request_info_.Pass())) + return CHANNEL_DELETED; + } + + if (handshake_response_info_.get()) { + if (CHANNEL_DELETED == event_interface->OnFinishOpeningHandshake( + handshake_response_info_.Pass())) + return CHANNEL_DELETED; + + // TODO(yhirano): We can release |this| to save memory because + // there will be no more opening handshake notification. + } + + return CHANNEL_ALIVE; +} + +WebSocketChannel::PendingReceivedFrame::PendingReceivedFrame( + bool final, + WebSocketFrameHeader::OpCode opcode, + const scoped_refptr<IOBuffer>& data, + size_t offset, + size_t size) + : final_(final), + opcode_(opcode), + data_(data), + offset_(offset), + size_(size) {} + +WebSocketChannel::PendingReceivedFrame::~PendingReceivedFrame() {} + +void WebSocketChannel::PendingReceivedFrame::ResetOpcode() { + DCHECK(WebSocketFrameHeader::IsKnownDataOpCode(opcode_)); + opcode_ = WebSocketFrameHeader::kOpCodeContinuation; +} + +void WebSocketChannel::PendingReceivedFrame::DidConsume(size_t bytes) { + DCHECK_LE(offset_, size_); + DCHECK_LE(bytes, size_ - offset_); + offset_ += bytes; +} + WebSocketChannel::WebSocketChannel( scoped_ptr<WebSocketEventInterface> event_interface, URLRequestContext* url_request_context) @@ -135,9 +298,15 @@ WebSocketChannel::WebSocketChannel( send_quota_low_water_mark_(kDefaultSendQuotaLowWaterMark), send_quota_high_water_mark_(kDefaultSendQuotaHighWaterMark), current_send_quota_(0), + current_receive_quota_(0), timeout_(base::TimeDelta::FromSeconds(kClosingHandshakeTimeoutSeconds)), - closing_code_(0), - state_(FRESHLY_CONSTRUCTED) {} + received_close_code_(0), + state_(FRESHLY_CONSTRUCTED), + notification_sender_(new HandshakeNotificationSender(this)), + sending_text_message_(false), + receiving_text_message_(false), + expecting_to_handle_continuation_(false), + initial_frame_forwarded_(false) {} WebSocketChannel::~WebSocketChannel() { // The stream may hold a pointer to read_frames_, and so it needs to be @@ -151,7 +320,7 @@ WebSocketChannel::~WebSocketChannel() { void WebSocketChannel::SendAddChannelRequest( const GURL& socket_url, const std::vector<std::string>& requested_subprotocols, - const GURL& origin) { + const url::Origin& origin) { // Delegate to the tested version. SendAddChannelRequestWithSuppliedCreator( socket_url, @@ -160,6 +329,19 @@ void WebSocketChannel::SendAddChannelRequest( base::Bind(&WebSocketStream::CreateAndConnectStream)); } +void WebSocketChannel::SetState(State new_state) { + DCHECK_NE(state_, new_state); + + if (new_state == CONNECTED) + established_on_ = base::TimeTicks::Now(); + if (state_ == CONNECTED && !established_on_.is_null()) { + UMA_HISTOGRAM_LONG_TIMES( + "Net.WebSocket.Duration", base::TimeTicks::Now() - established_on_); + } + + state_ = new_state; +} + bool WebSocketChannel::InClosingState() const { // The state RECV_CLOSED is not supported here, because it is only used in one // code path and should not leak into the code in general. @@ -182,18 +364,18 @@ void WebSocketChannel::SendFrame(bool fin, return; } if (InClosingState()) { - VLOG(1) << "SendFrame called in state " << state_ - << ". This may be a bug, or a harmless race."; + DVLOG(1) << "SendFrame called in state " << state_ + << ". This may be a bug, or a harmless race."; return; } if (state_ != CONNECTED) { NOTREACHED() << "SendFrame() called in state " << state_; return; } - if (data.size() > base::checked_numeric_cast<size_t>(current_send_quota_)) { - AllowUnused(FailChannel(SEND_GOING_AWAY, - kWebSocketMuxErrorSendQuotaViolation, - "Send quota exceeded")); + if (data.size() > base::checked_cast<size_t>(current_send_quota_)) { + // TODO(ricea): Kill renderer. + AllowUnused( + FailChannel("Send quota exceeded", kWebSocketErrorGoingAway, "")); // |this| has been deleted. return; } @@ -203,30 +385,95 @@ void WebSocketChannel::SendFrame(bool fin, << " data.size()=" << data.size(); return; } + if (op_code == WebSocketFrameHeader::kOpCodeText || + (op_code == WebSocketFrameHeader::kOpCodeContinuation && + sending_text_message_)) { + StreamingUtf8Validator::State state = + outgoing_utf8_validator_.AddBytes(vector_as_array(&data), data.size()); + if (state == StreamingUtf8Validator::INVALID || + (state == StreamingUtf8Validator::VALID_MIDPOINT && fin)) { + // TODO(ricea): Kill renderer. + AllowUnused( + FailChannel("Browser sent a text frame containing invalid UTF-8", + kWebSocketErrorGoingAway, + "")); + // |this| has been deleted. + return; + } + sending_text_message_ = !fin; + DCHECK(!fin || state == StreamingUtf8Validator::VALID_ENDPOINT); + } current_send_quota_ -= data.size(); // TODO(ricea): If current_send_quota_ has dropped below // send_quota_low_water_mark_, it might be good to increase the "low // water mark" and "high water mark", but only if the link to the WebSocket // server is not saturated. - // TODO(ricea): For kOpCodeText, do UTF-8 validation? scoped_refptr<IOBuffer> buffer(new IOBuffer(data.size())); std::copy(data.begin(), data.end(), buffer->data()); - AllowUnused(SendIOBuffer(fin, op_code, buffer, data.size())); + AllowUnused(SendFrameFromIOBuffer(fin, op_code, buffer, data.size())); // |this| may have been deleted. } void WebSocketChannel::SendFlowControl(int64 quota) { DCHECK(state_ == CONNECTING || state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT); - // TODO(ricea): Add interface to WebSocketStream and implement. - // stream_->SendFlowControl(quota); + // TODO(ricea): Kill the renderer if it tries to send us a negative quota + // value or > INT_MAX. + DCHECK_GE(quota, 0); + DCHECK_LE(quota, INT_MAX); + if (!pending_received_frames_.empty()) { + DCHECK_EQ(0, current_receive_quota_); + } + while (!pending_received_frames_.empty() && quota > 0) { + PendingReceivedFrame& front = pending_received_frames_.front(); + const size_t data_size = front.size() - front.offset(); + const size_t bytes_to_send = + std::min(base::checked_cast<size_t>(quota), data_size); + const bool final = front.final() && data_size == bytes_to_send; + const char* data = front.data() ? + front.data()->data() + front.offset() : NULL; + DCHECK(!bytes_to_send || data) << "Non empty data should not be null."; + const std::vector<char> data_vector(data, data + bytes_to_send); + DVLOG(3) << "Sending frame previously split due to quota to the " + << "renderer: quota=" << quota << " data_size=" << data_size + << " bytes_to_send=" << bytes_to_send; + if (event_interface_->OnDataFrame(final, front.opcode(), data_vector) == + CHANNEL_DELETED) + return; + if (bytes_to_send < data_size) { + front.DidConsume(bytes_to_send); + front.ResetOpcode(); + return; + } + const int64 signed_bytes_to_send = base::checked_cast<int64>(bytes_to_send); + DCHECK_GE(quota, signed_bytes_to_send); + quota -= signed_bytes_to_send; + + pending_received_frames_.pop(); + } + // If current_receive_quota_ == 0 then there is no pending ReadFrames() + // operation. + const bool start_read = + current_receive_quota_ == 0 && quota > 0 && + (state_ == CONNECTED || state_ == SEND_CLOSED || state_ == CLOSE_WAIT); + current_receive_quota_ += base::checked_cast<int>(quota); + if (start_read) + AllowUnused(ReadFrames()); + // |this| may have been deleted. } void WebSocketChannel::StartClosingHandshake(uint16 code, const std::string& reason) { if (InClosingState()) { - VLOG(1) << "StartClosingHandshake called in state " << state_ - << ". This may be a bug, or a harmless race."; + DVLOG(1) << "StartClosingHandshake called in state " << state_ + << ". This may be a bug, or a harmless race."; + return; + } + if (state_ == CONNECTING) { + // Abort the in-progress handshake and drop the connection immediately. + stream_request_.reset(); + SetState(CLOSED); + AllowUnused(DoDropChannel(false, kWebSocketErrorAbnormalClosure, "")); return; } if (state_ != CONNECTED) { @@ -242,19 +489,25 @@ void WebSocketChannel::StartClosingHandshake(uint16 code, // errata 3227 to RFC6455. If the renderer is sending us an invalid code or // reason it must be malfunctioning in some way, and based on that we // interpret this as an internal error. - AllowUnused( - SendClose(kWebSocketErrorInternalServerError, "Internal Error")); - // |this| may have been deleted. + if (SendClose(kWebSocketErrorInternalServerError, "") != CHANNEL_DELETED) { + DCHECK_EQ(CONNECTED, state_); + SetState(SEND_CLOSED); + } return; } - AllowUnused(SendClose(code, IsStringUTF8(reason) ? reason : std::string())); - // |this| may have been deleted. + if (SendClose( + code, + StreamingUtf8Validator::Validate(reason) ? reason : std::string()) == + CHANNEL_DELETED) + return; + DCHECK_EQ(CONNECTED, state_); + SetState(SEND_CLOSED); } void WebSocketChannel::SendAddChannelRequestForTesting( const GURL& socket_url, const std::vector<std::string>& requested_subprotocols, - const GURL& origin, + const url::Origin& origin, const WebSocketStreamCreator& creator) { SendAddChannelRequestWithSuppliedCreator( socket_url, requested_subprotocols, origin, creator); @@ -268,13 +521,13 @@ void WebSocketChannel::SetClosingHandshakeTimeoutForTesting( void WebSocketChannel::SendAddChannelRequestWithSuppliedCreator( const GURL& socket_url, const std::vector<std::string>& requested_subprotocols, - const GURL& origin, + const url::Origin& origin, const WebSocketStreamCreator& creator) { DCHECK_EQ(FRESHLY_CONSTRUCTED, state_); if (!socket_url.SchemeIsWSOrWSS()) { // TODO(ricea): Kill the renderer (this error should have been caught by // Javascript). - AllowUnused(event_interface_->OnAddChannelResponse(true, "")); + AllowUnused(event_interface_->OnAddChannelResponse(true, "", "")); // |this| is deleted here. return; } @@ -287,16 +540,20 @@ void WebSocketChannel::SendAddChannelRequestWithSuppliedCreator( url_request_context_, BoundNetLog(), connect_delegate.Pass()); - state_ = CONNECTING; + SetState(CONNECTING); } void WebSocketChannel::OnConnectSuccess(scoped_ptr<WebSocketStream> stream) { DCHECK(stream); DCHECK_EQ(CONNECTING, state_); + stream_ = stream.Pass(); - state_ = CONNECTED; + + SetState(CONNECTED); + if (event_interface_->OnAddChannelResponse( - false, stream_->GetSubProtocol()) == CHANNEL_DELETED) + false, stream_->GetSubProtocol(), stream_->GetExtensions()) == + CHANNEL_DELETED) return; // TODO(ricea): Get flow control information from the WebSocketStream once we @@ -308,18 +565,64 @@ void WebSocketChannel::OnConnectSuccess(scoped_ptr<WebSocketStream> stream) { // |stream_request_| is not used once the connection has succeeded. stream_request_.reset(); + AllowUnused(ReadFrames()); // |this| may have been deleted. } -void WebSocketChannel::OnConnectFailure(uint16 websocket_error) { +void WebSocketChannel::OnConnectFailure(const std::string& message) { DCHECK_EQ(CONNECTING, state_); - state_ = CLOSED; + + // Copy the message before we delete its owner. + std::string message_copy = message; + + SetState(CLOSED); stream_request_.reset(); - AllowUnused(event_interface_->OnAddChannelResponse(true, "")); + + if (CHANNEL_DELETED == + notification_sender_->SendImmediately(event_interface_.get())) { + // |this| has been deleted. + return; + } + AllowUnused(event_interface_->OnFailChannel(message_copy)); // |this| has been deleted. } +void WebSocketChannel::OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks, + const SSLInfo& ssl_info, + bool fatal) { + AllowUnused(event_interface_->OnSSLCertificateError( + ssl_error_callbacks.Pass(), socket_url_, ssl_info, fatal)); +} + +void WebSocketChannel::OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) { + DCHECK(!notification_sender_->handshake_request_info()); + + // Because it is hard to handle an IPC error synchronously is difficult, + // we asynchronously notify the information. + notification_sender_->set_handshake_request_info(request.Pass()); + ScheduleOpeningHandshakeNotification(); +} + +void WebSocketChannel::OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) { + DCHECK(!notification_sender_->handshake_response_info()); + + // Because it is hard to handle an IPC error synchronously is difficult, + // we asynchronously notify the information. + notification_sender_->set_handshake_response_info(response.Pass()); + ScheduleOpeningHandshakeNotification(); +} + +void WebSocketChannel::ScheduleOpeningHandshakeNotification() { + base::MessageLoop::current()->PostTask( + FROM_HERE, + base::Bind(HandshakeNotificationSender::Send, + notification_sender_->AsWeakPtr())); +} + ChannelState WebSocketChannel::WriteFrames() { int result = OK; do { @@ -333,6 +636,8 @@ ChannelState WebSocketChannel::WriteFrames() { if (result != ERR_IO_PENDING) { if (OnWriteDone(true, result) == CHANNEL_DELETED) return CHANNEL_DELETED; + // OnWriteDone() returns CHANNEL_DELETED on error. Here |state_| is + // guaranteed to be the same as before OnWriteDone() call. } } while (result == OK && data_being_sent_); return CHANNEL_ALIVE; @@ -373,17 +678,16 @@ ChannelState WebSocketChannel::OnWriteDone(bool synchronous, int result) { default: DCHECK_LT(result, 0) << "WriteFrames() should only return OK or ERR_ codes"; + stream_->Close(); - DCHECK_NE(CLOSED, state_); - state_ = CLOSED; - return event_interface_->OnDropChannel(kWebSocketErrorAbnormalClosure, - "Abnormal Closure"); + SetState(CLOSED); + return DoDropChannel(false, kWebSocketErrorAbnormalClosure, ""); } } ChannelState WebSocketChannel::ReadFrames() { int result = OK; - do { + while (result == OK && current_receive_quota_ > 0) { // This use of base::Unretained is safe because this object owns the // WebSocketStream, and any pending reads will be cancelled when it is // destroyed. @@ -397,7 +701,7 @@ ChannelState WebSocketChannel::ReadFrames() { return CHANNEL_DELETED; } DCHECK_NE(CLOSED, state_); - } while (result == OK); + } return CHANNEL_ALIVE; } @@ -414,7 +718,7 @@ ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) { for (size_t i = 0; i < read_frames_.size(); ++i) { scoped_ptr<WebSocketFrame> frame(read_frames_[i]); read_frames_[i] = NULL; - if (ProcessFrame(frame.Pass()) == CHANNEL_DELETED) + if (HandleFrame(frame.Pass()) == CHANNEL_DELETED) return CHANNEL_DELETED; } read_frames_.clear(); @@ -426,48 +730,64 @@ ChannelState WebSocketChannel::OnReadDone(bool synchronous, int result) { return CHANNEL_ALIVE; case ERR_WS_PROTOCOL_ERROR: - return FailChannel(SEND_REAL_ERROR, + // This could be kWebSocketErrorProtocolError (specifically, non-minimal + // encoding of payload length) or kWebSocketErrorMessageTooBig, or an + // extension-specific error. + return FailChannel("Invalid frame header", kWebSocketErrorProtocolError, "WebSocket Protocol Error"); default: DCHECK_LT(result, 0) << "ReadFrames() should only return OK or ERR_ codes"; + stream_->Close(); - DCHECK_NE(CLOSED, state_); - state_ = CLOSED; + SetState(CLOSED); + uint16 code = kWebSocketErrorAbnormalClosure; - std::string reason = "Abnormal Closure"; - if (closing_code_ != 0) { - code = closing_code_; - reason = closing_reason_; + std::string reason = ""; + bool was_clean = false; + if (received_close_code_ != 0) { + code = received_close_code_; + reason = received_close_reason_; + was_clean = (result == ERR_CONNECTION_CLOSED); } - return event_interface_->OnDropChannel(code, reason); + + return DoDropChannel(was_clean, code, reason); } } -ChannelState WebSocketChannel::ProcessFrame(scoped_ptr<WebSocketFrame> frame) { +ChannelState WebSocketChannel::HandleFrame(scoped_ptr<WebSocketFrame> frame) { if (frame->header.masked) { // RFC6455 Section 5.1 "A client MUST close a connection if it detects a // masked frame." - return FailChannel(SEND_REAL_ERROR, - kWebSocketErrorProtocolError, - "Masked frame from server"); + return FailChannel( + "A server must not mask any frames that it sends to the " + "client.", + kWebSocketErrorProtocolError, + "Masked frame from server"); } const WebSocketFrameHeader::OpCode opcode = frame->header.opcode; - if (WebSocketFrameHeader::IsKnownControlOpCode(opcode) && - !frame->header.final) { - return FailChannel(SEND_REAL_ERROR, + DCHECK(!WebSocketFrameHeader::IsKnownControlOpCode(opcode) || + frame->header.final); + if (frame->header.reserved1 || frame->header.reserved2 || + frame->header.reserved3) { + return FailChannel(base::StringPrintf( + "One or more reserved bits are on: reserved1 = %d, " + "reserved2 = %d, reserved3 = %d", + static_cast<int>(frame->header.reserved1), + static_cast<int>(frame->header.reserved2), + static_cast<int>(frame->header.reserved3)), kWebSocketErrorProtocolError, - "Control message with FIN bit unset received"); + "Invalid reserved bit"); } // Respond to the frame appropriately to its type. - return HandleFrame( + return HandleFrameByState( opcode, frame->header.final, frame->data, frame->header.payload_length); } -ChannelState WebSocketChannel::HandleFrame( +ChannelState WebSocketChannel::HandleFrameByState( const WebSocketFrameHeader::OpCode opcode, bool final, const scoped_refptr<IOBuffer>& data_buffer, @@ -478,96 +798,66 @@ ChannelState WebSocketChannel::HandleFrame( DCHECK_NE(CLOSED, state_); if (state_ == CLOSE_WAIT) { std::string frame_name; - switch (opcode) { - case WebSocketFrameHeader::kOpCodeText: // fall-thru - case WebSocketFrameHeader::kOpCodeBinary: // fall-thru - case WebSocketFrameHeader::kOpCodeContinuation: - frame_name = "Data frame"; - break; - - case WebSocketFrameHeader::kOpCodePing: - frame_name = "Ping"; - break; - - case WebSocketFrameHeader::kOpCodePong: - frame_name = "Pong"; - break; - - case WebSocketFrameHeader::kOpCodeClose: - frame_name = "Close"; - break; - - default: - frame_name = "Unknown frame type"; - break; - } - // SEND_REAL_ERROR makes no difference here, as FailChannel() won't send - // another Close frame. - return FailChannel(SEND_REAL_ERROR, - kWebSocketErrorProtocolError, - frame_name + " received after close"); + GetFrameTypeForOpcode(opcode, &frame_name); + + // FailChannel() won't send another Close frame. + return FailChannel( + frame_name + " received after close", kWebSocketErrorProtocolError, ""); } switch (opcode) { - case WebSocketFrameHeader::kOpCodeText: // fall-thru - case WebSocketFrameHeader::kOpCodeBinary: // fall-thru + case WebSocketFrameHeader::kOpCodeText: // fall-thru + case WebSocketFrameHeader::kOpCodeBinary: case WebSocketFrameHeader::kOpCodeContinuation: - if (state_ == CONNECTED) { - // TODO(ricea): Need to fail the connection if UTF-8 is invalid - // post-reassembly. Requires a streaming UTF-8 validator. - // TODO(ricea): Can this copy be eliminated? - const char* const data_begin = size ? data_buffer->data() : NULL; - const char* const data_end = data_begin + size; - const std::vector<char> data(data_begin, data_end); - // TODO(ricea): Handle the case when ReadFrames returns far - // more data at once than should be sent in a single IPC. This needs to - // be handled carefully, as an overloaded IO thread is one possible - // cause of receiving very large chunks. - - // Sends the received frame to the renderer process. - return event_interface_->OnDataFrame(final, opcode, data); - } - VLOG(3) << "Ignored data packet received in state " << state_; - return CHANNEL_ALIVE; + return HandleDataFrame(opcode, final, data_buffer, size); case WebSocketFrameHeader::kOpCodePing: - VLOG(1) << "Got Ping of size " << size; + DVLOG(1) << "Got Ping of size " << size; if (state_ == CONNECTED) - return SendIOBuffer( + return SendFrameFromIOBuffer( true, WebSocketFrameHeader::kOpCodePong, data_buffer, size); - VLOG(3) << "Ignored ping in state " << state_; + DVLOG(3) << "Ignored ping in state " << state_; return CHANNEL_ALIVE; case WebSocketFrameHeader::kOpCodePong: - VLOG(1) << "Got Pong of size " << size; + DVLOG(1) << "Got Pong of size " << size; // There is no need to do anything with pong messages. return CHANNEL_ALIVE; case WebSocketFrameHeader::kOpCodeClose: { + // TODO(ricea): If there is a message which is queued for transmission to + // the renderer, then the renderer should not receive an + // OnClosingHandshake or OnDropChannel IPC until the queued message has + // been completedly transmitted. uint16 code = kWebSocketNormalClosure; std::string reason; - ParseClose(data_buffer, size, &code, &reason); + std::string message; + if (!ParseClose(data_buffer, size, &code, &reason, &message)) { + return FailChannel(message, code, reason); + } // TODO(ricea): Find a way to safely log the message from the close // message (escape control codes and so on). - VLOG(1) << "Got Close with code " << code; + DVLOG(1) << "Got Close with code " << code; switch (state_) { case CONNECTED: - state_ = RECV_CLOSED; - if (SendClose(code, reason) == // Sets state_ to CLOSE_WAIT - CHANNEL_DELETED) + SetState(RECV_CLOSED); + if (SendClose(code, reason) == CHANNEL_DELETED) return CHANNEL_DELETED; + DCHECK_EQ(RECV_CLOSED, state_); + SetState(CLOSE_WAIT); + if (event_interface_->OnClosingHandshake() == CHANNEL_DELETED) return CHANNEL_DELETED; - closing_code_ = code; - closing_reason_ = reason; + received_close_code_ = code; + received_close_reason_ = reason; break; case SEND_CLOSED: - state_ = CLOSE_WAIT; + SetState(CLOSE_WAIT); // From RFC6455 section 7.1.5: "Each endpoint // will see the status code sent by the other end as _The WebSocket // Connection Close Code_." - closing_code_ = code; - closing_reason_ = reason; + received_close_code_ = code; + received_close_reason_ = reason; break; default: @@ -579,23 +869,105 @@ ChannelState WebSocketChannel::HandleFrame( default: return FailChannel( - SEND_REAL_ERROR, kWebSocketErrorProtocolError, "Unknown opcode"); + base::StringPrintf("Unrecognized frame opcode: %d", opcode), + kWebSocketErrorProtocolError, + "Unknown opcode"); } } -ChannelState WebSocketChannel::SendIOBuffer( +ChannelState WebSocketChannel::HandleDataFrame( + WebSocketFrameHeader::OpCode opcode, + bool final, + const scoped_refptr<IOBuffer>& data_buffer, + size_t size) { + if (state_ != CONNECTED) { + DVLOG(3) << "Ignored data packet received in state " << state_; + return CHANNEL_ALIVE; + } + DCHECK(opcode == WebSocketFrameHeader::kOpCodeContinuation || + opcode == WebSocketFrameHeader::kOpCodeText || + opcode == WebSocketFrameHeader::kOpCodeBinary); + const bool got_continuation = + (opcode == WebSocketFrameHeader::kOpCodeContinuation); + if (got_continuation != expecting_to_handle_continuation_) { + const std::string console_log = got_continuation + ? "Received unexpected continuation frame." + : "Received start of new message but previous message is unfinished."; + const std::string reason = got_continuation + ? "Unexpected continuation" + : "Previous data frame unfinished"; + return FailChannel(console_log, kWebSocketErrorProtocolError, reason); + } + expecting_to_handle_continuation_ = !final; + WebSocketFrameHeader::OpCode opcode_to_send = opcode; + if (!initial_frame_forwarded_ && + opcode == WebSocketFrameHeader::kOpCodeContinuation) { + opcode_to_send = receiving_text_message_ + ? WebSocketFrameHeader::kOpCodeText + : WebSocketFrameHeader::kOpCodeBinary; + } + if (opcode == WebSocketFrameHeader::kOpCodeText || + (opcode == WebSocketFrameHeader::kOpCodeContinuation && + receiving_text_message_)) { + // This call is not redundant when size == 0 because it tells us what + // the current state is. + StreamingUtf8Validator::State state = incoming_utf8_validator_.AddBytes( + size ? data_buffer->data() : NULL, size); + if (state == StreamingUtf8Validator::INVALID || + (state == StreamingUtf8Validator::VALID_MIDPOINT && final)) { + return FailChannel("Could not decode a text frame as UTF-8.", + kWebSocketErrorProtocolError, + "Invalid UTF-8 in text frame"); + } + receiving_text_message_ = !final; + DCHECK(!final || state == StreamingUtf8Validator::VALID_ENDPOINT); + } + if (size == 0U && !final) + return CHANNEL_ALIVE; + + initial_frame_forwarded_ = !final; + if (size > base::checked_cast<size_t>(current_receive_quota_) || + !pending_received_frames_.empty()) { + const bool no_quota = (current_receive_quota_ == 0); + DCHECK(no_quota || pending_received_frames_.empty()); + DVLOG(3) << "Queueing frame to renderer due to quota. quota=" + << current_receive_quota_ << " size=" << size; + WebSocketFrameHeader::OpCode opcode_to_queue = + no_quota ? opcode_to_send : WebSocketFrameHeader::kOpCodeContinuation; + pending_received_frames_.push(PendingReceivedFrame( + final, opcode_to_queue, data_buffer, current_receive_quota_, size)); + if (no_quota) + return CHANNEL_ALIVE; + size = current_receive_quota_; + final = false; + } + + // TODO(ricea): Can this copy be eliminated? + const char* const data_begin = size ? data_buffer->data() : NULL; + const char* const data_end = data_begin + size; + const std::vector<char> data(data_begin, data_end); + current_receive_quota_ -= size; + DCHECK_GE(current_receive_quota_, 0); + + // Sends the received frame to the renderer process. + return event_interface_->OnDataFrame(final, opcode_to_send, data); +} + +ChannelState WebSocketChannel::SendFrameFromIOBuffer( bool fin, WebSocketFrameHeader::OpCode op_code, const scoped_refptr<IOBuffer>& buffer, size_t size) { DCHECK(state_ == CONNECTED || state_ == RECV_CLOSED); DCHECK(stream_); + scoped_ptr<WebSocketFrame> frame(new WebSocketFrame(op_code)); WebSocketFrameHeader& header = frame->header; header.final = fin; header.masked = true; header.payload_length = size; frame->data = buffer; + if (data_being_sent_) { // Either the link to the WebSocket server is saturated, or several messages // are being sent in a batch. @@ -606,36 +978,31 @@ ChannelState WebSocketChannel::SendIOBuffer( data_to_send_next_->AddFrame(frame.Pass()); return CHANNEL_ALIVE; } + data_being_sent_.reset(new SendBuffer); data_being_sent_->AddFrame(frame.Pass()); return WriteFrames(); } -ChannelState WebSocketChannel::FailChannel(ExposeError expose, +ChannelState WebSocketChannel::FailChannel(const std::string& message, uint16 code, const std::string& reason) { DCHECK_NE(FRESHLY_CONSTRUCTED, state_); DCHECK_NE(CONNECTING, state_); DCHECK_NE(CLOSED, state_); + // TODO(ricea): Logging. if (state_ == CONNECTED) { - uint16 send_code = kWebSocketErrorGoingAway; - std::string send_reason = "Internal Error"; - if (expose == SEND_REAL_ERROR) { - send_code = code; - send_reason = reason; - } - if (SendClose(send_code, send_reason) == // Sets state_ to SEND_CLOSED - CHANNEL_DELETED) + if (SendClose(code, reason) == CHANNEL_DELETED) return CHANNEL_DELETED; } + // Careful study of RFC6455 section 7.1.7 and 7.1.1 indicates the browser // should close the connection itself without waiting for the closing // handshake. stream_->Close(); - state_ = CLOSED; - - return event_interface_->OnDropChannel(code, reason); + SetState(CLOSED); + return event_interface_->OnFailChannel(message); } ChannelState WebSocketChannel::SendClose(uint16 code, @@ -647,12 +1014,13 @@ ChannelState WebSocketChannel::SendClose(uint16 code, if (code == kWebSocketErrorNoStatusReceived) { // Special case: translate kWebSocketErrorNoStatusReceived into a Close // frame with no payload. + DCHECK(reason.empty()); body = new IOBuffer(0); } else { const size_t payload_length = kWebSocketCloseCodeLength + reason.length(); body = new IOBuffer(payload_length); size = payload_length; - WriteBigEndian(body->data(), code); + base::WriteBigEndian(body->data(), code); COMPILE_ASSERT(sizeof(code) == kWebSocketCloseCodeLength, they_should_both_be_two); std::copy( @@ -664,57 +1032,82 @@ ChannelState WebSocketChannel::SendClose(uint16 code, FROM_HERE, timeout_, base::Bind(&WebSocketChannel::CloseTimeout, base::Unretained(this))); - if (SendIOBuffer(true, WebSocketFrameHeader::kOpCodeClose, body, size) == + if (SendFrameFromIOBuffer( + true, WebSocketFrameHeader::kOpCodeClose, body, size) == CHANNEL_DELETED) return CHANNEL_DELETED; - // SendIOBuffer() checks |state_|, so it is best not to change it until after - // SendIOBuffer() returns. - state_ = (state_ == CONNECTED) ? SEND_CLOSED : CLOSE_WAIT; return CHANNEL_ALIVE; } -void WebSocketChannel::ParseClose(const scoped_refptr<IOBuffer>& buffer, +bool WebSocketChannel::ParseClose(const scoped_refptr<IOBuffer>& buffer, size_t size, uint16* code, - std::string* reason) { + std::string* reason, + std::string* message) { reason->clear(); if (size < kWebSocketCloseCodeLength) { - *code = kWebSocketErrorNoStatusReceived; - if (size != 0) { - VLOG(1) << "Close frame with payload size " << size << " received " - << "(the first byte is " << std::hex - << static_cast<int>(buffer->data()[0]) << ")"; + if (size == 0U) { + *code = kWebSocketErrorNoStatusReceived; + return true; } - return; + + DVLOG(1) << "Close frame with payload size " << size << " received " + << "(the first byte is " << std::hex + << static_cast<int>(buffer->data()[0]) << ")"; + *code = kWebSocketErrorProtocolError; + *message = + "Received a broken close frame containing an invalid size body."; + return false; } + const char* data = buffer->data(); uint16 unchecked_code = 0; - ReadBigEndian(data, &unchecked_code); + base::ReadBigEndian(data, &unchecked_code); COMPILE_ASSERT(sizeof(unchecked_code) == kWebSocketCloseCodeLength, they_should_both_be_two_bytes); - if (unchecked_code >= static_cast<uint16>(kWebSocketNormalClosure) && - unchecked_code <= - static_cast<uint16>(kWebSocketErrorPrivateReservedMax)) { - *code = unchecked_code; - } else { - VLOG(1) << "Close frame contained code outside of the valid range: " - << unchecked_code; - *code = kWebSocketErrorAbnormalClosure; + + switch (unchecked_code) { + case kWebSocketErrorNoStatusReceived: + case kWebSocketErrorAbnormalClosure: + case kWebSocketErrorTlsHandshake: + *code = kWebSocketErrorProtocolError; + *message = + "Received a broken close frame containing a reserved status code."; + return false; + + default: + *code = unchecked_code; + break; } + std::string text(data + kWebSocketCloseCodeLength, data + size); - // IsStringUTF8() blocks surrogate pairs and non-characters, so it is strictly - // stronger than required by RFC3629. - if (IsStringUTF8(text)) { + if (StreamingUtf8Validator::Validate(text)) { reason->swap(text); + return true; } + + *code = kWebSocketErrorProtocolError; + *reason = "Invalid UTF-8 in Close frame"; + *message = "Received a broken close frame containing invalid UTF-8."; + return false; +} + +ChannelState WebSocketChannel::DoDropChannel(bool was_clean, + uint16 code, + const std::string& reason) { + if (CHANNEL_DELETED == + notification_sender_->SendImmediately(event_interface_.get())) + return CHANNEL_DELETED; + ChannelState result = + event_interface_->OnDropChannel(was_clean, code, reason); + DCHECK_EQ(CHANNEL_DELETED, result); + return result; } void WebSocketChannel::CloseTimeout() { stream_->Close(); - DCHECK_NE(CLOSED, state_); - state_ = CLOSED; - AllowUnused(event_interface_->OnDropChannel(kWebSocketErrorAbnormalClosure, - "Abnormal Closure")); + SetState(CLOSED); + AllowUnused(DoDropChannel(false, kWebSocketErrorAbnormalClosure, "")); // |this| has been deleted. } diff --git a/chromium/net/websockets/websocket_channel.h b/chromium/net/websockets/websocket_channel.h index 61f191af978..6d5640e8be0 100644 --- a/chromium/net/websockets/websocket_channel.h +++ b/chromium/net/websockets/websocket_channel.h @@ -5,12 +5,15 @@ #ifndef NET_WEBSOCKETS_WEBSOCKET_CHANNEL_H_ #define NET_WEBSOCKETS_WEBSOCKET_CHANNEL_H_ +#include <queue> #include <string> #include <vector> #include "base/basictypes.h" #include "base/callback.h" #include "base/compiler_specific.h" // for WARN_UNUSED_RESULT +#include "base/i18n/streaming_utf8_validator.h" +#include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/scoped_vector.h" #include "base/time/time.h" @@ -21,11 +24,17 @@ #include "net/websockets/websocket_stream.h" #include "url/gurl.h" +namespace url { +class Origin; +} // namespace url + namespace net { class BoundNetLog; class IOBuffer; class URLRequestContext; +struct WebSocketHandshakeRequestInfo; +struct WebSocketHandshakeResponseInfo; // Transport-independent implementation of WebSockets. Implements protocol // semantics that do not depend on the underlying transport. Provides the @@ -39,7 +48,7 @@ class NET_EXPORT WebSocketChannel { typedef base::Callback<scoped_ptr<WebSocketStreamRequest>( const GURL&, const std::vector<std::string>&, - const GURL&, + const url::Origin&, URLRequestContext*, const BoundNetLog&, scoped_ptr<WebSocketStream::ConnectDelegate>)> WebSocketStreamCreator; @@ -55,7 +64,7 @@ class NET_EXPORT WebSocketChannel { void SendAddChannelRequest( const GURL& socket_url, const std::vector<std::string>& requested_protocols, - const GURL& origin); + const url::Origin& origin); // Sends a data frame to the remote side. The frame should usually be no // larger than 32KB to prevent the time required to copy the buffers from from @@ -93,7 +102,7 @@ class NET_EXPORT WebSocketChannel { void SendAddChannelRequestForTesting( const GURL& socket_url, const std::vector<std::string>& requested_protocols, - const GURL& origin, + const url::Origin& origin, const WebSocketStreamCreator& creator); // The default timout for the closing handshake is a sensible value (see @@ -101,7 +110,55 @@ class NET_EXPORT WebSocketChannel { // set it to a very small value for testing purposes. void SetClosingHandshakeTimeoutForTesting(base::TimeDelta delay); + // Called when the stream starts the WebSocket Opening Handshake. + // This method is public for testing. + void OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request); + + // Called when the stream ends the WebSocket Opening Handshake. + // This method is public for testing. + void OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response); + private: + class HandshakeNotificationSender; + + // The Windows implementation of std::queue requires that this declaration be + // visible in the header. + class PendingReceivedFrame { + public: + PendingReceivedFrame(bool final, + WebSocketFrameHeader::OpCode opcode, + const scoped_refptr<IOBuffer>& data, + size_t offset, + size_t size); + ~PendingReceivedFrame(); + + bool final() const { return final_; } + WebSocketFrameHeader::OpCode opcode() const { return opcode_; } + // ResetOpcode() to Continuation. + void ResetOpcode(); + const scoped_refptr<IOBuffer>& data() const { return data_; } + size_t offset() const { return offset_; } + size_t size() const { return size_; } + // Increase |offset_| by |bytes|. + void DidConsume(size_t bytes); + + // This object needs to be copyable and assignable, since it will be placed + // in a std::queue. The compiler-generated copy constructor and assignment + // operator will do the right thing. + + private: + bool final_; + WebSocketFrameHeader::OpCode opcode_; + scoped_refptr<IOBuffer> data_; + // Where to start reading from data_. Everything prior to offset_ has + // already been sent to the browser. + size_t offset_; + // The size of data_. + size_t size_; + }; + // Methods which return a value of type ChannelState may delete |this|. If the // return value is CHANNEL_DELETED, then the caller must return without making // any further access to member variables or methods. @@ -124,14 +181,6 @@ class NET_EXPORT WebSocketChannel { // has been closed; or the connection is failed. }; - // When failing a channel, sometimes it is inappropriate to expose the real - // reason for failing to the remote server. This enum is used by FailChannel() - // to select between sending the real status or a "Going Away" status. - enum ExposeError { - SEND_REAL_ERROR, - SEND_GOING_AWAY, - }; - // Implementation of WebSocketStream::ConnectDelegate for // WebSocketChannel. WebSocketChannel does not inherit from // WebSocketStream::ConnectDelegate directly to avoid cluttering the public @@ -144,7 +193,7 @@ class NET_EXPORT WebSocketChannel { void SendAddChannelRequestWithSuppliedCreator( const GURL& socket_url, const std::vector<std::string>& requested_protocols, - const GURL& origin, + const url::Origin& origin, const WebSocketStreamCreator& creator); // Success callback from WebSocketStream::CreateAndConnectStream(). Reports @@ -153,7 +202,23 @@ class NET_EXPORT WebSocketChannel { // Failure callback from WebSocketStream::CreateAndConnectStream(). Reports // failure to the event interface. May delete |this|. - void OnConnectFailure(uint16 websocket_error); + void OnConnectFailure(const std::string& message); + + // SSL certificate error callback from + // WebSocketStream::CreateAndConnectStream(). Forwards the request to the + // event interface. + void OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> + ssl_error_callbacks, + const SSLInfo& ssl_info, + bool fatal); + + // Posts a task that sends pending notifications relating WebSocket Opening + // Handshake to the renderer. + void ScheduleOpeningHandshakeNotification(); + + // Sets |state_| to |new_state| and updates UMA if necessary. + void SetState(State new_state); // Returns true if state_ is SEND_CLOSED, CLOSE_WAIT or CLOSED. bool InClosingState() const; @@ -168,7 +233,8 @@ class NET_EXPORT WebSocketChannel { // WriteFrames() itself. ChannelState OnWriteDone(bool synchronous, int result) WARN_UNUSED_RESULT; - // Calls WebSocketStream::ReadFrames() with the appropriate arguments. + // Calls WebSocketStream::ReadFrames() with the appropriate arguments. Stops + // calling ReadFrames if current_receive_quota_ is 0. ChannelState ReadFrames() WARN_UNUSED_RESULT; // Callback from WebSocketStream::ReadFrames. Handles any errors and processes @@ -177,37 +243,52 @@ class NET_EXPORT WebSocketChannel { // within the ReadFrames() loop and does not need to call ReadFrames() itself. ChannelState OnReadDone(bool synchronous, int result) WARN_UNUSED_RESULT; - // Processes a single frame that has been read from the stream. - ChannelState ProcessFrame( + // Handles a single frame that the object has received enough of to process. + // May call |event_interface_| methods, send responses to the server, and + // change the value of |state_|. + // + // This method performs sanity checks on the frame that are needed regardless + // of the current state. Then, calls the HandleFrameByState() method below + // which performs the appropriate action(s) depending on the current state. + ChannelState HandleFrame( scoped_ptr<WebSocketFrame> frame) WARN_UNUSED_RESULT; - // Handles a frame that the object has received enough of to process. May call - // |event_interface_| methods, send responses to the server, and change the - // value of |state_|. - ChannelState HandleFrame(const WebSocketFrameHeader::OpCode opcode, - bool final, - const scoped_refptr<IOBuffer>& data_buffer, - size_t size) WARN_UNUSED_RESULT; + // Handles a single frame depending on the current state. It's used by the + // HandleFrame() method. + ChannelState HandleFrameByState( + const WebSocketFrameHeader::OpCode opcode, + bool final, + const scoped_refptr<IOBuffer>& data_buffer, + size_t size) WARN_UNUSED_RESULT; + + // Forward a received data frame to the renderer, if connected. If + // |expecting_continuation| is not equal to |expecting_to_read_continuation_|, + // will fail the channel. Also checks the UTF-8 validity of text frames. + ChannelState HandleDataFrame(WebSocketFrameHeader::OpCode opcode, + bool final, + const scoped_refptr<IOBuffer>& data_buffer, + size_t size) WARN_UNUSED_RESULT; // Low-level method to send a single frame. Used for both data and control // frames. Either sends the frame immediately or buffers it to be scheduled // when the current write finishes. |fin| and |op_code| are defined as for // SendFrame() above, except that |op_code| may also be a control frame // opcode. - ChannelState SendIOBuffer(bool fin, - WebSocketFrameHeader::OpCode op_code, - const scoped_refptr<IOBuffer>& buffer, - size_t size) WARN_UNUSED_RESULT; + ChannelState SendFrameFromIOBuffer(bool fin, + WebSocketFrameHeader::OpCode op_code, + const scoped_refptr<IOBuffer>& buffer, + size_t size) WARN_UNUSED_RESULT; // Performs the "Fail the WebSocket Connection" operation as defined in - // RFC6455. The supplied code and reason are sent back to the renderer in an - // OnDropChannel message. If state_ is CONNECTED then a Close message is sent - // to the remote host. If |expose| is SEND_REAL_ERROR then the remote host is - // given the same status code passed to the renderer; otherwise it is sent a - // fixed "Going Away" code. Closes the stream_ and sets state_ to CLOSED. - // FailChannel() always returns CHANNEL_DELETED. It is not valid to access any - // member variables or methods after calling FailChannel(). - ChannelState FailChannel(ExposeError expose, + // RFC6455. A NotifyFailure message is sent to the renderer with |message|. + // The renderer will log the message to the console but not expose it to + // Javascript. Javascript will see a Close code of AbnormalClosure (1006) with + // an empty reason string. If state_ is CONNECTED then a Close message is sent + // to the remote host containing the supplied |code| and |reason|. If the + // stream is open, closes it and sets state_ to CLOSED. FailChannel() always + // returns CHANNEL_DELETED. It is not valid to access any member variables or + // methods after calling FailChannel(). + ChannelState FailChannel(const std::string& message, uint16 code, const std::string& reason) WARN_UNUSED_RESULT; @@ -218,15 +299,26 @@ class NET_EXPORT WebSocketChannel { ChannelState SendClose(uint16 code, const std::string& reason) WARN_UNUSED_RESULT; - // Parses a Close frame. If no status code is supplied, then |code| is set to - // 1005 (No status code) with empty |reason|. If the supplied code is - // outside the valid range, then 1002 (Protocol error) is set instead. If the - // reason text is not valid UTF-8, then |reason| is set to an empty string - // instead. - void ParseClose(const scoped_refptr<IOBuffer>& buffer, + // Parses a Close frame payload. If no status code is supplied, then |code| is + // set to 1005 (No status code) with empty |reason|. If the reason text is not + // valid UTF-8, then |reason| is set to an empty string. If the payload size + // is 1, or the supplied code is not permitted to be sent over the network, + // then false is returned and |message| is set to an appropriate console + // message. + bool ParseClose(const scoped_refptr<IOBuffer>& buffer, size_t size, uint16* code, - std::string* reason); + std::string* reason, + std::string* message); + + // Drop this channel. + // If there are pending opening handshake notifications, notify them + // before dropping. + // + // Always returns CHANNEL_DELETED. + ChannelState DoDropChannel(bool was_clean, + uint16 code, + const std::string& reason); // Called if the closing handshake times out. Closes the connection and // informs the |event_interface_| if appropriate. @@ -256,6 +348,10 @@ class NET_EXPORT WebSocketChannel { // Destination for the current call to WebSocketStream::ReadFrames ScopedVector<WebSocketFrame> read_frames_; + // Frames that have been read but not yet forwarded to the renderer due to + // lack of quota. + std::queue<PendingReceivedFrame> pending_received_frames_; + // Handle to an in-progress WebSocketStream creation request. Only non-NULL // during the connection process. scoped_ptr<WebSocketStreamRequest> stream_request_; @@ -270,6 +366,9 @@ class NET_EXPORT WebSocketChannel { // The current amount of quota that the renderer has available for sending // on this logical channel (quota units). int current_send_quota_; + // The remaining amount of quota that the renderer will allow us to send on + // this logical channel (quota units). + int current_receive_quota_; // Timer for the closing handshake. base::OneShotTimer<WebSocketChannel> timer_; @@ -280,13 +379,35 @@ class NET_EXPORT WebSocketChannel { // Storage for the status code and reason from the time the Close frame // arrives until the connection is closed and they are passed to // OnDropChannel(). - uint16 closing_code_; - std::string closing_reason_; + uint16 received_close_code_; + std::string received_close_reason_; // The current state of the channel. Mainly used for sanity checking, but also // used to track the close state. State state_; + // |notification_sender_| is owned by this object. + scoped_ptr<HandshakeNotificationSender> notification_sender_; + + // UTF-8 validator for outgoing Text messages. + base::StreamingUtf8Validator outgoing_utf8_validator_; + bool sending_text_message_; + + // UTF-8 validator for incoming Text messages. + base::StreamingUtf8Validator incoming_utf8_validator_; + bool receiving_text_message_; + + // True if we are in the middle of receiving a message. + bool expecting_to_handle_continuation_; + + // True if we have already sent the type (Text or Binary) of the current + // message to the renderer. This can be false if the message is empty so far. + bool initial_frame_forwarded_; + + // For UMA. The time when OnConnectSuccess() method was called and |stream_| + // was set. + base::TimeTicks established_on_; + DISALLOW_COPY_AND_ASSIGN(WebSocketChannel); }; diff --git a/chromium/net/websockets/websocket_channel_test.cc b/chromium/net/websockets/websocket_channel_test.cc index 1bd75db2eb1..4a8f119a65b 100644 --- a/chromium/net/websockets/websocket_channel_test.cc +++ b/chromium/net/websockets/websocket_channel_test.cc @@ -4,6 +4,7 @@ #include "net/websockets/websocket_channel.h" +#include <limits.h> #include <string.h> #include <iostream> @@ -18,17 +19,20 @@ #include "base/memory/scoped_vector.h" #include "base/memory/weak_ptr.h" #include "base/message_loop/message_loop.h" -#include "base/safe_numerics.h" #include "base/strings/string_piece.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" +#include "net/http/http_response_headers.h" #include "net/url_request/url_request_context.h" #include "net/websockets/websocket_errors.h" #include "net/websockets/websocket_event_interface.h" +#include "net/websockets/websocket_handshake_request_info.h" +#include "net/websockets/websocket_handshake_response_info.h" #include "net/websockets/websocket_mux.h" #include "testing/gmock/include/gmock/gmock.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" +#include "url/origin.h" // Hacky macros to construct the body of a Close message from a code and a // string, while ensuring the result is a compile-time constant string. @@ -37,6 +41,7 @@ #define WEBSOCKET_CLOSE_CODE_AS_STRING_NORMAL_CLOSURE "\x03\xe8" #define WEBSOCKET_CLOSE_CODE_AS_STRING_GOING_AWAY "\x03\xe9" #define WEBSOCKET_CLOSE_CODE_AS_STRING_PROTOCOL_ERROR "\x03\xea" +#define WEBSOCKET_CLOSE_CODE_AS_STRING_ABNORMAL_CLOSURE "\x03\xee" #define WEBSOCKET_CLOSE_CODE_AS_STRING_SERVER_ERROR "\x03\xf3" namespace net { @@ -92,6 +97,7 @@ using ::testing::AnyNumber; using ::testing::DefaultValue; using ::testing::InSequence; using ::testing::MockFunction; +using ::testing::NotNull; using ::testing::Return; using ::testing::SaveArg; using ::testing::StrictMock; @@ -125,27 +131,62 @@ const size_t kDefaultQuotaRefreshTrigger = (1 << 16) + 1; // in that time! I would like my tests to run a bit quicker. const int kVeryTinyTimeoutMillis = 1; +// Enough quota to pass any test. +const int64 kPlentyOfQuota = INT_MAX; + typedef WebSocketEventInterface::ChannelState ChannelState; const ChannelState CHANNEL_ALIVE = WebSocketEventInterface::CHANNEL_ALIVE; const ChannelState CHANNEL_DELETED = WebSocketEventInterface::CHANNEL_DELETED; // This typedef mainly exists to avoid having to repeat the "NOLINT" incantation // all over the place. -typedef MockFunction<void(int)> Checkpoint; // NOLINT +typedef StrictMock< MockFunction<void(int)> > Checkpoint; // NOLINT // This mock is for testing expectations about how the EventInterface is used. class MockWebSocketEventInterface : public WebSocketEventInterface { public: - MOCK_METHOD2(OnAddChannelResponse, - ChannelState(bool, const std::string&)); // NOLINT + MockWebSocketEventInterface() {} + + MOCK_METHOD3(OnAddChannelResponse, + ChannelState(bool, + const std::string&, + const std::string&)); // NOLINT MOCK_METHOD3(OnDataFrame, ChannelState(bool, WebSocketMessageType, const std::vector<char>&)); // NOLINT - MOCK_METHOD1(OnFlowControl, ChannelState(int64)); // NOLINT + MOCK_METHOD1(OnFlowControl, ChannelState(int64)); // NOLINT MOCK_METHOD0(OnClosingHandshake, ChannelState(void)); // NOLINT - MOCK_METHOD2(OnDropChannel, - ChannelState(uint16, const std::string&)); // NOLINT + MOCK_METHOD1(OnFailChannel, ChannelState(const std::string&)); // NOLINT + MOCK_METHOD3(OnDropChannel, + ChannelState(bool, uint16, const std::string&)); // NOLINT + + // We can't use GMock with scoped_ptr. + ChannelState OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo>) OVERRIDE { + OnStartOpeningHandshakeCalled(); + return CHANNEL_ALIVE; + } + ChannelState OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo>) OVERRIDE { + OnFinishOpeningHandshakeCalled(); + return CHANNEL_ALIVE; + } + virtual ChannelState OnSSLCertificateError( + scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks, + const GURL& url, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE { + OnSSLCertificateErrorCalled( + ssl_error_callbacks.get(), url, ssl_info, fatal); + return CHANNEL_ALIVE; + } + + MOCK_METHOD0(OnStartOpeningHandshakeCalled, void()); // NOLINT + MOCK_METHOD0(OnFinishOpeningHandshakeCalled, void()); // NOLINT + MOCK_METHOD4( + OnSSLCertificateErrorCalled, + void(SSLErrorCallbacks*, const GURL&, const SSLInfo&, bool)); // NOLINT }; // This fake EventInterface is for tests which need a WebSocketEventInterface @@ -153,7 +194,8 @@ class MockWebSocketEventInterface : public WebSocketEventInterface { class FakeWebSocketEventInterface : public WebSocketEventInterface { virtual ChannelState OnAddChannelResponse( bool fail, - const std::string& selected_protocol) OVERRIDE { + const std::string& selected_protocol, + const std::string& extensions) OVERRIDE { return fail ? CHANNEL_DELETED : CHANNEL_ALIVE; } virtual ChannelState OnDataFrame(bool fin, @@ -165,10 +207,29 @@ class FakeWebSocketEventInterface : public WebSocketEventInterface { return CHANNEL_ALIVE; } virtual ChannelState OnClosingHandshake() OVERRIDE { return CHANNEL_ALIVE; } - virtual ChannelState OnDropChannel(uint16 code, + virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE { + return CHANNEL_DELETED; + } + virtual ChannelState OnDropChannel(bool was_clean, + uint16 code, const std::string& reason) OVERRIDE { return CHANNEL_DELETED; } + virtual ChannelState OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE { + return CHANNEL_ALIVE; + } + virtual ChannelState OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { + return CHANNEL_ALIVE; + } + virtual ChannelState OnSSLCertificateError( + scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks, + const GURL& url, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE { + return CHANNEL_ALIVE; + } }; // This fake WebSocketStream is for tests that require a WebSocketStream but are @@ -212,15 +273,9 @@ class FakeWebSocketStream : public WebSocketStream { // To make the static initialisers easier to read, we use enums rather than // bools. -enum IsFinal { - NOT_FINAL_FRAME, - FINAL_FRAME -}; +enum IsFinal { NOT_FINAL_FRAME, FINAL_FRAME }; -enum IsMasked { - NOT_MASKED, - MASKED -}; +enum IsMasked { NOT_MASKED, MASKED }; // This is used to initialise a WebSocketFrame but is statically initialisable. struct InitFrame { @@ -394,10 +449,7 @@ ACTION_P(InvokeClosureReturnDeleted, closure) { // A FakeWebSocketStream whose ReadFrames() function returns data. class ReadableFakeWebSocketStream : public FakeWebSocketStream { public: - enum IsSync { - SYNC, - ASYNC - }; + enum IsSync { SYNC, ASYNC }; // After constructing the object, call PrepareReadFrames() once for each // time you wish it to return from the test. @@ -653,7 +705,7 @@ struct ArgumentCopyingWebSocketStreamCreator { scoped_ptr<WebSocketStreamRequest> Create( const GURL& socket_url, const std::vector<std::string>& requested_subprotocols, - const GURL& origin, + const url::Origin& origin, URLRequestContext* url_request_context, const BoundNetLog& net_log, scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { @@ -667,7 +719,7 @@ struct ArgumentCopyingWebSocketStreamCreator { } GURL socket_url; - GURL origin; + url::Origin origin; std::vector<std::string> requested_subprotocols; URLRequestContext* url_request_context; BoundNetLog net_log; @@ -681,6 +733,13 @@ std::vector<char> AsVector(const std::string& s) { return std::vector<char>(s.begin(), s.end()); } +class FakeSSLErrorCallbacks + : public WebSocketEventInterface::SSLErrorCallbacks { + public: + virtual void CancelSSLRequest(int error, const SSLInfo* ssl_info) OVERRIDE {} + virtual void ContinueSSLRequest() OVERRIDE {} +}; + // Base class for all test fixtures. class WebSocketChannelTest : public ::testing::Test { protected: @@ -703,6 +762,9 @@ class WebSocketChannelTest : public ::testing::Test { // well. This method is virtual so that subclasses can also set the stream. virtual void CreateChannelAndConnectSuccessfully() { CreateChannelAndConnect(); + // Most tests aren't concerned with flow control from the renderer, so allow + // MAX_INT quota units. + channel_->SendFlowControl(kPlentyOfQuota); connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); } @@ -727,10 +789,7 @@ class WebSocketChannelTest : public ::testing::Test { // A struct containing the data that will be used to connect the channel. // Grouped for readability. struct ConnectData { - ConnectData() : - socket_url("ws://ws/"), - origin("http://ws/") - {} + ConnectData() : socket_url("ws://ws/"), origin("http://ws") {} // URLRequestContext object. URLRequestContext url_request_context; @@ -740,7 +799,7 @@ class WebSocketChannelTest : public ::testing::Test { // Requested protocols for the request. std::vector<std::string> requested_subprotocols; // Origin of the request - GURL origin; + url::Origin origin; // A fake WebSocketStreamCreator that just records its arguments. ArgumentCopyingWebSocketStreamCreator creator; @@ -761,7 +820,11 @@ enum EventInterfaceCall { EVENT_ON_DATA_FRAME = 0x2, EVENT_ON_FLOW_CONTROL = 0x4, EVENT_ON_CLOSING_HANDSHAKE = 0x8, - EVENT_ON_DROP_CHANNEL = 0x10, + EVENT_ON_FAIL_CHANNEL = 0x10, + EVENT_ON_DROP_CHANNEL = 0x20, + EVENT_ON_START_OPENING_HANDSHAKE = 0x40, + EVENT_ON_FINISH_OPENING_HANDSHAKE = 0x80, + EVENT_ON_SSL_CERTIFICATE_ERROR = 0x100, }; class WebSocketChannelDeletingTest : public WebSocketChannelTest { @@ -780,7 +843,11 @@ class WebSocketChannelDeletingTest : public WebSocketChannelTest { : deleting_(EVENT_ON_ADD_CHANNEL_RESPONSE | EVENT_ON_DATA_FRAME | EVENT_ON_FLOW_CONTROL | EVENT_ON_CLOSING_HANDSHAKE | - EVENT_ON_DROP_CHANNEL) {} + EVENT_ON_FAIL_CHANNEL | + EVENT_ON_DROP_CHANNEL | + EVENT_ON_START_OPENING_HANDSHAKE | + EVENT_ON_FINISH_OPENING_HANDSHAKE | + EVENT_ON_SSL_CERTIFICATE_ERROR) {} // Create a ChannelDeletingFakeWebSocketEventInterface. Defined out-of-line to // avoid circular dependency. virtual scoped_ptr<WebSocketEventInterface> CreateEventInterface() OVERRIDE; @@ -802,7 +869,8 @@ class ChannelDeletingFakeWebSocketEventInterface virtual ChannelState OnAddChannelResponse( bool fail, - const std::string& selected_protocol) OVERRIDE { + const std::string& selected_protocol, + const std::string& extensions) OVERRIDE { return fixture_->DeleteIfDeleting(EVENT_ON_ADD_CHANNEL_RESPONSE); } @@ -820,11 +888,32 @@ class ChannelDeletingFakeWebSocketEventInterface return fixture_->DeleteIfDeleting(EVENT_ON_CLOSING_HANDSHAKE); } - virtual ChannelState OnDropChannel(uint16 code, + virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE { + return fixture_->DeleteIfDeleting(EVENT_ON_FAIL_CHANNEL); + } + + virtual ChannelState OnDropChannel(bool was_clean, + uint16 code, const std::string& reason) OVERRIDE { return fixture_->DeleteIfDeleting(EVENT_ON_DROP_CHANNEL); } + virtual ChannelState OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE { + return fixture_->DeleteIfDeleting(EVENT_ON_START_OPENING_HANDSHAKE); + } + virtual ChannelState OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { + return fixture_->DeleteIfDeleting(EVENT_ON_FINISH_OPENING_HANDSHAKE); + } + virtual ChannelState OnSSLCertificateError( + scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks, + const GURL& url, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE { + return fixture_->DeleteIfDeleting(EVENT_ON_SSL_CERTIFICATE_ERROR); + } + private: // A pointer to the test fixture. Owned by the test harness; this object will // be deleted before it is. @@ -844,9 +933,11 @@ class WebSocketChannelEventInterfaceTest : public WebSocketChannelTest { WebSocketChannelEventInterfaceTest() : event_interface_(new StrictMock<MockWebSocketEventInterface>) { DefaultValue<ChannelState>::Set(CHANNEL_ALIVE); - ON_CALL(*event_interface_, OnAddChannelResponse(true, _)) + ON_CALL(*event_interface_, OnAddChannelResponse(true, _, _)) + .WillByDefault(Return(CHANNEL_DELETED)); + ON_CALL(*event_interface_, OnDropChannel(_, _, _)) .WillByDefault(Return(CHANNEL_DELETED)); - ON_CALL(*event_interface_, OnDropChannel(_, _)) + ON_CALL(*event_interface_, OnFailChannel(_)) .WillByDefault(Return(CHANNEL_DELETED)); } @@ -880,11 +971,54 @@ class WebSocketChannelStreamTest : public WebSocketChannelTest { scoped_ptr<MockWebSocketStream> mock_stream_; }; +// Fixture for tests which test UTF-8 validation of sent Text frames via the +// EventInterface. +class WebSocketChannelSendUtf8Test + : public WebSocketChannelEventInterfaceTest { + public: + virtual void SetUp() { + set_stream(make_scoped_ptr(new WriteableFakeWebSocketStream)); + // For the purpose of the tests using this fixture, it doesn't matter + // whether these methods are called or not. + EXPECT_CALL(*event_interface_, OnAddChannelResponse(_, _, _)) + .Times(AnyNumber()); + EXPECT_CALL(*event_interface_, OnFlowControl(_)) + .Times(AnyNumber()); + } +}; + +// Fixture for tests which test use of receive quota from the renderer. +class WebSocketChannelFlowControlTest + : public WebSocketChannelEventInterfaceTest { + protected: + // Tests using this fixture should use CreateChannelAndConnectWithQuota() + // instead of CreateChannelAndConnectSuccessfully(). + void CreateChannelAndConnectWithQuota(int64 quota) { + CreateChannelAndConnect(); + channel_->SendFlowControl(quota); + connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); + } + + virtual void CreateChannelAndConnectSuccesfully() { NOTREACHED(); } +}; + +// Fixture for tests which test UTF-8 validation of received Text frames using a +// mock WebSocketStream. +class WebSocketChannelReceiveUtf8Test : public WebSocketChannelStreamTest { + public: + virtual void SetUp() { + // For the purpose of the tests using this fixture, it doesn't matter + // whether these methods are called or not. + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + } +}; + // Simple test that everything that should be passed to the creator function is // passed to the creator function. TEST_F(WebSocketChannelTest, EverythingIsPassedToTheCreatorFunction) { connect_data_.socket_url = GURL("ws://example.com/test"); - connect_data_.origin = GURL("http://example.com/test"); + connect_data_.origin = url::Origin("http://example.com"); connect_data_.requested_subprotocols.push_back("Sinbad"); CreateChannelAndConnect(); @@ -896,7 +1030,7 @@ TEST_F(WebSocketChannelTest, EverythingIsPassedToTheCreatorFunction) { EXPECT_EQ(connect_data_.socket_url, actual.socket_url); EXPECT_EQ(connect_data_.requested_subprotocols, actual.requested_subprotocols); - EXPECT_EQ(connect_data_.origin, actual.origin); + EXPECT_EQ(connect_data_.origin.string(), actual.origin.string()); } // Verify that calling SendFlowControl before the connection is established does @@ -915,8 +1049,7 @@ TEST_F(WebSocketChannelTest, SendFlowControlDuringHandshakeOkay) { TEST_F(WebSocketChannelDeletingTest, OnAddChannelResponseFail) { CreateChannelAndConnect(); EXPECT_TRUE(channel_); - connect_data_.creator.connect_delegate->OnFailure( - kWebSocketErrorNoStatusReceived); + connect_data_.creator.connect_delegate->OnFailure("bye"); EXPECT_EQ(NULL, channel_.get()); } @@ -964,7 +1097,7 @@ TEST_F(WebSocketChannelDeletingTest, OnFlowControlAfterConnect) { TEST_F(WebSocketChannelDeletingTest, OnFlowControlAfterSend) { set_stream(make_scoped_ptr(new WriteableFakeWebSocketStream)); // Avoid deleting the channel yet. - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL | EVENT_ON_DROP_CHANNEL; CreateChannelAndConnectSuccessfully(); ASSERT_TRUE(channel_); deleting_ = EVENT_ON_FLOW_CONTROL; @@ -1025,9 +1158,50 @@ TEST_F(WebSocketChannelDeletingTest, OnDropChannelReadError) { EXPECT_EQ(NULL, channel_.get()); } +TEST_F(WebSocketChannelDeletingTest, OnNotifyStartOpeningHandshakeError) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "HELLO"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::ASYNC, OK, frames); + set_stream(stream.Pass()); + deleting_ = EVENT_ON_START_OPENING_HANDSHAKE; + + CreateChannelAndConnectSuccessfully(); + ASSERT_TRUE(channel_); + channel_->OnStartOpeningHandshake(scoped_ptr<WebSocketHandshakeRequestInfo>( + new WebSocketHandshakeRequestInfo(GURL("http://www.example.com/"), + base::Time()))); + base::MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(NULL, channel_.get()); +} + +TEST_F(WebSocketChannelDeletingTest, OnNotifyFinishOpeningHandshakeError) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "HELLO"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::ASYNC, OK, frames); + set_stream(stream.Pass()); + deleting_ = EVENT_ON_FINISH_OPENING_HANDSHAKE; + + CreateChannelAndConnectSuccessfully(); + ASSERT_TRUE(channel_); + scoped_refptr<HttpResponseHeaders> response_headers( + new HttpResponseHeaders("")); + channel_->OnFinishOpeningHandshake(scoped_ptr<WebSocketHandshakeResponseInfo>( + new WebSocketHandshakeResponseInfo(GURL("http://www.example.com/"), + 200, + "OK", + response_headers, + base::Time()))); + base::MessageLoop::current()->RunUntilIdle(); + EXPECT_EQ(NULL, channel_.get()); +} + TEST_F(WebSocketChannelDeletingTest, FailChannelInSendFrame) { set_stream(make_scoped_ptr(new WriteableFakeWebSocketStream)); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); ASSERT_TRUE(channel_); channel_->SendFrame(true, @@ -1042,7 +1216,7 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelInOnReadDone) { stream->PrepareReadFramesError(ReadableFakeWebSocketStream::ASYNC, ERR_WS_PROTOCOL_ERROR); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); ASSERT_TRUE(channel_); base::MessageLoop::current()->RunUntilIdle(); @@ -1056,7 +1230,7 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToMaskedFrame) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, MASKED, "HELLO"}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1066,10 +1240,10 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToBadControlFrame) { scoped_ptr<ReadableFakeWebSocketStream> stream( new ReadableFakeWebSocketStream); static const InitFrame frames[] = { - {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, ""}}; + {FINAL_FRAME, 0xF, NOT_MASKED, ""}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1080,10 +1254,10 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToBadControlFrameNull) { scoped_ptr<ReadableFakeWebSocketStream> stream( new ReadableFakeWebSocketStream); static const InitFrame frames[] = { - {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, NULL}}; + {FINAL_FRAME, 0xF, NOT_MASKED, NULL}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1093,12 +1267,12 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToPongAfterClose) { scoped_ptr<ReadableFakeWebSocketStream> stream( new ReadableFakeWebSocketStream); static const InitFrame frames[] = { - {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, - CLOSE_DATA(NORMAL_CLOSURE, "Success")}, - {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, ""}}; + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, + CLOSE_DATA(NORMAL_CLOSURE, "Success")}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, ""}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1108,12 +1282,12 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToPongAfterCloseNull) { scoped_ptr<ReadableFakeWebSocketStream> stream( new ReadableFakeWebSocketStream); static const InitFrame frames[] = { - {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, - CLOSE_DATA(NORMAL_CLOSURE, "Success")}, - {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, NULL}}; + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, + CLOSE_DATA(NORMAL_CLOSURE, "Success")}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, NULL}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1125,7 +1299,7 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToUnknownOpCode) { static const InitFrame frames[] = {{FINAL_FRAME, 0x7, NOT_MASKED, ""}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1137,7 +1311,21 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToUnknownOpCodeNull) { static const InitFrame frames[] = {{FINAL_FRAME, 0x7, NOT_MASKED, NULL}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - deleting_ = EVENT_ON_DROP_CHANNEL; + deleting_ = EVENT_ON_FAIL_CHANNEL; + + CreateChannelAndConnectSuccessfully(); + EXPECT_EQ(NULL, channel_.get()); +} + +TEST_F(WebSocketChannelDeletingTest, FailChannelDueInvalidCloseReason) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, + NOT_MASKED, CLOSE_DATA(NORMAL_CLOSURE, "\xFF")}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + deleting_ = EVENT_ON_FAIL_CHANNEL; CreateChannelAndConnectSuccessfully(); EXPECT_EQ(NULL, channel_.get()); @@ -1145,7 +1333,7 @@ TEST_F(WebSocketChannelDeletingTest, FailChannelDueToUnknownOpCodeNull) { TEST_F(WebSocketChannelEventInterfaceTest, ConnectSuccessReported) { // false means success. - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, "")); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, "", "")); // OnFlowControl is always called immediately after connect to provide initial // quota to the renderer. EXPECT_CALL(*event_interface_, OnFlowControl(_)); @@ -1156,23 +1344,21 @@ TEST_F(WebSocketChannelEventInterfaceTest, ConnectSuccessReported) { } TEST_F(WebSocketChannelEventInterfaceTest, ConnectFailureReported) { - // true means failure. - EXPECT_CALL(*event_interface_, OnAddChannelResponse(true, "")); + EXPECT_CALL(*event_interface_, OnFailChannel("hello")); CreateChannelAndConnect(); - connect_data_.creator.connect_delegate->OnFailure( - kWebSocketErrorNoStatusReceived); + connect_data_.creator.connect_delegate->OnFailure("hello"); } TEST_F(WebSocketChannelEventInterfaceTest, NonWebSocketSchemeRejected) { - EXPECT_CALL(*event_interface_, OnAddChannelResponse(true, "")); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(true, "", "")); connect_data_.socket_url = GURL("http://www.google.com/"); CreateChannelAndConnect(); } TEST_F(WebSocketChannelEventInterfaceTest, ProtocolPassed) { - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, "Bob")); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, "Bob", "")); EXPECT_CALL(*event_interface_, OnFlowControl(_)); CreateChannelAndConnect(); @@ -1181,6 +1367,17 @@ TEST_F(WebSocketChannelEventInterfaceTest, ProtocolPassed) { scoped_ptr<WebSocketStream>(new FakeWebSocketStream("Bob", ""))); } +TEST_F(WebSocketChannelEventInterfaceTest, ExtensionsPassed) { + EXPECT_CALL(*event_interface_, + OnAddChannelResponse(false, "", "extension1, extension2")); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + + CreateChannelAndConnect(); + + connect_data_.creator.connect_delegate->OnSuccess(scoped_ptr<WebSocketStream>( + new FakeWebSocketStream("", "extension1, extension2"))); +} + // The first frames from the server can arrive together with the handshake, in // which case they will be available as soon as ReadFrames() is called the first // time. @@ -1193,7 +1390,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, DataLeftFromHandshake) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL( *event_interface_, @@ -1218,12 +1415,13 @@ TEST_F(WebSocketChannelEventInterfaceTest, CloseAfterHandshake) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, OnClosingHandshake()); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorInternalServerError, - "Internal Server Error")); + EXPECT_CALL( + *event_interface_, + OnDropChannel( + true, kWebSocketErrorInternalServerError, "Internal Server Error")); } CreateChannelAndConnectSuccessfully(); @@ -1239,10 +1437,10 @@ TEST_F(WebSocketChannelEventInterfaceTest, ConnectionCloseAfterHandshake) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, _)); + OnDropChannel(false, kWebSocketErrorAbnormalClosure, _)); } CreateChannelAndConnectSuccessfully(); @@ -1260,7 +1458,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, NormalAsyncRead) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(checkpoint, Call(1)); EXPECT_CALL( @@ -1290,7 +1488,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, AsyncThenSyncRead) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL( *event_interface_, @@ -1332,7 +1530,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, FragmentedMessage) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL( *event_interface_, @@ -1368,7 +1566,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, NullMessage) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, NULL}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL( *event_interface_, @@ -1376,29 +1574,6 @@ TEST_F(WebSocketChannelEventInterfaceTest, NullMessage) { CreateChannelAndConnectSuccessfully(); } -// A control frame is not permitted to be split into multiple frames. RFC6455 -// 5.5 "All control frames ... MUST NOT be fragmented." -TEST_F(WebSocketChannelEventInterfaceTest, MultiFrameControlMessageIsRejected) { - scoped_ptr<ReadableFakeWebSocketStream> stream( - new ReadableFakeWebSocketStream); - static const InitFrame frames[] = { - {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodePing, NOT_MASKED, "Pi"}, - {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, - NOT_MASKED, "ng"}}; - stream->PrepareReadFrames(ReadableFakeWebSocketStream::ASYNC, OK, frames); - set_stream(stream.Pass()); - { - InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); - EXPECT_CALL(*event_interface_, OnFlowControl(_)); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorProtocolError, _)); - } - - CreateChannelAndConnectSuccessfully(); - base::MessageLoop::current()->RunUntilIdle(); -} - // Connection closed by the remote host without a closing handshake. TEST_F(WebSocketChannelEventInterfaceTest, AsyncAbnormalClosure) { scoped_ptr<ReadableFakeWebSocketStream> stream( @@ -1408,10 +1583,10 @@ TEST_F(WebSocketChannelEventInterfaceTest, AsyncAbnormalClosure) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, _)); + OnDropChannel(false, kWebSocketErrorAbnormalClosure, _)); } CreateChannelAndConnectSuccessfully(); @@ -1427,10 +1602,10 @@ TEST_F(WebSocketChannelEventInterfaceTest, ConnectionReset) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, _)); + OnDropChannel(false, kWebSocketErrorAbnormalClosure, _)); } CreateChannelAndConnectSuccessfully(); @@ -1448,10 +1623,12 @@ TEST_F(WebSocketChannelEventInterfaceTest, MaskedFramesAreRejected) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorProtocolError, _)); + EXPECT_CALL( + *event_interface_, + OnFailChannel( + "A server must not mask any frames that it sends to the client.")); } CreateChannelAndConnectSuccessfully(); @@ -1469,10 +1646,10 @@ TEST_F(WebSocketChannelEventInterfaceTest, UnknownOpCodeIsRejected) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorProtocolError, _)); + OnFailChannel("Unrecognized frame opcode: 4")); } CreateChannelAndConnectSuccessfully(); @@ -1500,7 +1677,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, ControlFrameInDataMessage) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL( *event_interface_, @@ -1525,7 +1702,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, PongWithNullData) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, NOT_MASKED, NULL}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::ASYNC, OK, frames); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); CreateChannelAndConnectSuccessfully(); @@ -1545,10 +1722,12 @@ TEST_F(WebSocketChannelEventInterfaceTest, FrameAfterInvalidFrame) { set_stream(stream.Pass()); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorProtocolError, _)); + EXPECT_CALL( + *event_interface_, + OnFailChannel( + "A server must not mask any frames that it sends to the client.")); } CreateChannelAndConnectSuccessfully(); @@ -1561,7 +1740,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, SmallWriteDoesntUpdateQuota) { set_stream(make_scoped_ptr(new WriteableFakeWebSocketStream)); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); } @@ -1578,7 +1757,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, LargeWriteUpdatesQuota) { Checkpoint checkpoint; { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(checkpoint, Call(1)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); @@ -1599,7 +1778,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, QuotaReallyIsRefreshed) { Checkpoint checkpoint; { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(checkpoint, Call(1)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); @@ -1629,10 +1808,9 @@ TEST_F(WebSocketChannelEventInterfaceTest, WriteOverQuotaIsRejected) { set_stream(make_scoped_ptr(new WriteableFakeWebSocketStream)); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(kDefaultInitialQuota)); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketMuxErrorSendQuotaViolation, _)); + EXPECT_CALL(*event_interface_, OnFailChannel("Send quota exceeded")); } CreateChannelAndConnectSuccessfully(); @@ -1647,11 +1825,11 @@ TEST_F(WebSocketChannelEventInterfaceTest, FailedWrite) { Checkpoint checkpoint; { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(checkpoint, Call(1)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, _)); + OnDropChannel(false, kWebSocketErrorAbnormalClosure, _)); EXPECT_CALL(checkpoint, Call(2)); } @@ -1667,10 +1845,10 @@ TEST_F(WebSocketChannelEventInterfaceTest, SendCloseDropsChannel) { set_stream(make_scoped_ptr(new EchoeyFakeWebSocketStream)); { InSequence s; - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketNormalClosure, "Fred")); + OnDropChannel(true, kWebSocketNormalClosure, "Fred")); } CreateChannelAndConnectSuccessfully(); @@ -1679,15 +1857,25 @@ TEST_F(WebSocketChannelEventInterfaceTest, SendCloseDropsChannel) { base::MessageLoop::current()->RunUntilIdle(); } +// StartClosingHandshake() also works before connection completes, and calls +// OnDropChannel. +TEST_F(WebSocketChannelEventInterfaceTest, CloseDuringConnection) { + EXPECT_CALL(*event_interface_, + OnDropChannel(false, kWebSocketErrorAbnormalClosure, "")); + + CreateChannelAndConnect(); + channel_->StartClosingHandshake(kWebSocketNormalClosure, "Joe"); +} + // OnDropChannel() is only called once when a write() on the socket triggers a // connection reset. TEST_F(WebSocketChannelEventInterfaceTest, OnDropChannelCalledOnce) { set_stream(make_scoped_ptr(new ResetOnWriteFakeWebSocketStream)); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, "Abnormal Closure")) + OnDropChannel(false, kWebSocketErrorAbnormalClosure, "")) .Times(1); CreateChannelAndConnectSuccessfully(); @@ -1707,11 +1895,11 @@ TEST_F(WebSocketChannelEventInterfaceTest, CloseWithNoPayloadGivesStatus1005) { stream->PrepareReadFramesError(ReadableFakeWebSocketStream::SYNC, ERR_CONNECTION_CLOSED); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, OnClosingHandshake()); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorNoStatusReceived, _)); + OnDropChannel(true, kWebSocketErrorNoStatusReceived, _)); CreateChannelAndConnectSuccessfully(); } @@ -1727,28 +1915,27 @@ TEST_F(WebSocketChannelEventInterfaceTest, stream->PrepareReadFramesError(ReadableFakeWebSocketStream::SYNC, ERR_CONNECTION_CLOSED); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, OnClosingHandshake()); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorNoStatusReceived, _)); + OnDropChannel(true, kWebSocketErrorNoStatusReceived, _)); CreateChannelAndConnectSuccessfully(); } -// If ReadFrames() returns ERR_WS_PROTOCOL_ERROR, then -// kWebSocketErrorProtocolError must be sent to the renderer. +// If ReadFrames() returns ERR_WS_PROTOCOL_ERROR, then the connection must be +// failed. TEST_F(WebSocketChannelEventInterfaceTest, SyncProtocolErrorGivesStatus1002) { scoped_ptr<ReadableFakeWebSocketStream> stream( new ReadableFakeWebSocketStream); stream->PrepareReadFramesError(ReadableFakeWebSocketStream::SYNC, ERR_WS_PROTOCOL_ERROR); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorProtocolError, _)); + EXPECT_CALL(*event_interface_, OnFailChannel("Invalid frame header")); CreateChannelAndConnectSuccessfully(); } @@ -1760,16 +1947,195 @@ TEST_F(WebSocketChannelEventInterfaceTest, AsyncProtocolErrorGivesStatus1002) { stream->PrepareReadFramesError(ReadableFakeWebSocketStream::ASYNC, ERR_WS_PROTOCOL_ERROR); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); - EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorProtocolError, _)); + EXPECT_CALL(*event_interface_, OnFailChannel("Invalid frame header")); + + CreateChannelAndConnectSuccessfully(); + base::MessageLoop::current()->RunUntilIdle(); +} + +TEST_F(WebSocketChannelEventInterfaceTest, StartHandshakeRequest) { + { + InSequence s; + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL(*event_interface_, OnStartOpeningHandshakeCalled()); + } + + CreateChannelAndConnectSuccessfully(); + + scoped_ptr<WebSocketHandshakeRequestInfo> request_info( + new WebSocketHandshakeRequestInfo(GURL("ws://www.example.com/"), + base::Time())); + connect_data_.creator.connect_delegate->OnStartOpeningHandshake( + request_info.Pass()); + + base::MessageLoop::current()->RunUntilIdle(); +} + +TEST_F(WebSocketChannelEventInterfaceTest, FinishHandshakeRequest) { + { + InSequence s; + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL(*event_interface_, OnFinishOpeningHandshakeCalled()); + } CreateChannelAndConnectSuccessfully(); + + scoped_refptr<HttpResponseHeaders> response_headers( + new HttpResponseHeaders("")); + scoped_ptr<WebSocketHandshakeResponseInfo> response_info( + new WebSocketHandshakeResponseInfo(GURL("ws://www.example.com/"), + 200, + "OK", + response_headers, + base::Time())); + connect_data_.creator.connect_delegate->OnFinishOpeningHandshake( + response_info.Pass()); base::MessageLoop::current()->RunUntilIdle(); } +TEST_F(WebSocketChannelEventInterfaceTest, FailJustAfterHandshake) { + { + InSequence s; + EXPECT_CALL(*event_interface_, OnStartOpeningHandshakeCalled()); + EXPECT_CALL(*event_interface_, OnFinishOpeningHandshakeCalled()); + EXPECT_CALL(*event_interface_, OnFailChannel("bye")); + } + + CreateChannelAndConnect(); + + WebSocketStream::ConnectDelegate* connect_delegate = + connect_data_.creator.connect_delegate.get(); + GURL url("ws://www.example.com/"); + scoped_ptr<WebSocketHandshakeRequestInfo> request_info( + new WebSocketHandshakeRequestInfo(url, base::Time())); + scoped_refptr<HttpResponseHeaders> response_headers( + new HttpResponseHeaders("")); + scoped_ptr<WebSocketHandshakeResponseInfo> response_info( + new WebSocketHandshakeResponseInfo(url, + 200, + "OK", + response_headers, + base::Time())); + connect_delegate->OnStartOpeningHandshake(request_info.Pass()); + connect_delegate->OnFinishOpeningHandshake(response_info.Pass()); + + connect_delegate->OnFailure("bye"); + base::MessageLoop::current()->RunUntilIdle(); +} + +// Any frame after close is invalid. This test uses a Text frame. See also +// test "PingAfterCloseIfRejected". +TEST_F(WebSocketChannelEventInterfaceTest, DataAfterCloseIsRejected) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, + CLOSE_DATA(NORMAL_CLOSURE, "OK")}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "Payload"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + + { + InSequence s; + EXPECT_CALL(*event_interface_, OnClosingHandshake()); + EXPECT_CALL(*event_interface_, + OnFailChannel("Data frame received after close")); + } + + CreateChannelAndConnectSuccessfully(); +} + +// A Close frame with a one-byte payload elicits a specific console error +// message. +TEST_F(WebSocketChannelEventInterfaceTest, OneByteClosePayloadMessage) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, "\x03"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL( + *event_interface_, + OnFailChannel( + "Received a broken close frame containing an invalid size body.")); + + CreateChannelAndConnectSuccessfully(); +} + +// A Close frame with a reserved status code also elicits a specific console +// error message. +TEST_F(WebSocketChannelEventInterfaceTest, ClosePayloadReservedStatusMessage) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, + NOT_MASKED, CLOSE_DATA(ABNORMAL_CLOSURE, "Not valid on wire")}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL( + *event_interface_, + OnFailChannel( + "Received a broken close frame containing a reserved status code.")); + + CreateChannelAndConnectSuccessfully(); +} + +// A Close frame with invalid UTF-8 also elicits a specific console error +// message. +TEST_F(WebSocketChannelEventInterfaceTest, ClosePayloadInvalidReason) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, + NOT_MASKED, CLOSE_DATA(NORMAL_CLOSURE, "\xFF")}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL( + *event_interface_, + OnFailChannel( + "Received a broken close frame containing invalid UTF-8.")); + + CreateChannelAndConnectSuccessfully(); +} + +// The reserved bits must all be clear on received frames. Extensions should +// clear the bits when they are set correctly before passing on the frame. +TEST_F(WebSocketChannelEventInterfaceTest, ReservedBitsMustNotBeSet) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, + NOT_MASKED, "sakana"}}; + // It is not worth adding support for reserved bits to InitFrame just for this + // one test, so set the bit manually. + ScopedVector<WebSocketFrame> raw_frames = CreateFrameVector(frames); + raw_frames[0]->header.reserved1 = true; + stream->PrepareRawReadFrames( + ReadableFakeWebSocketStream::SYNC, OK, raw_frames.Pass()); + set_stream(stream.Pass()); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL(*event_interface_, + OnFailChannel( + "One or more reserved bits are on: reserved1 = 1, " + "reserved2 = 0, reserved3 = 0")); + + CreateChannelAndConnectSuccessfully(); +} + // The closing handshake times out and sends an OnDropChannel event if no // response to the client Close message is received. TEST_F(WebSocketChannelEventInterfaceTest, @@ -1779,7 +2145,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, stream->PrepareReadFramesError(ReadableFakeWebSocketStream::SYNC, ERR_IO_PENDING); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); // This checkpoint object verifies that the OnDropChannel message comes after // the timeout. @@ -1789,7 +2155,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, InSequence s; EXPECT_CALL(checkpoint, Call(1)); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, _)) + OnDropChannel(false, kWebSocketErrorAbnormalClosure, _)) .WillOnce(InvokeClosureReturnDeleted(completion.closure())); } CreateChannelAndConnectSuccessfully(); @@ -1814,7 +2180,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, NOT_MASKED, CLOSE_DATA(NORMAL_CLOSURE, "OK")}}; stream->PrepareReadFrames(ReadableFakeWebSocketStream::ASYNC, OK, frames); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); Checkpoint checkpoint; TestClosure completion; @@ -1823,7 +2189,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, EXPECT_CALL(checkpoint, Call(1)); EXPECT_CALL(*event_interface_, OnClosingHandshake()); EXPECT_CALL(*event_interface_, - OnDropChannel(kWebSocketErrorAbnormalClosure, _)) + OnDropChannel(false, kWebSocketErrorAbnormalClosure, _)) .WillOnce(InvokeClosureReturnDeleted(completion.closure())); } CreateChannelAndConnectSuccessfully(); @@ -1833,6 +2199,280 @@ TEST_F(WebSocketChannelEventInterfaceTest, completion.WaitForResult(); } +// The renderer should provide us with some quota immediately, and then +// WebSocketChannel calls ReadFrames as soon as the stream is available. +TEST_F(WebSocketChannelStreamTest, FlowControlEarly) { + Checkpoint checkpoint; + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + { + InSequence s; + EXPECT_CALL(checkpoint, Call(1)); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(checkpoint, Call(2)); + } + + set_stream(mock_stream_.Pass()); + CreateChannelAndConnect(); + channel_->SendFlowControl(kPlentyOfQuota); + checkpoint.Call(1); + connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); + checkpoint.Call(2); +} + +// If for some reason the connect succeeds before the renderer sends us quota, +// we shouldn't call ReadFrames() immediately. +// TODO(ricea): Actually we should call ReadFrames() with a small limit so we +// can still handle control frames. This should be done once we have any API to +// expose quota to the lower levels. +TEST_F(WebSocketChannelStreamTest, FlowControlLate) { + Checkpoint checkpoint; + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + { + InSequence s; + EXPECT_CALL(checkpoint, Call(1)); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(Return(ERR_IO_PENDING)); + EXPECT_CALL(checkpoint, Call(2)); + } + + set_stream(mock_stream_.Pass()); + CreateChannelAndConnect(); + connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); + checkpoint.Call(1); + channel_->SendFlowControl(kPlentyOfQuota); + checkpoint.Call(2); +} + +// We should stop calling ReadFrames() when all quota is used. +TEST_F(WebSocketChannelStreamTest, FlowControlStopsReadFrames) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "FOUR"}}; + + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)); + + set_stream(mock_stream_.Pass()); + CreateChannelAndConnect(); + channel_->SendFlowControl(4); + connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); +} + +// Providing extra quota causes ReadFrames() to be called again. +TEST_F(WebSocketChannelStreamTest, FlowControlStartsWithMoreQuota) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "FOUR"}}; + Checkpoint checkpoint; + + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + { + InSequence s; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)); + EXPECT_CALL(checkpoint, Call(1)); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(Return(ERR_IO_PENDING)); + } + + set_stream(mock_stream_.Pass()); + CreateChannelAndConnect(); + channel_->SendFlowControl(4); + connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); + checkpoint.Call(1); + channel_->SendFlowControl(4); +} + +// ReadFrames() isn't called again until all pending data has been passed to +// the renderer. +TEST_F(WebSocketChannelStreamTest, ReadFramesNotCalledUntilQuotaAvailable) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "FOUR"}}; + Checkpoint checkpoint; + + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + { + InSequence s; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)); + EXPECT_CALL(checkpoint, Call(1)); + EXPECT_CALL(checkpoint, Call(2)); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(Return(ERR_IO_PENDING)); + } + + set_stream(mock_stream_.Pass()); + CreateChannelAndConnect(); + channel_->SendFlowControl(2); + connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); + checkpoint.Call(1); + channel_->SendFlowControl(2); + checkpoint.Call(2); + channel_->SendFlowControl(2); +} + +// A message that needs to be split into frames to fit within quota should +// maintain correct semantics. +TEST_F(WebSocketChannelFlowControlTest, SingleFrameMessageSplitSync) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "FOUR"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + { + InSequence s; + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL( + *event_interface_, + OnDataFrame(false, WebSocketFrameHeader::kOpCodeText, AsVector("FO"))); + EXPECT_CALL( + *event_interface_, + OnDataFrame( + false, WebSocketFrameHeader::kOpCodeContinuation, AsVector("U"))); + EXPECT_CALL( + *event_interface_, + OnDataFrame( + true, WebSocketFrameHeader::kOpCodeContinuation, AsVector("R"))); + } + + CreateChannelAndConnectWithQuota(2); + channel_->SendFlowControl(1); + channel_->SendFlowControl(1); +} + +// The code path for async messages is slightly different, so test it +// separately. +TEST_F(WebSocketChannelFlowControlTest, SingleFrameMessageSplitAsync) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "FOUR"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::ASYNC, OK, frames); + set_stream(stream.Pass()); + Checkpoint checkpoint; + { + InSequence s; + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL(checkpoint, Call(1)); + EXPECT_CALL( + *event_interface_, + OnDataFrame(false, WebSocketFrameHeader::kOpCodeText, AsVector("FO"))); + EXPECT_CALL(checkpoint, Call(2)); + EXPECT_CALL( + *event_interface_, + OnDataFrame( + false, WebSocketFrameHeader::kOpCodeContinuation, AsVector("U"))); + EXPECT_CALL(checkpoint, Call(3)); + EXPECT_CALL( + *event_interface_, + OnDataFrame( + true, WebSocketFrameHeader::kOpCodeContinuation, AsVector("R"))); + } + + CreateChannelAndConnectWithQuota(2); + checkpoint.Call(1); + base::MessageLoop::current()->RunUntilIdle(); + checkpoint.Call(2); + channel_->SendFlowControl(1); + checkpoint.Call(3); + channel_->SendFlowControl(1); +} + +// A message split into multiple frames which is further split due to quota +// restrictions should stil be correct. +// TODO(ricea): The message ends up split into more frames than are strictly +// necessary. The complexity/performance tradeoffs here need further +// examination. +TEST_F(WebSocketChannelFlowControlTest, MultipleFrameSplit) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, + NOT_MASKED, "FIRST FRAME IS 25 BYTES. "}, + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "SECOND FRAME IS 26 BYTES. "}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "FINAL FRAME IS 24 BYTES."}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + { + InSequence s; + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL(*event_interface_, + OnDataFrame(false, + WebSocketFrameHeader::kOpCodeText, + AsVector("FIRST FRAME IS"))); + EXPECT_CALL(*event_interface_, + OnDataFrame(false, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector(" 25 BYTES. "))); + EXPECT_CALL(*event_interface_, + OnDataFrame(false, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector("SECOND FRAME IS 26 BYTES. "))); + EXPECT_CALL(*event_interface_, + OnDataFrame(false, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector("FINAL "))); + EXPECT_CALL(*event_interface_, + OnDataFrame(true, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector("FRAME IS 24 BYTES."))); + } + CreateChannelAndConnectWithQuota(14); + channel_->SendFlowControl(43); + channel_->SendFlowControl(32); +} + +// An empty message handled when we are out of quota must not be delivered +// out-of-order with respect to other messages. +TEST_F(WebSocketChannelFlowControlTest, EmptyMessageNoQuota) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, + NOT_MASKED, "FIRST MESSAGE"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, + NOT_MASKED, NULL}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, + NOT_MASKED, "THIRD MESSAGE"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + { + InSequence s; + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(_)); + EXPECT_CALL(*event_interface_, + OnDataFrame(false, + WebSocketFrameHeader::kOpCodeText, + AsVector("FIRST "))); + EXPECT_CALL(*event_interface_, + OnDataFrame(true, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector("MESSAGE"))); + EXPECT_CALL(*event_interface_, + OnDataFrame(true, + WebSocketFrameHeader::kOpCodeText, + AsVector(""))); + EXPECT_CALL(*event_interface_, + OnDataFrame(true, + WebSocketFrameHeader::kOpCodeText, + AsVector("THIRD MESSAGE"))); + } + + CreateChannelAndConnectWithQuota(6); + channel_->SendFlowControl(128); +} + // RFC6455 5.1 "a client MUST mask all frames that it sends to the server". // WebSocketChannel actually only sets the mask bit in the header, it doesn't // perform masking itself (not all transports actually use masking). @@ -1841,6 +2481,7 @@ TEST_F(WebSocketChannelStreamTest, SentFramesAreMasked) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, MASKED, "NEEDS MASKING"}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)).WillOnce(Return(ERR_IO_PENDING)); EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) .WillOnce(Return(OK)); @@ -1857,6 +2498,7 @@ TEST_F(WebSocketChannelStreamTest, NothingIsSentAfterClose) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, CLOSE_DATA(NORMAL_CLOSURE, "Success")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)).WillOnce(Return(ERR_IO_PENDING)); EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) .WillOnce(Return(OK)); @@ -1877,6 +2519,7 @@ TEST_F(WebSocketChannelStreamTest, CloseIsEchoedBack) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, CLOSE_DATA(NORMAL_CLOSURE, "Close")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -1901,11 +2544,14 @@ TEST_F(WebSocketChannelStreamTest, CloseOnlySentOnce) { CompletionCallback read_callback; ScopedVector<WebSocketFrame>* frames = NULL; + // These are not interesting. + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + // Use a checkpoint to make the ordering of events clearer. Checkpoint checkpoint; { InSequence s; - EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(DoAll(SaveArg<0>(&frames), SaveArg<1>(&read_callback), @@ -1935,9 +2581,10 @@ TEST_F(WebSocketChannelStreamTest, CloseOnlySentOnce) { TEST_F(WebSocketChannelStreamTest, InvalidCloseStatusCodeNotSent) { static const InitFrame expected[] = { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, - MASKED, CLOSE_DATA(SERVER_ERROR, "Internal Error")}}; + MASKED, CLOSE_DATA(SERVER_ERROR, "")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(Return(ERR_IO_PENDING)); @@ -1952,9 +2599,10 @@ TEST_F(WebSocketChannelStreamTest, InvalidCloseStatusCodeNotSent) { TEST_F(WebSocketChannelStreamTest, LongCloseReasonNotSent) { static const InitFrame expected[] = { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, - MASKED, CLOSE_DATA(SERVER_ERROR, "Internal Error")}}; + MASKED, CLOSE_DATA(SERVER_ERROR, "")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(Return(ERR_IO_PENDING)); @@ -1975,6 +2623,7 @@ TEST_F(WebSocketChannelStreamTest, Code1005IsNotEchoed) { static const InitFrame expected[] = { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, ""}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -1990,6 +2639,7 @@ TEST_F(WebSocketChannelStreamTest, Code1005IsNotEchoedNull) { static const InitFrame expected[] = { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, ""}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -1999,6 +2649,28 @@ TEST_F(WebSocketChannelStreamTest, Code1005IsNotEchoedNull) { CreateChannelAndConnectSuccessfully(); } +// Receiving an invalid UTF-8 payload in a Close frame causes us to fail the +// connection. +TEST_F(WebSocketChannelStreamTest, CloseFrameInvalidUtf8) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, + NOT_MASKED, CLOSE_DATA(NORMAL_CLOSURE, "\xFF")}}; + static const InitFrame expected[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, + MASKED, CLOSE_DATA(PROTOCOL_ERROR, "Invalid UTF-8 in Close frame")}}; + + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()); + + CreateChannelAndConnectSuccessfully(); +} + // RFC6455 5.5.2 "Upon receipt of a Ping frame, an endpoint MUST send a Pong // frame in response" // 5.5.3 "A Pong frame sent in response to a Ping frame must have identical @@ -2012,6 +2684,7 @@ TEST_F(WebSocketChannelStreamTest, PingRepliedWithPong) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, MASKED, "Application data"}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -2021,14 +2694,15 @@ TEST_F(WebSocketChannelStreamTest, PingRepliedWithPong) { CreateChannelAndConnectSuccessfully(); } -// A ping with a NULL payload should be responded to with a Pong with an empty +// A ping with a NULL payload should be responded to with a Pong with a NULL // payload. -TEST_F(WebSocketChannelStreamTest, NullPingRepliedWithEmptyPong) { +TEST_F(WebSocketChannelStreamTest, NullPingRepliedWithNullPong) { static const InitFrame frames[] = { {FINAL_FRAME, WebSocketFrameHeader::kOpCodePing, NOT_MASKED, NULL}}; static const InitFrame expected[] = { - {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, MASKED, ""}}; + {FINAL_FRAME, WebSocketFrameHeader::kOpCodePong, MASKED, NULL}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -2053,6 +2727,7 @@ TEST_F(WebSocketChannelStreamTest, PongInTheMiddleOfDataMessage) { ScopedVector<WebSocketFrame>* read_frames; CompletionCallback read_callback; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(DoAll(SaveArg<0>(&read_frames), SaveArg<1>(&read_callback), @@ -2089,6 +2764,7 @@ TEST_F(WebSocketChannelStreamTest, WriteFramesOneAtATime) { Checkpoint checkpoint; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)).WillOnce(Return(ERR_IO_PENDING)); { InSequence s; @@ -2127,6 +2803,7 @@ TEST_F(WebSocketChannelStreamTest, WaitingMessagesAreBatched) { CompletionCallback write_callback; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)).WillOnce(Return(ERR_IO_PENDING)); { InSequence s; @@ -2145,17 +2822,14 @@ TEST_F(WebSocketChannelStreamTest, WaitingMessagesAreBatched) { write_callback.Run(OK); } -// When the renderer sends more on a channel than it has quota for, then we send -// a kWebSocketMuxErrorSendQuotaViolation status code (from the draft websocket -// mux specification) back to the renderer. This should not be sent to the -// remote server, which may not even implement the mux specification, and could -// even be using a different extension which uses that code to mean something -// else. -TEST_F(WebSocketChannelStreamTest, MuxErrorIsNotSentToStream) { +// When the renderer sends more on a channel than it has quota for, we send the +// remote server a kWebSocketErrorGoingAway error code. +TEST_F(WebSocketChannelStreamTest, SendGoingAwayOnRendererQuotaExceeded) { static const InitFrame expected[] = { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, - MASKED, CLOSE_DATA(GOING_AWAY, "Internal Error")}}; + MASKED, CLOSE_DATA(GOING_AWAY, "")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)).WillOnce(Return(ERR_IO_PENDING)); EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) .WillOnce(Return(OK)); @@ -2174,6 +2848,7 @@ TEST_F(WebSocketChannelStreamTest, WrittenBinaryFramesAre8BitClean) { ScopedVector<WebSocketFrame>* frames = NULL; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)).WillOnce(Return(ERR_IO_PENDING)); EXPECT_CALL(*mock_stream_, WriteFrames(_, _)) .WillOnce(DoAll(SaveArg<0>(&frames), Return(ERR_IO_PENDING))); @@ -2207,7 +2882,7 @@ TEST_F(WebSocketChannelEventInterfaceTest, ReadBinaryFramesAre8BitClean) { stream->PrepareRawReadFrames( ReadableFakeWebSocketStream::SYNC, OK, frames.Pass()); set_stream(stream.Pass()); - EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _)); + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); EXPECT_CALL(*event_interface_, OnFlowControl(_)); EXPECT_CALL(*event_interface_, OnDataFrame(true, @@ -2218,6 +2893,376 @@ TEST_F(WebSocketChannelEventInterfaceTest, ReadBinaryFramesAre8BitClean) { CreateChannelAndConnectSuccessfully(); } +// Invalid UTF-8 is not permitted in Text frames. +TEST_F(WebSocketChannelSendUtf8Test, InvalidUtf8Rejected) { + EXPECT_CALL( + *event_interface_, + OnFailChannel("Browser sent a text frame containing invalid UTF-8")); + + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + true, WebSocketFrameHeader::kOpCodeText, AsVector("\xff")); +} + +// A Text message cannot end with a partial UTF-8 character. +TEST_F(WebSocketChannelSendUtf8Test, IncompleteCharacterInFinalFrame) { + EXPECT_CALL( + *event_interface_, + OnFailChannel("Browser sent a text frame containing invalid UTF-8")); + + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + true, WebSocketFrameHeader::kOpCodeText, AsVector("\xc2")); +} + +// A non-final Text frame may end with a partial UTF-8 character (compare to +// previous test). +TEST_F(WebSocketChannelSendUtf8Test, IncompleteCharacterInNonFinalFrame) { + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeText, AsVector("\xc2")); +} + +// UTF-8 parsing context must be retained between frames. +TEST_F(WebSocketChannelSendUtf8Test, ValidCharacterSplitBetweenFrames) { + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeText, AsVector("\xf1")); + channel_->SendFrame(true, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector("\x80\xa0\xbf")); +} + +// Similarly, an invalid character should be detected even if split. +TEST_F(WebSocketChannelSendUtf8Test, InvalidCharacterSplit) { + EXPECT_CALL( + *event_interface_, + OnFailChannel("Browser sent a text frame containing invalid UTF-8")); + + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeText, AsVector("\xe1")); + channel_->SendFrame(true, + WebSocketFrameHeader::kOpCodeContinuation, + AsVector("\x80\xa0\xbf")); +} + +// An invalid character must be detected in continuation frames. +TEST_F(WebSocketChannelSendUtf8Test, InvalidByteInContinuation) { + EXPECT_CALL( + *event_interface_, + OnFailChannel("Browser sent a text frame containing invalid UTF-8")); + + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeText, AsVector("foo")); + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeContinuation, AsVector("bar")); + channel_->SendFrame( + true, WebSocketFrameHeader::kOpCodeContinuation, AsVector("\xff")); +} + +// However, continuation frames of a Binary frame will not be tested for UTF-8 +// validity. +TEST_F(WebSocketChannelSendUtf8Test, BinaryContinuationNotChecked) { + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeBinary, AsVector("foo")); + channel_->SendFrame( + false, WebSocketFrameHeader::kOpCodeContinuation, AsVector("bar")); + channel_->SendFrame( + true, WebSocketFrameHeader::kOpCodeContinuation, AsVector("\xff")); +} + +// Multiple text messages can be validated without the validation state getting +// confused. +TEST_F(WebSocketChannelSendUtf8Test, ValidateMultipleTextMessages) { + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame(true, WebSocketFrameHeader::kOpCodeText, AsVector("foo")); + channel_->SendFrame(true, WebSocketFrameHeader::kOpCodeText, AsVector("bar")); +} + +// UTF-8 validation is enforced on received Text frames. +TEST_F(WebSocketChannelEventInterfaceTest, ReceivedInvalidUtf8) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xff"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(kDefaultInitialQuota)); + EXPECT_CALL(*event_interface_, + OnFailChannel("Could not decode a text frame as UTF-8.")); + + CreateChannelAndConnectSuccessfully(); + base::MessageLoop::current()->RunUntilIdle(); +} + +// Invalid UTF-8 is not sent over the network. +TEST_F(WebSocketChannelStreamTest, InvalidUtf8TextFrameNotSent) { + static const InitFrame expected[] = {{FINAL_FRAME, + WebSocketFrameHeader::kOpCodeClose, + MASKED, CLOSE_DATA(GOING_AWAY, "")}}; + EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()).Times(1); + + CreateChannelAndConnectSuccessfully(); + + channel_->SendFrame( + true, WebSocketFrameHeader::kOpCodeText, AsVector("\xff")); +} + +// The rest of the tests for receiving invalid UTF-8 test the communication with +// the server. Since there is only one code path, it would be redundant to +// perform the same tests on the EventInterface as well. + +// If invalid UTF-8 is received in a Text frame, the connection is failed. +TEST_F(WebSocketChannelReceiveUtf8Test, InvalidTextFrameRejected) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xff"}}; + static const InitFrame expected[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, + CLOSE_DATA(PROTOCOL_ERROR, "Invalid UTF-8 in text frame")}}; + { + InSequence s; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()).Times(1); + } + + CreateChannelAndConnectSuccessfully(); +} + +// A received Text message is not permitted to end with a partial UTF-8 +// character. +TEST_F(WebSocketChannelReceiveUtf8Test, IncompleteCharacterReceived) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xc2"}}; + static const InitFrame expected[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, + CLOSE_DATA(PROTOCOL_ERROR, "Invalid UTF-8 in text frame")}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()).Times(1); + + CreateChannelAndConnectSuccessfully(); +} + +// However, a non-final Text frame may end with a partial UTF-8 character. +TEST_F(WebSocketChannelReceiveUtf8Test, IncompleteCharacterIncompleteMessage) { + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xc2"}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + + CreateChannelAndConnectSuccessfully(); +} + +// However, it will become an error if it is followed by an empty final frame. +TEST_F(WebSocketChannelReceiveUtf8Test, TricksyIncompleteCharacter) { + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xc2"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, NOT_MASKED, ""}}; + static const InitFrame expected[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, + CLOSE_DATA(PROTOCOL_ERROR, "Invalid UTF-8 in text frame")}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()).Times(1); + + CreateChannelAndConnectSuccessfully(); +} + +// UTF-8 parsing context must be retained between received frames of the same +// message. +TEST_F(WebSocketChannelReceiveUtf8Test, ReceivedParsingContextRetained) { + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xf1"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "\x80\xa0\xbf"}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + + CreateChannelAndConnectSuccessfully(); +} + +// An invalid character must be detected even if split between frames. +TEST_F(WebSocketChannelReceiveUtf8Test, SplitInvalidCharacterReceived) { + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "\xe1"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "\x80\xa0\xbf"}}; + static const InitFrame expected[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, + CLOSE_DATA(PROTOCOL_ERROR, "Invalid UTF-8 in text frame")}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()).Times(1); + + CreateChannelAndConnectSuccessfully(); +} + +// An invalid character received in a continuation frame must be detected. +TEST_F(WebSocketChannelReceiveUtf8Test, InvalidReceivedIncontinuation) { + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "foo"}, + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "bar"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "\xff"}}; + static const InitFrame expected[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, + CLOSE_DATA(PROTOCOL_ERROR, "Invalid UTF-8 in text frame")}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) + .WillOnce(Return(OK)); + EXPECT_CALL(*mock_stream_, Close()).Times(1); + + CreateChannelAndConnectSuccessfully(); +} + +// Continuations of binary frames must not be tested for UTF-8 validity. +TEST_F(WebSocketChannelReceiveUtf8Test, ReceivedBinaryNotUtf8Tested) { + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeBinary, NOT_MASKED, "foo"}, + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "bar"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "\xff"}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + + CreateChannelAndConnectSuccessfully(); +} + +// Multiple Text messages can be validated. +TEST_F(WebSocketChannelReceiveUtf8Test, ValidateMultipleReceived) { + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "foo"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, "bar"}}; + EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) + .WillOnce(ReturnFrames(&frames)) + .WillRepeatedly(Return(ERR_IO_PENDING)); + + CreateChannelAndConnectSuccessfully(); +} + +// A new data message cannot start in the middle of another data message. +TEST_F(WebSocketChannelEventInterfaceTest, BogusContinuation) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeBinary, + NOT_MASKED, "frame1"}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, + NOT_MASKED, "frame2"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(kDefaultInitialQuota)); + EXPECT_CALL( + *event_interface_, + OnDataFrame( + false, WebSocketFrameHeader::kOpCodeBinary, AsVector("frame1"))); + EXPECT_CALL( + *event_interface_, + OnFailChannel( + "Received start of new message but previous message is unfinished.")); + + CreateChannelAndConnectSuccessfully(); +} + +// A new message cannot start with a Continuation frame. +TEST_F(WebSocketChannelEventInterfaceTest, MessageStartingWithContinuation) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, "continuation"}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(kDefaultInitialQuota)); + EXPECT_CALL(*event_interface_, + OnFailChannel("Received unexpected continuation frame.")); + + CreateChannelAndConnectSuccessfully(); +} + +// A frame passed to the renderer must be either non-empty or have the final bit +// set. +TEST_F(WebSocketChannelEventInterfaceTest, DataFramesNonEmptyOrFinal) { + scoped_ptr<ReadableFakeWebSocketStream> stream( + new ReadableFakeWebSocketStream); + static const InitFrame frames[] = { + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeText, NOT_MASKED, ""}, + {NOT_FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, + NOT_MASKED, ""}, + {FINAL_FRAME, WebSocketFrameHeader::kOpCodeContinuation, NOT_MASKED, ""}}; + stream->PrepareReadFrames(ReadableFakeWebSocketStream::SYNC, OK, frames); + set_stream(stream.Pass()); + + EXPECT_CALL(*event_interface_, OnAddChannelResponse(false, _, _)); + EXPECT_CALL(*event_interface_, OnFlowControl(kDefaultInitialQuota)); + EXPECT_CALL( + *event_interface_, + OnDataFrame(true, WebSocketFrameHeader::kOpCodeText, AsVector(""))); + + CreateChannelAndConnectSuccessfully(); +} + +// Calls to OnSSLCertificateError() must be passed through to the event +// interface with the correct URL attached. +TEST_F(WebSocketChannelEventInterfaceTest, OnSSLCertificateErrorCalled) { + const GURL wss_url("wss://example.com/sslerror"); + connect_data_.socket_url = wss_url; + const SSLInfo ssl_info; + const bool fatal = true; + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> fake_callbacks( + new FakeSSLErrorCallbacks); + + EXPECT_CALL(*event_interface_, + OnSSLCertificateErrorCalled(NotNull(), wss_url, _, fatal)); + + CreateChannelAndConnect(); + connect_data_.creator.connect_delegate->OnSSLCertificateError( + fake_callbacks.Pass(), ssl_info, fatal); +} + // If we receive another frame after Close, it is not valid. It is not // completely clear what behaviour is required from the standard in this case, // but the current implementation fails the connection. Since a Close has @@ -2232,6 +3277,7 @@ TEST_F(WebSocketChannelStreamTest, PingAfterCloseIsRejected) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, CLOSE_DATA(NORMAL_CLOSURE, "OK")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -2256,6 +3302,7 @@ TEST_F(WebSocketChannelStreamTest, ProtocolError) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, CLOSE_DATA(PROTOCOL_ERROR, "WebSocket Protocol Error")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(Return(ERR_WS_PROTOCOL_ERROR)); EXPECT_CALL(*mock_stream_, WriteFrames(EqualsFrames(expected), _)) @@ -2273,6 +3320,7 @@ class WebSocketChannelStreamTimeoutTest : public WebSocketChannelStreamTest { virtual void CreateChannelAndConnectSuccessfully() OVERRIDE { set_stream(mock_stream_.Pass()); CreateChannelAndConnect(); + channel_->SendFlowControl(kPlentyOfQuota); channel_->SetClosingHandshakeTimeoutForTesting( TimeDelta::FromMilliseconds(kVeryTinyTimeoutMillis)); connect_data_.creator.connect_delegate->OnSuccess(stream_.Pass()); @@ -2292,6 +3340,7 @@ TEST_F(WebSocketChannelStreamTimeoutTest, ServerInitiatedCloseTimesOut) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, CLOSE_DATA(NORMAL_CLOSURE, "OK")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillOnce(ReturnFrames(&frames)) .WillRepeatedly(Return(ERR_IO_PENDING)); @@ -2320,6 +3369,7 @@ TEST_F(WebSocketChannelStreamTimeoutTest, ClientInitiatedCloseTimesOut) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, MASKED, CLOSE_DATA(NORMAL_CLOSURE, "OK")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); EXPECT_CALL(*mock_stream_, ReadFrames(_, _)) .WillRepeatedly(Return(ERR_IO_PENDING)); TestClosure completion; @@ -2348,6 +3398,7 @@ TEST_F(WebSocketChannelStreamTimeoutTest, ConnectionCloseTimesOut) { {FINAL_FRAME, WebSocketFrameHeader::kOpCodeClose, NOT_MASKED, CLOSE_DATA(NORMAL_CLOSURE, "OK")}}; EXPECT_CALL(*mock_stream_, GetSubProtocol()).Times(AnyNumber()); + EXPECT_CALL(*mock_stream_, GetExtensions()).Times(AnyNumber()); TestClosure completion; ScopedVector<WebSocketFrame>* read_frames = NULL; CompletionCallback read_callback; diff --git a/chromium/net/websockets/websocket_deflate_stream.cc b/chromium/net/websockets/websocket_deflate_stream.cc index 601670d373b..38de5fa2eca 100644 --- a/chromium/net/websockets/websocket_deflate_stream.cc +++ b/chromium/net/websockets/websocket_deflate_stream.cc @@ -36,6 +36,7 @@ const size_t kChunkSize = 4 * 1024; WebSocketDeflateStream::WebSocketDeflateStream( scoped_ptr<WebSocketStream> stream, WebSocketDeflater::ContextTakeOverMode mode, + int client_window_bits, scoped_ptr<WebSocketDeflatePredictor> predictor) : stream_(stream.Pass()), deflater_(mode), @@ -46,7 +47,9 @@ WebSocketDeflateStream::WebSocketDeflateStream( current_writing_opcode_(WebSocketFrameHeader::kOpCodeText), predictor_(predictor.Pass()) { DCHECK(stream_); - deflater_.Initialize(kWindowBits); + DCHECK_GE(client_window_bits, 8); + DCHECK_LE(client_window_bits, 15); + deflater_.Initialize(client_window_bits); inflater_.Initialize(kWindowBits); } @@ -54,16 +57,18 @@ WebSocketDeflateStream::~WebSocketDeflateStream() {} int WebSocketDeflateStream::ReadFrames(ScopedVector<WebSocketFrame>* frames, const CompletionCallback& callback) { - CompletionCallback callback_to_pass = + int result = stream_->ReadFrames( + frames, base::Bind(&WebSocketDeflateStream::OnReadComplete, base::Unretained(this), base::Unretained(frames), - callback); - int result = stream_->ReadFrames(frames, callback_to_pass); + callback)); if (result < 0) return result; DCHECK_EQ(OK, result); - return InflateAndReadIfNecessary(frames, callback_to_pass); + DCHECK(!frames->empty()); + + return InflateAndReadIfNecessary(frames, callback); } int WebSocketDeflateStream::WriteFrames(ScopedVector<WebSocketFrame>* frames, @@ -274,6 +279,11 @@ int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) { for (size_t i = 0; i < frames_passed.size(); ++i) { scoped_ptr<WebSocketFrame> frame(frames_passed[i]); frames_passed[i] = NULL; + DVLOG(3) << "Input frame: opcode=" << frame->header.opcode + << " final=" << frame->header.final + << " reserved1=" << frame->header.reserved1 + << " payload_length=" << frame->header.payload_length; + if (!WebSocketFrameHeader::IsKnownDataOpCode(frame->header.opcode)) { frames_to_output.push_back(frame.release()); continue; @@ -323,9 +333,7 @@ int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) { scoped_ptr<WebSocketFrame> inflated( new WebSocketFrame(WebSocketFrameHeader::kOpCodeText)); scoped_refptr<IOBufferWithSize> data = inflater_.GetOutput(size); - bool is_final = !inflater_.CurrentOutputSize(); - // |is_final| can't be true if |frame->header.final| is false. - DCHECK(!(is_final && !frame->header.final)); + bool is_final = !inflater_.CurrentOutputSize() && frame->header.final; if (!data) { DVLOG(1) << "WebSocket protocol error. " << "inflater_.GetOutput() returns an error."; @@ -337,7 +345,10 @@ int WebSocketDeflateStream::Inflate(ScopedVector<WebSocketFrame>* frames) { inflated->header.reserved1 = false; inflated->data = data; inflated->header.payload_length = data->size(); - + DVLOG(3) << "Inflated frame: opcode=" << inflated->header.opcode + << " final=" << inflated->header.final + << " reserved1=" << inflated->header.reserved1 + << " payload_length=" << inflated->header.payload_length; frames_to_output.push_back(inflated.release()); current_reading_opcode_ = WebSocketFrameHeader::kOpCodeContinuation; if (is_final) @@ -357,11 +368,18 @@ int WebSocketDeflateStream::InflateAndReadIfNecessary( int result = Inflate(frames); while (result == ERR_IO_PENDING) { DCHECK(frames->empty()); - result = stream_->ReadFrames(frames, callback); + + result = stream_->ReadFrames( + frames, + base::Bind(&WebSocketDeflateStream::OnReadComplete, + base::Unretained(this), + base::Unretained(frames), + callback)); if (result < 0) break; DCHECK_EQ(OK, result); DCHECK(!frames->empty()); + result = Inflate(frames); } if (result < 0) diff --git a/chromium/net/websockets/websocket_deflate_stream.h b/chromium/net/websockets/websocket_deflate_stream.h index a7859446f6a..39ac2dfa256 100644 --- a/chromium/net/websockets/websocket_deflate_stream.h +++ b/chromium/net/websockets/websocket_deflate_stream.h @@ -41,6 +41,7 @@ class NET_EXPORT_PRIVATE WebSocketDeflateStream : public WebSocketStream { public: WebSocketDeflateStream(scoped_ptr<WebSocketStream> stream, WebSocketDeflater::ContextTakeOverMode mode, + int client_window_bits, scoped_ptr<WebSocketDeflatePredictor> predictor); virtual ~WebSocketDeflateStream(); @@ -67,6 +68,7 @@ class NET_EXPORT_PRIVATE WebSocketDeflateStream : public WebSocketStream { NOT_WRITING, }; + // Handles asynchronous completion of ReadFrames() call on |stream_|. void OnReadComplete(ScopedVector<WebSocketFrame>* frames, const CompletionCallback& callback, int result); diff --git a/chromium/net/websockets/websocket_deflate_stream_test.cc b/chromium/net/websockets/websocket_deflate_stream_test.cc index 1775962dce1..a8b4e59c32b 100644 --- a/chromium/net/websockets/websocket_deflate_stream_test.cc +++ b/chromium/net/websockets/websocket_deflate_stream_test.cc @@ -209,17 +209,27 @@ class WebSocketDeflatePredictorMock : public WebSocketDeflatePredictor { class WebSocketDeflateStreamTest : public ::testing::Test { public: WebSocketDeflateStreamTest() - : mock_stream_(NULL) { + : mock_stream_(NULL), + predictor_(NULL) {} + virtual ~WebSocketDeflateStreamTest() {} + + virtual void SetUp() { + Initialize(WebSocketDeflater::TAKE_OVER_CONTEXT, kWindowBits); + } + + protected: + // Initialize deflate_stream_ with the given parameters. + void Initialize(WebSocketDeflater::ContextTakeOverMode mode, + int window_bits) { mock_stream_ = new testing::StrictMock<MockWebSocketStream>; predictor_ = new WebSocketDeflatePredictorMock; deflate_stream_.reset(new WebSocketDeflateStream( scoped_ptr<WebSocketStream>(mock_stream_), - WebSocketDeflater::TAKE_OVER_CONTEXT, + mode, + window_bits, scoped_ptr<WebSocketDeflatePredictor>(predictor_))); } - virtual ~WebSocketDeflateStreamTest() {} - protected: scoped_ptr<WebSocketDeflateStream> deflate_stream_; // Owned by |deflate_stream_|. MockWebSocketStream* mock_stream_; @@ -231,25 +241,41 @@ class WebSocketDeflateStreamTest : public ::testing::Test { // websocket_deflater_test.cc, we have only a few tests for this configuration // here. class WebSocketDeflateStreamWithDoNotTakeOverContextTest - : public ::testing::Test { + : public WebSocketDeflateStreamTest { public: - WebSocketDeflateStreamWithDoNotTakeOverContextTest() - : mock_stream_(NULL) { - mock_stream_ = new testing::StrictMock<MockWebSocketStream>; - predictor_ = new WebSocketDeflatePredictorMock; - deflate_stream_.reset(new WebSocketDeflateStream( - scoped_ptr<WebSocketStream>(mock_stream_), - WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT, - scoped_ptr<WebSocketDeflatePredictor>(predictor_))); - } + WebSocketDeflateStreamWithDoNotTakeOverContextTest() {} virtual ~WebSocketDeflateStreamWithDoNotTakeOverContextTest() {} + virtual void SetUp() { + Initialize(WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT, kWindowBits); + } +}; + +class WebSocketDeflateStreamWithClientWindowBitsTest + : public WebSocketDeflateStreamTest { + public: + WebSocketDeflateStreamWithClientWindowBitsTest() {} + virtual ~WebSocketDeflateStreamWithClientWindowBitsTest() {} + + // Overridden to postpone the call to Initialize(). + virtual void SetUp() {} + + // This needs to be called explicitly from the tests. + void SetUpWithWindowBits(int window_bits) { + Initialize(WebSocketDeflater::TAKE_OVER_CONTEXT, window_bits); + } + + // Add a frame which will be compressed to a smaller size if the window + // size is large enough. + void AddCompressibleFrameString() { + const std::string word = "Chromium"; + const std::string payload = word + std::string(256, 'a') + word; + AppendTo(&frames_, WebSocketFrameHeader::kOpCodeText, kFinal, payload); + predictor_->AddFramesToBeInput(frames_); + } + protected: - scoped_ptr<WebSocketDeflateStream> deflate_stream_; - // |mock_stream_| will be deleted when |deflate_stream_| is destroyed. - MockWebSocketStream* mock_stream_; - // |predictor_| will be deleted when |deflate_stream_| is destroyed. - WebSocketDeflatePredictorMock* predictor_; + ScopedVector<WebSocketFrame> frames_; }; // ReadFrameStub is a stub for WebSocketStream::ReadFrames. @@ -707,6 +733,47 @@ TEST_F(WebSocketDeflateStreamTest, SplitToMultipleFramesInReadFrames) { ToString(frames[0]) + ToString(frames[1]) + ToString(frames[2])); } +TEST_F(WebSocketDeflateStreamTest, InflaterInternalDataCanBeEmpty) { + WebSocketDeflater deflater(WebSocketDeflater::TAKE_OVER_CONTEXT); + deflater.Initialize(kWindowBits); + const std::string original_data(kChunkSize, 'a'); + deflater.AddBytes(original_data.data(), original_data.size()); + deflater.Finish(); + + ScopedVector<WebSocketFrame> frames_to_output; + AppendTo(&frames_to_output, + WebSocketFrameHeader::kOpCodeBinary, + kReserved1, + ToString(deflater.GetOutput(deflater.CurrentOutputSize()))); + AppendTo(&frames_to_output, + WebSocketFrameHeader::kOpCodeBinary, + kFinal, + ""); + + ReadFramesStub stub(OK, &frames_to_output); + CompletionCallback callback; + ScopedVector<WebSocketFrame> frames; + { + InSequence s; + EXPECT_CALL(*mock_stream_, ReadFrames(&frames, _)) + .WillOnce(Invoke(&stub, &ReadFramesStub::Call)); + } + + ASSERT_EQ(OK, deflate_stream_->ReadFrames(&frames, callback)); + ASSERT_EQ(2u, frames.size()); + EXPECT_EQ(WebSocketFrameHeader::kOpCodeBinary, frames[0]->header.opcode); + EXPECT_FALSE(frames[0]->header.final); + EXPECT_FALSE(frames[0]->header.reserved1); + EXPECT_EQ(kChunkSize, static_cast<size_t>(frames[0]->header.payload_length)); + + EXPECT_EQ(WebSocketFrameHeader::kOpCodeContinuation, + frames[1]->header.opcode); + EXPECT_TRUE(frames[1]->header.final); + EXPECT_FALSE(frames[1]->header.reserved1); + EXPECT_EQ(0u, static_cast<size_t>(frames[1]->header.payload_length)); + EXPECT_EQ(original_data, ToString(frames[0]) + ToString(frames[1])); +} + TEST_F(WebSocketDeflateStreamTest, Reserved1TurnsOnDuringReadingCompressedContinuationFrame) { const std::string data1("\xf2\x48\xcd", 3); @@ -886,6 +953,43 @@ TEST_F(WebSocketDeflateStreamTest, EXPECT_EQ("compressed", ToString(frames[1])); } +// This is a regression test for crbug.com/343506. +TEST_F(WebSocketDeflateStreamTest, ReadEmptyAsyncFrame) { + ScopedVector<ReadFramesStub> stub_vector; + stub_vector.push_back(new ReadFramesStub(ERR_IO_PENDING)); + stub_vector.push_back(new ReadFramesStub(ERR_IO_PENDING)); + MockCallback mock_callback; + CompletionCallback callback = + base::Bind(&MockCallback::Call, base::Unretained(&mock_callback)); + ScopedVector<WebSocketFrame> frames; + + { + InSequence s; + EXPECT_CALL(*mock_stream_, ReadFrames(&frames, _)) + .WillOnce(Invoke(stub_vector[0], &ReadFramesStub::Call)); + + EXPECT_CALL(*mock_stream_, ReadFrames(&frames, _)) + .WillOnce(Invoke(stub_vector[1], &ReadFramesStub::Call)); + + EXPECT_CALL(mock_callback, Call(OK)); + } + + ASSERT_EQ(ERR_IO_PENDING, deflate_stream_->ReadFrames(&frames, callback)); + AppendTo(stub_vector[0]->frames_passed(), + WebSocketFrameHeader::kOpCodeText, + kReserved1, + std::string()); + stub_vector[0]->callback().Run(OK); + AppendTo(stub_vector[1]->frames_passed(), + WebSocketFrameHeader::kOpCodeContinuation, + kFinal, + std::string("\x02\x00")); + stub_vector[1]->callback().Run(OK); + ASSERT_EQ(1u, frames.size()); + EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames[0]->header.opcode); + EXPECT_EQ("", ToString(frames[0])); +} + TEST_F(WebSocketDeflateStreamTest, WriteEmpty) { ScopedVector<WebSocketFrame> frames; CompletionCallback callback; @@ -1006,7 +1110,7 @@ TEST_F(WebSocketDeflateStreamTest, WriteEmptyMessage) { EXPECT_EQ(WebSocketFrameHeader::kOpCodeText, frames_passed[0]->header.opcode); EXPECT_TRUE(frames_passed[0]->header.final); EXPECT_TRUE(frames_passed[0]->header.reserved1); - EXPECT_EQ(std::string("\x02\x00", 2), ToString(frames_passed[0])); + EXPECT_EQ(std::string("\x00", 1), ToString(frames_passed[0])); } TEST_F(WebSocketDeflateStreamTest, WriteUncompressedMessage) { @@ -1201,6 +1305,44 @@ TEST_F(WebSocketDeflateStreamWithDoNotTakeOverContextTest, EXPECT_EQ("YY", ToString(frames_passed[4])); } +// This is based on the similar test from websocket_deflater_test.cc +TEST_F(WebSocketDeflateStreamWithClientWindowBitsTest, WindowBits8) { + SetUpWithWindowBits(8); + CompletionCallback callback; + AddCompressibleFrameString(); + WriteFramesStub stub(predictor_, OK); + { + InSequence s; + EXPECT_CALL(*mock_stream_, WriteFrames(_, _)) + .WillOnce(Invoke(&stub, &WriteFramesStub::Call)); + } + ASSERT_EQ(OK, deflate_stream_->WriteFrames(&frames_, callback)); + const ScopedVector<WebSocketFrame>& frames_passed = *stub.frames(); + ASSERT_EQ(1u, frames_passed.size()); + EXPECT_EQ(std::string("r\xce(\xca\xcf\xcd,\xcdM\x1c\xe1\xc0\x39\xa3" + "(?7\xb3\x34\x17\x00", 21), + ToString(frames_passed[0])); +} + +// The same input with window_bits=10 returns smaller output. +TEST_F(WebSocketDeflateStreamWithClientWindowBitsTest, WindowBits10) { + SetUpWithWindowBits(10); + CompletionCallback callback; + AddCompressibleFrameString(); + WriteFramesStub stub(predictor_, OK); + { + InSequence s; + EXPECT_CALL(*mock_stream_, WriteFrames(_, _)) + .WillOnce(Invoke(&stub, &WriteFramesStub::Call)); + } + ASSERT_EQ(OK, deflate_stream_->WriteFrames(&frames_, callback)); + const ScopedVector<WebSocketFrame>& frames_passed = *stub.frames(); + ASSERT_EQ(1u, frames_passed.size()); + EXPECT_EQ( + std::string("r\xce(\xca\xcf\xcd,\xcdM\x1c\xe1\xc0\x19\x1a\x0e\0\0", 17), + ToString(frames_passed[0])); +} + } // namespace } // namespace net diff --git a/chromium/net/websockets/websocket_deflater.cc b/chromium/net/websockets/websocket_deflater.cc index 41d13e86870..a4c56bccb19 100644 --- a/chromium/net/websockets/websocket_deflater.cc +++ b/chromium/net/websockets/websocket_deflater.cc @@ -66,7 +66,6 @@ bool WebSocketDeflater::Finish() { // Since consecutive calls of deflate with Z_SYNC_FLUSH and no input // lead to an error, we create and return the output for the empty input // manually. - buffer_.push_back('\x02'); buffer_.push_back('\x00'); ResetContext(); return true; diff --git a/chromium/net/websockets/websocket_deflater.h b/chromium/net/websockets/websocket_deflater.h index da85bfec912..1b631e2663e 100644 --- a/chromium/net/websockets/websocket_deflater.h +++ b/chromium/net/websockets/websocket_deflater.h @@ -21,9 +21,12 @@ class IOBufferWithSize; class NET_EXPORT_PRIVATE WebSocketDeflater { public: + // Do not reorder or remove entries of this enum. The values of them are used + // in UMA. enum ContextTakeOverMode { DO_NOT_TAKE_OVER_CONTEXT, TAKE_OVER_CONTEXT, + NUM_CONTEXT_TAKEOVER_MODE_TYPES, }; explicit WebSocketDeflater(ContextTakeOverMode mode); diff --git a/chromium/net/websockets/websocket_deflater_test.cc b/chromium/net/websockets/websocket_deflater_test.cc index 03b8a3d7c52..ae0133c6424 100644 --- a/chromium/net/websockets/websocket_deflater_test.cc +++ b/chromium/net/websockets/websocket_deflater_test.cc @@ -25,7 +25,7 @@ TEST(WebSocketDeflaterTest, Construct) { ASSERT_TRUE(deflater.Finish()); scoped_refptr<IOBufferWithSize> actual = deflater.GetOutput(deflater.CurrentOutputSize()); - EXPECT_EQ(std::string("\x02\00", 2), ToString(actual.get())); + EXPECT_EQ(std::string("\00", 1), ToString(actual.get())); ASSERT_EQ(0u, deflater.CurrentOutputSize()); } @@ -93,8 +93,8 @@ TEST(WebSocketDeflaterTest, GetMultipleDeflatedOutput) { actual = deflater.GetOutput(deflater.CurrentOutputSize()); EXPECT_EQ(std::string("\xf2\x48\xcd\xc9\xc9\x07\x00\x00\x00\xff\xff" - "\x02\x00\x00\x00\xff\xff" - "\xf2\x00\x11\x00\x00", 22), + "\x00\x00\x00\xff\xff" + "\xf2\x00\x11\x00\x00", 21), ToString(actual.get())); ASSERT_EQ(0u, deflater.CurrentOutputSize()); } diff --git a/chromium/net/websockets/websocket_event_interface.h b/chromium/net/websockets/websocket_event_interface.h index baba88ce012..d32a7c131cf 100644 --- a/chromium/net/websockets/websocket_event_interface.h +++ b/chromium/net/websockets/websocket_event_interface.h @@ -12,8 +12,14 @@ #include "base/compiler_specific.h" // for WARN_UNUSED_RESULT #include "net/base/net_export.h" +class GURL; + namespace net { +class SSLInfo; +struct WebSocketHandshakeRequestInfo; +struct WebSocketHandshakeResponseInfo; + // Interface for events sent from the network layer to the content layer. These // events will generally be sent as-is to the renderer process. class NET_EXPORT WebSocketEventInterface { @@ -29,13 +35,15 @@ class NET_EXPORT WebSocketEventInterface { }; virtual ~WebSocketEventInterface() {} + // Called in response to an AddChannelRequest. This generally means that a // response has been received from the remote server, but the response might // have been generated internally. If |fail| is true, the channel cannot be // used and should be deleted, returning CHANNEL_DELETED. virtual ChannelState OnAddChannelResponse( bool fail, - const std::string& selected_subprotocol) WARN_UNUSED_RESULT = 0; + const std::string& selected_subprotocol, + const std::string& extensions) WARN_UNUSED_RESULT = 0; // Called when a data frame has been received from the remote host and needs // to be forwarded to the renderer process. @@ -63,14 +71,63 @@ class NET_EXPORT WebSocketEventInterface { // callers must take care not to provide details that could be useful to // attackers attempting to use WebSockets to probe networks. // + // |was_clean| should be true if the closing handshake completed successfully. + // // The channel should not be used again after OnDropChannel() has been // called. // // This method returns a ChannelState for consistency, but all implementations // must delete the Channel and return CHANNEL_DELETED. - virtual ChannelState OnDropChannel(uint16 code, const std::string& reason) + virtual ChannelState OnDropChannel(bool was_clean, + uint16 code, + const std::string& reason) WARN_UNUSED_RESULT = 0; + // Called when the browser fails the channel, as specified in the spec. + // + // The channel should not be used again after OnFailChannel() has been + // called. + // + // This method returns a ChannelState for consistency, but all implementations + // must delete the Channel and return CHANNEL_DELETED. + virtual ChannelState OnFailChannel(const std::string& message) + WARN_UNUSED_RESULT = 0; + + // Called when the browser starts the WebSocket Opening Handshake. + virtual ChannelState OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) WARN_UNUSED_RESULT = 0; + + // Called when the browser finishes the WebSocket Opening Handshake. + virtual ChannelState OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) + WARN_UNUSED_RESULT = 0; + + // Callbacks to be used in response to a call to OnSSLCertificateError. Very + // similar to content::SSLErrorHandler::Delegate (which we can't use directly + // due to layering constraints). + class NET_EXPORT SSLErrorCallbacks { + public: + virtual ~SSLErrorCallbacks() {} + + // Cancels the SSL response in response to the error. + virtual void CancelSSLRequest(int error, const SSLInfo* ssl_info) = 0; + + // Continue with the SSL connection despite the error. + virtual void ContinueSSLRequest() = 0; + }; + + // Called on SSL Certificate Error during the SSL handshake. Should result in + // a call to either ssl_error_callbacks->ContinueSSLRequest() or + // ssl_error_callbacks->CancelSSLRequest(). Normally the implementation of + // this method will delegate to content::SSLManager::OnSSLCertificateError to + // make the actual decision. The callbacks must not be called after the + // WebSocketChannel has been destroyed. + virtual ChannelState OnSSLCertificateError( + scoped_ptr<SSLErrorCallbacks> ssl_error_callbacks, + const GURL& url, + const SSLInfo& ssl_info, + bool fatal) WARN_UNUSED_RESULT = 0; + protected: WebSocketEventInterface() {} diff --git a/chromium/net/websockets/websocket_frame.cc b/chromium/net/websockets/websocket_frame.cc index 763712a6f57..6fe972ba4a8 100644 --- a/chromium/net/websockets/websocket_frame.cc +++ b/chromium/net/websockets/websocket_frame.cc @@ -7,9 +7,9 @@ #include <algorithm> #include "base/basictypes.h" +#include "base/big_endian.h" #include "base/logging.h" #include "base/rand_util.h" -#include "net/base/big_endian.h" #include "net/base/io_buffer.h" #include "net/base/net_errors.h" @@ -131,10 +131,10 @@ int WriteWebSocketFrameHeader(const WebSocketFrameHeader& header, // Writes "extended payload length" field. if (extended_length_size == 2) { uint16 payload_length_16 = static_cast<uint16>(header.payload_length); - WriteBigEndian(buffer + buffer_index, payload_length_16); + base::WriteBigEndian(buffer + buffer_index, payload_length_16); buffer_index += sizeof(payload_length_16); } else if (extended_length_size == 8) { - WriteBigEndian(buffer + buffer_index, header.payload_length); + base::WriteBigEndian(buffer + buffer_index, header.payload_length); buffer_index += sizeof(header.payload_length); } diff --git a/chromium/net/websockets/websocket_frame_parser.cc b/chromium/net/websockets/websocket_frame_parser.cc index 3b199128b42..2e4c58fe302 100644 --- a/chromium/net/websockets/websocket_frame_parser.cc +++ b/chromium/net/websockets/websocket_frame_parser.cc @@ -8,11 +8,11 @@ #include <limits> #include "base/basictypes.h" +#include "base/big_endian.h" #include "base/logging.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/scoped_vector.h" -#include "net/base/big_endian.h" #include "net/base/io_buffer.h" #include "net/websockets/websocket_frame.h" @@ -124,7 +124,7 @@ void WebSocketFrameParser::DecodeFrameHeader() { if (end - current < 2) return; uint16 payload_length_16; - ReadBigEndian(current, &payload_length_16); + base::ReadBigEndian(current, &payload_length_16); current += 2; payload_length = payload_length_16; if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) @@ -132,7 +132,7 @@ void WebSocketFrameParser::DecodeFrameHeader() { } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) { if (end - current < 8) return; - ReadBigEndian(current, &payload_length); + base::ReadBigEndian(current, &payload_length); current += 8; if (payload_length <= kuint16max || payload_length > static_cast<uint64>(kint64max)) { diff --git a/chromium/net/websockets/websocket_frame_test.cc b/chromium/net/websockets/websocket_frame_test.cc index 97fac03e12e..b37dbb33001 100644 --- a/chromium/net/websockets/websocket_frame_test.cc +++ b/chromium/net/websockets/websocket_frame_test.cc @@ -308,7 +308,7 @@ TEST(WebSocketFrameTest, MaskPayloadAlignment) { }; COMPILE_ASSERT(arraysize(kTestInput) == arraysize(kTestOutput), output_and_input_arrays_have_the_same_length); - scoped_ptr_malloc<char, base::ScopedPtrAlignedFree> scratch( + scoped_ptr<char, base::AlignedFreeDeleter> scratch( static_cast<char*>( base::AlignedAlloc(kScratchBufferSize, kMaxVectorAlignment))); WebSocketMaskingKey masking_key; @@ -348,7 +348,7 @@ class WebSocketFrameTestMaskBenchmark : public testing::Test { virtual void SetUp() { std::string iterations( - CommandLine::ForCurrentProcess()->GetSwitchValueASCII( + base::CommandLine::ForCurrentProcess()->GetSwitchValueASCII( kBenchmarkIterations)); int benchmark_iterations = 0; if (!iterations.empty() && diff --git a/chromium/net/websockets/websocket_handshake_handler_spdy_test.cc b/chromium/net/websockets/websocket_handshake_handler_spdy_test.cc index a825dcdfede..064bdcfb360 100644 --- a/chromium/net/websockets/websocket_handshake_handler_spdy_test.cc +++ b/chromium/net/websockets/websocket_handshake_handler_spdy_test.cc @@ -29,8 +29,7 @@ INSTANTIATE_TEST_CASE_P( NextProto, WebSocketHandshakeHandlerSpdyTest, testing::Values(kProtoDeprecatedSPDY2, - kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2, - kProtoHTTP2Draft04)); + kProtoSPDY3, kProtoSPDY31, kProtoSPDY4)); TEST_P(WebSocketHandshakeHandlerSpdyTest, RequestResponse) { WebSocketHandshakeRequestHandler request_handler; diff --git a/chromium/net/websockets/websocket_handshake_request_info.cc b/chromium/net/websockets/websocket_handshake_request_info.cc new file mode 100644 index 00000000000..7acd4d0433c --- /dev/null +++ b/chromium/net/websockets/websocket_handshake_request_info.cc @@ -0,0 +1,19 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_handshake_request_info.h" + +#include "base/time/time.h" +#include "url/gurl.h" + +namespace net { + +WebSocketHandshakeRequestInfo::WebSocketHandshakeRequestInfo( + const GURL& url, + base::Time request_time) + : url(url), request_time(request_time) {} + +WebSocketHandshakeRequestInfo::~WebSocketHandshakeRequestInfo() {} + +} // namespace net diff --git a/chromium/net/websockets/websocket_handshake_request_info.h b/chromium/net/websockets/websocket_handshake_request_info.h new file mode 100644 index 00000000000..e5ef3369b2a --- /dev/null +++ b/chromium/net/websockets/websocket_handshake_request_info.h @@ -0,0 +1,33 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_REQUEST_INFO_H_ +#define NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_REQUEST_INFO_H_ + +#include <string> + +#include "base/time/time.h" +#include "net/base/net_export.h" +#include "net/http/http_request_headers.h" +#include "url/gurl.h" + +namespace net { + +struct NET_EXPORT WebSocketHandshakeRequestInfo { + WebSocketHandshakeRequestInfo(const GURL& url, base::Time request_time); + ~WebSocketHandshakeRequestInfo(); + // The request URL + GURL url; + // HTTP request headers + HttpRequestHeaders headers; + // The time that this request is sent + base::Time request_time; + + private: + DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeRequestInfo); +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_REQUEST_INFO_H_ diff --git a/chromium/net/websockets/websocket_handshake_response_info.cc b/chromium/net/websockets/websocket_handshake_response_info.cc new file mode 100644 index 00000000000..b9588b63e3f --- /dev/null +++ b/chromium/net/websockets/websocket_handshake_response_info.cc @@ -0,0 +1,30 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/websockets/websocket_handshake_response_info.h" + +#include <string> + +#include "base/memory/ref_counted.h" +#include "base/time/time.h" +#include "net/http/http_response_headers.h" +#include "url/gurl.h" + +namespace net { + +WebSocketHandshakeResponseInfo::WebSocketHandshakeResponseInfo( + const GURL& url, + int status_code, + const std::string& status_text, + scoped_refptr<HttpResponseHeaders> headers, + base::Time response_time) + : url(url), + status_code(status_code), + status_text(status_text), + headers(headers), + response_time(response_time) {} + +WebSocketHandshakeResponseInfo::~WebSocketHandshakeResponseInfo() {} + +} // namespace net diff --git a/chromium/net/websockets/websocket_handshake_response_info.h b/chromium/net/websockets/websocket_handshake_response_info.h new file mode 100644 index 00000000000..66aff4417b7 --- /dev/null +++ b/chromium/net/websockets/websocket_handshake_response_info.h @@ -0,0 +1,43 @@ +// Copyright 2014 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_RESPONSE_INFO_H_ +#define NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_RESPONSE_INFO_H_ + +#include <string> + +#include "base/memory/ref_counted.h" +#include "base/time/time.h" +#include "net/base/net_export.h" +#include "url/gurl.h" + +namespace net { + +class HttpResponseHeaders; + +struct NET_EXPORT WebSocketHandshakeResponseInfo { + WebSocketHandshakeResponseInfo(const GURL& url, + int status_code, + const std::string& status_text, + scoped_refptr<HttpResponseHeaders> headers, + base::Time response_time); + ~WebSocketHandshakeResponseInfo(); + // The request URL + GURL url; + // HTTP status code + int status_code; + // HTTP status text + std::string status_text; + // HTTP response headers + scoped_refptr<HttpResponseHeaders> headers; + // The time that this response arrived + base::Time response_time; + + private: + DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeResponseInfo); +}; + +} // namespace net + +#endif // NET_WEBSOCKETS_WEBSOCKET_HANDSHAKE_RESPONSE_INFO_H_ diff --git a/chromium/net/websockets/websocket_handshake_stream_base.h b/chromium/net/websockets/websocket_handshake_stream_base.h index 71d8321824f..8208c0e742f 100644 --- a/chromium/net/websockets/websocket_handshake_stream_base.h +++ b/chromium/net/websockets/websocket_handshake_stream_base.h @@ -9,6 +9,8 @@ // Since net/http can be built without linking net/websockets code, // this file must not introduce any link-time dependencies on websockets. +#include <string> + #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" diff --git a/chromium/net/websockets/websocket_handshake_stream_create_helper.cc b/chromium/net/websockets/websocket_handshake_stream_create_helper.cc index 8f1060c4868..e68052ef102 100644 --- a/chromium/net/websockets/websocket_handshake_stream_create_helper.cc +++ b/chromium/net/websockets/websocket_handshake_stream_create_helper.cc @@ -14,9 +14,14 @@ namespace net { WebSocketHandshakeStreamCreateHelper::WebSocketHandshakeStreamCreateHelper( + WebSocketStream::ConnectDelegate* connect_delegate, const std::vector<std::string>& requested_subprotocols) : requested_subprotocols_(requested_subprotocols), - stream_(NULL) {} + stream_(NULL), + connect_delegate_(connect_delegate), + failure_message_(NULL) { + DCHECK(connect_delegate_); +} WebSocketHandshakeStreamCreateHelper::~WebSocketHandshakeStreamCreateHelper() {} @@ -24,11 +29,18 @@ WebSocketHandshakeStreamBase* WebSocketHandshakeStreamCreateHelper::CreateBasicStream( scoped_ptr<ClientSocketHandle> connection, bool using_proxy) { - return stream_ = - new WebSocketBasicHandshakeStream(connection.Pass(), - using_proxy, - requested_subprotocols_, - std::vector<std::string>()); + DCHECK(failure_message_) << "set_failure_message() must be called"; + // The list of supported extensions and parameters is hard-coded. + // TODO(ricea): If more extensions are added, consider a more flexible + // method. + std::vector<std::string> extensions( + 1, "permessage-deflate; client_max_window_bits"); + return stream_ = new WebSocketBasicHandshakeStream(connection.Pass(), + connect_delegate_, + using_proxy, + requested_subprotocols_, + extensions, + failure_message_); } // TODO(ricea): Create a WebSocketSpdyHandshakeStream. crbug.com/323852 diff --git a/chromium/net/websockets/websocket_handshake_stream_create_helper.h b/chromium/net/websockets/websocket_handshake_stream_create_helper.h index 31be2313ff7..648f8fd23fe 100644 --- a/chromium/net/websockets/websocket_handshake_stream_create_helper.h +++ b/chromium/net/websockets/websocket_handshake_stream_create_helper.h @@ -10,6 +10,7 @@ #include "net/base/net_export.h" #include "net/websockets/websocket_handshake_stream_base.h" +#include "net/websockets/websocket_stream.h" namespace net { @@ -22,7 +23,9 @@ namespace net { class NET_EXPORT_PRIVATE WebSocketHandshakeStreamCreateHelper : public WebSocketHandshakeStreamBase::CreateHelper { public: + // |connect_delegate| must out-live this object. explicit WebSocketHandshakeStreamCreateHelper( + WebSocketStream::ConnectDelegate* connect_delegate, const std::vector<std::string>& requested_subprotocols); virtual ~WebSocketHandshakeStreamCreateHelper(); @@ -42,17 +45,30 @@ class NET_EXPORT_PRIVATE WebSocketHandshakeStreamCreateHelper // Return the WebSocketHandshakeStreamBase object that we created. In the case // where CreateBasicStream() was called more than once, returns the most // recent stream, which will be the one on which the handshake succeeded. + // It is not safe to call this if the handshake failed. WebSocketHandshakeStreamBase* stream() { return stream_; } + // Set a pointer to the std::string into which to write any failure messages + // that are encountered. This method must be called before CreateBasicStream() + // or CreateSpdyStream(). The |failure_message| pointer must remain valid as + // long as this object exists. + void set_failure_message(std::string* failure_message) { + failure_message_ = failure_message; + } + private: const std::vector<std::string> requested_subprotocols_; // This is owned by the caller of CreateBaseStream() or // CreateSpdyStream(). Both the stream and this object will be destroyed // during the destruction of the URLRequest object associated with the - // handshake. + // handshake. This is only guaranteed to be a valid pointer if the handshake + // succeeded. WebSocketHandshakeStreamBase* stream_; + WebSocketStream::ConnectDelegate* connect_delegate_; + std::string* failure_message_; + DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeStreamCreateHelper); }; diff --git a/chromium/net/websockets/websocket_handshake_stream_create_helper_test.cc b/chromium/net/websockets/websocket_handshake_stream_create_helper_test.cc index 7566edf6174..644679410b5 100644 --- a/chromium/net/websockets/websocket_handshake_stream_create_helper_test.cc +++ b/chromium/net/websockets/websocket_handshake_stream_create_helper_test.cc @@ -4,6 +4,9 @@ #include "net/websockets/websocket_handshake_stream_create_helper.h" +#include <string> +#include <vector> + #include "net/base/completion_callback.h" #include "net/base/net_errors.h" #include "net/http/http_request_headers.h" @@ -54,6 +57,23 @@ class MockClientSocketHandleFactory { DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory); }; +class TestConnectDelegate : public WebSocketStream::ConnectDelegate { + public: + virtual ~TestConnectDelegate() {} + + virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE {} + virtual void OnFailure(const std::string& failure_message) OVERRIDE {} + virtual void OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {} + virtual void OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {} + virtual void OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> + ssl_error_callbacks, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE {} +}; + class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { protected: scoped_ptr<WebSocketStream> CreateAndInitializeStream( @@ -63,7 +83,9 @@ class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { const std::string& origin, const std::string& extra_request_headers, const std::string& extra_response_headers) { - WebSocketHandshakeStreamCreateHelper create_helper(sub_protocols); + WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_, + sub_protocols); + create_helper.set_failure_message(&failure_message_); scoped_ptr<ClientSocketHandle> socket_handle = socket_handle_factory_.CreateClientSocketHandle( @@ -91,6 +113,8 @@ class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { HttpRequestHeaders headers; headers.SetHeader("Host", "localhost"); headers.SetHeader("Connection", "Upgrade"); + headers.SetHeader("Pragma", "no-cache"); + headers.SetHeader("Cache-Control", "no-cache"); headers.SetHeader("Upgrade", "websocket"); headers.SetHeader("Origin", origin); headers.SetHeader("Sec-WebSocket-Version", "13"); @@ -114,6 +138,8 @@ class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { } MockClientSocketHandleFactory socket_handle_factory_; + TestConnectDelegate connect_delegate_; + std::string failure_message_; }; // Confirm that the basic case works as expected. @@ -132,14 +158,47 @@ TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) { sub_protocols.push_back("chat"); sub_protocols.push_back("superchat"); scoped_ptr<WebSocketStream> stream = - CreateAndInitializeStream("ws://localhost/", "/", - sub_protocols, "http://localhost/", + CreateAndInitializeStream("ws://localhost/", + "/", + sub_protocols, + "http://localhost/", "Sec-WebSocket-Protocol: chat, superchat\r\n", "Sec-WebSocket-Protocol: superchat\r\n"); EXPECT_EQ("superchat", stream->GetSubProtocol()); } -// TODO(ricea): Test extensions once they are implemented. +// Verify that extension name is available. Bad extension names are tested in +// websocket_stream_test.cc. +TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) { + scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( + "ws://localhost/", + "/", + std::vector<std::string>(), + "http://localhost/", + "", + "Sec-WebSocket-Extensions: permessage-deflate\r\n"); + EXPECT_EQ("permessage-deflate", stream->GetExtensions()); +} + +// Verify that extension parameters are available. Bad parameters are tested in +// websocket_stream_test.cc. +TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) { + scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( + "ws://localhost/", + "/", + std::vector<std::string>(), + "http://localhost/", + "", + "Sec-WebSocket-Extensions: permessage-deflate;" + " client_max_window_bits=14; server_max_window_bits=14;" + " server_no_context_takeover; client_no_context_takeover\r\n"); + + EXPECT_EQ( + "permessage-deflate;" + " client_max_window_bits=14; server_max_window_bits=14;" + " server_no_context_takeover; client_no_context_takeover", + stream->GetExtensions()); +} } // namespace } // namespace net diff --git a/chromium/net/websockets/websocket_job.cc b/chromium/net/websockets/websocket_job.cc index b0f5be8bf3f..eb653fa5a97 100644 --- a/chromium/net/websockets/websocket_job.cc +++ b/chromium/net/websockets/websocket_job.cc @@ -36,9 +36,10 @@ const char* const kSetCookieHeaders[] = { }; net::SocketStreamJob* WebSocketJobFactory( - const GURL& url, net::SocketStream::Delegate* delegate) { + const GURL& url, net::SocketStream::Delegate* delegate, + net::URLRequestContext* context, net::CookieStore* cookie_store) { net::WebSocketJob* job = new net::WebSocketJob(delegate); - job->InitSocketStream(new net::SocketStream(url, job)); + job->InitSocketStream(new net::SocketStream(url, job, context, cookie_store)); return job; } @@ -58,18 +59,11 @@ static base::LazyInstance<WebSocketJobInitSingleton> g_websocket_job_init = namespace net { -bool WebSocketJob::websocket_over_spdy_enabled_ = false; - // static void WebSocketJob::EnsureInit() { g_websocket_job_init.Get(); } -// static -void WebSocketJob::set_websocket_over_spdy_enabled(bool enabled) { - websocket_over_spdy_enabled_ = enabled; -} - WebSocketJob::WebSocketJob(SocketStream::Delegate* delegate) : delegate_(delegate), state_(INITIALIZED), @@ -303,9 +297,10 @@ void WebSocketJob::OnSentSpdyHeaders() { DCHECK_NE(INITIALIZED, state_); if (state_ != CONNECTING) return; - if (delegate_) - delegate_->OnSentData(socket_.get(), handshake_request_->original_length()); + size_t original_length = handshake_request_->original_length(); handshake_request_.reset(); + if (delegate_) + delegate_->OnSentData(socket_.get(), original_length); } void WebSocketJob::OnSpdyResponseHeadersUpdated( @@ -370,11 +365,11 @@ void WebSocketJob::AddCookieHeaderAndSend() { if (socket_.get() && delegate_ && state_ == CONNECTING) { handshake_request_->RemoveHeaders(kCookieHeaders, arraysize(kCookieHeaders)); - if (allow && socket_->context()->cookie_store()) { + if (allow && socket_->cookie_store()) { // Add cookies, including HttpOnly cookies. CookieOptions cookie_options; cookie_options.set_include_httponly(); - socket_->context()->cookie_store()->GetCookiesWithOptionsAsync( + socket_->cookie_store()->GetCookiesWithOptionsAsync( GetURLForCookies(), cookie_options, base::Bind(&WebSocketJob::LoadCookieCallback, weak_ptr_factory_.GetWeakPtr())); @@ -387,7 +382,7 @@ void WebSocketJob::AddCookieHeaderAndSend() { void WebSocketJob::LoadCookieCallback(const std::string& cookie) { if (!cookie.empty()) // TODO(tyoshino): Sending cookie means that connection doesn't need - // kPrivacyModeEnabled as cookies may be server-bound and channel id + // PRIVACY_MODE_ENABLED as cookies may be server-bound and channel id // wouldn't negatively affect privacy anyway. Need to restart connection // or refactor to determine cookie status prior to connecting. handshake_request_->AppendHeaderIfMissing("Cookie", cookie); @@ -422,11 +417,12 @@ void WebSocketJob::OnSentHandshakeRequest( if (handshake_request_sent_ >= handshake_request_->raw_length()) { // handshake request has been sent. // notify original size of handshake request to delegate. - if (delegate_) - delegate_->OnSentData( - socket, - handshake_request_->original_length()); + // Reset the handshake_request_ first in case this object is deleted by the + // delegate. + size_t original_length = handshake_request_->original_length(); handshake_request_.reset(); + if (delegate_) + delegate_->OnSentData(socket, original_length); } } @@ -505,7 +501,7 @@ void WebSocketJob::SaveNextCookie() { callback_pending_ = false; save_next_cookie_running_ = true; - if (socket_->context()->cookie_store()) { + if (socket_->cookie_store()) { GURL url_for_cookies = GetURLForCookies(); CookieOptions options; @@ -526,7 +522,7 @@ void WebSocketJob::SaveNextCookie() { continue; callback_pending_ = true; - socket_->context()->cookie_store()->SetCookieWithOptionsAsync( + socket_->cookie_store()->SetCookieWithOptionsAsync( url_for_cookies, cookie, options, base::Bind(&WebSocketJob::OnCookieSaved, weak_ptr_factory_.GetWeakPtr())); @@ -563,9 +559,8 @@ void WebSocketJob::OnCookieSaved(bool cookie_status) { GURL WebSocketJob::GetURLForCookies() const { GURL url = socket_->url(); std::string scheme = socket_->is_secure() ? "https" : "http"; - url_canon::Replacements<char> replacements; - replacements.SetScheme(scheme.c_str(), - url_parse::Component(0, scheme.length())); + url::Replacements<char> replacements; + replacements.SetScheme(scheme.c_str(), url::Component(0, scheme.length())); return url.ReplaceComponents(replacements); } @@ -577,16 +572,13 @@ int WebSocketJob::TrySpdyStream() { if (!socket_.get()) return ERR_FAILED; - if (!websocket_over_spdy_enabled_) - return OK; - // Check if we have a SPDY session available. HttpTransactionFactory* factory = socket_->context()->http_transaction_factory(); if (!factory) return OK; scoped_refptr<HttpNetworkSession> session = factory->GetSession(); - if (!session.get()) + if (!session.get() || !session->params().enable_websocket_over_spdy) return OK; SpdySessionPool* spdy_pool = session->spdy_session_pool(); PrivacyMode privacy_mode = socket_->privacy_mode(); diff --git a/chromium/net/websockets/websocket_job.h b/chromium/net/websockets/websocket_job.h index 119c4dcfaa9..2e90a24d16c 100644 --- a/chromium/net/websockets/websocket_job.h +++ b/chromium/net/websockets/websocket_job.h @@ -49,10 +49,6 @@ class NET_EXPORT WebSocketJob static void EnsureInit(); - // Enable or Disable WebSocket over SPDY feature. - // This function is intended to be called before I/O thread starts. - static void set_websocket_over_spdy_enabled(bool enabled); - State state() const { return state_; } virtual void Connect() OVERRIDE; virtual bool SendData(const char* data, int len) OVERRIDE; @@ -124,8 +120,6 @@ class NET_EXPORT WebSocketJob void CloseInternal(); void SendPending(); - static bool websocket_over_spdy_enabled_; - SocketStream::Delegate* delegate_; State state_; bool waiting_; diff --git a/chromium/net/websockets/websocket_job_test.cc b/chromium/net/websockets/websocket_job_test.cc index bdbae709eb9..7b87a870d8d 100644 --- a/chromium/net/websockets/websocket_job_test.cc +++ b/chromium/net/websockets/websocket_job_test.cc @@ -41,8 +41,9 @@ namespace { class MockSocketStream : public SocketStream { public: - MockSocketStream(const GURL& url, SocketStream::Delegate* delegate) - : SocketStream(url, delegate) {} + MockSocketStream(const GURL& url, SocketStream::Delegate* delegate, + URLRequestContext* context, CookieStore* cookie_store) + : SocketStream(url, delegate, context, cookie_store) {} virtual void Connect() OVERRIDE {} virtual bool SendData(const char* data, int len) OVERRIDE { @@ -203,6 +204,12 @@ class MockCookieStore : public CookieStore { callback.Run(GetCookiesWithOptions(url, options)); } + virtual void GetAllCookiesForURLAsync( + const GURL& url, + const GetCookieListCallback& callback) OVERRIDE { + ADD_FAILURE(); + } + virtual void DeleteCookieAsync(const GURL& url, const std::string& cookie_name, const base::Closure& callback) OVERRIDE { @@ -216,6 +223,14 @@ class MockCookieStore : public CookieStore { ADD_FAILURE(); } + virtual void DeleteAllCreatedBetweenForHostAsync( + const base::Time delete_begin, + const base::Time delete_end, + const GURL& url, + const DeleteCallback& callback) OVERRIDE { + ADD_FAILURE(); + } + virtual void DeleteSessionCookiesAsync(const DeleteCallback&) OVERRIDE { ADD_FAILURE(); } @@ -259,11 +274,14 @@ class MockURLRequestContext : public URLRequestContext { class MockHttpTransactionFactory : public HttpTransactionFactory { public: - MockHttpTransactionFactory(NextProto next_proto, OrderedSocketData* data) { + MockHttpTransactionFactory(NextProto next_proto, + OrderedSocketData* data, + bool enable_websocket_over_spdy) { data_ = data; MockConnect connect_data(SYNCHRONOUS, OK); data_->set_connect_data(connect_data); session_deps_.reset(new SpdySessionDependencies(next_proto)); + session_deps_->enable_websocket_over_spdy = enable_websocket_over_spdy; session_deps_->socket_factory->AddSocketDataProvider(data_); http_session_ = SpdySessionDependencies::SpdyCreateSession(session_deps_.get()); @@ -271,15 +289,14 @@ class MockHttpTransactionFactory : public HttpTransactionFactory { host_port_pair_.set_port(80); spdy_session_key_ = SpdySessionKey(host_port_pair_, ProxyServer::Direct(), - kPrivacyModeDisabled); + PRIVACY_MODE_DISABLED); session_ = CreateInsecureSpdySession( http_session_, spdy_session_key_, BoundNetLog()); } virtual int CreateTransaction( RequestPriority priority, - scoped_ptr<HttpTransaction>* trans, - HttpTransactionDelegate* delegate) OVERRIDE { + scoped_ptr<HttpTransaction>* trans) OVERRIDE { NOTREACHED(); return ERR_UNEXPECTED; } @@ -302,12 +319,82 @@ class MockHttpTransactionFactory : public HttpTransactionFactory { SpdySessionKey spdy_session_key_; }; +class DeletingSocketStreamDelegate : public SocketStream::Delegate { + public: + DeletingSocketStreamDelegate() + : delete_next_(false) {} + + // Since this class needs to be able to delete |job_|, it must be the only + // reference holder (except for temporary references). Provide access to the + // pointer for tests to use. + WebSocketJob* job() { return job_.get(); } + + void set_job(WebSocketJob* job) { job_ = job; } + + // After calling this, the next call to a method on this delegate will delete + // the WebSocketJob object. + void set_delete_next(bool delete_next) { delete_next_ = delete_next; } + + void DeleteJobMaybe() { + if (delete_next_) { + job_->DetachContext(); + job_->DetachDelegate(); + job_ = NULL; + } + } + + // SocketStream::Delegate implementation + + // OnStartOpenConnection() is not implemented by SocketStreamDispatcherHost + + virtual void OnConnected(SocketStream* socket, + int max_pending_send_allowed) OVERRIDE { + DeleteJobMaybe(); + } + + virtual void OnSentData(SocketStream* socket, int amount_sent) OVERRIDE { + DeleteJobMaybe(); + } + + virtual void OnReceivedData(SocketStream* socket, + const char* data, + int len) OVERRIDE { + DeleteJobMaybe(); + } + + virtual void OnClose(SocketStream* socket) OVERRIDE { DeleteJobMaybe(); } + + virtual void OnAuthRequired(SocketStream* socket, + AuthChallengeInfo* auth_info) OVERRIDE { + DeleteJobMaybe(); + } + + virtual void OnSSLCertificateError(SocketStream* socket, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE { + DeleteJobMaybe(); + } + + virtual void OnError(const SocketStream* socket, int error) OVERRIDE { + DeleteJobMaybe(); + } + + // CanGetCookies() and CanSetCookies() do not appear to be able to delete the + // WebSocketJob object. + + private: + scoped_refptr<WebSocketJob> job_; + bool delete_next_; +}; + } // namespace class WebSocketJobTest : public PlatformTest, public ::testing::WithParamInterface<NextProto> { public: - WebSocketJobTest() : spdy_util_(GetParam()) {} + WebSocketJobTest() + : spdy_util_(GetParam()), + enable_websocket_over_spdy_(false) {} virtual void SetUp() OVERRIDE { stream_type_ = STREAM_INVALID; @@ -334,6 +421,7 @@ class WebSocketJobTest : public PlatformTest, int WaitForResult() { return sync_test_callback_.WaitForResult(); } + protected: enum StreamType { STREAM_INVALID, @@ -357,12 +445,13 @@ class WebSocketJobTest : public PlatformTest, websocket_ = new WebSocketJob(delegate); if (stream_type == STREAM_MOCK_SOCKET) - socket_ = new MockSocketStream(url, websocket_.get()); + socket_ = new MockSocketStream(url, websocket_.get(), context_.get(), + NULL); if (stream_type == STREAM_SOCKET || stream_type == STREAM_SPDY_WEBSOCKET) { if (stream_type == STREAM_SPDY_WEBSOCKET) { - http_factory_.reset( - new MockHttpTransactionFactory(GetParam(), data_.get())); + http_factory_.reset(new MockHttpTransactionFactory( + GetParam(), data_.get(), enable_websocket_over_spdy_)); context_->set_http_transaction_factory(http_factory_.get()); } @@ -373,7 +462,7 @@ class WebSocketJobTest : public PlatformTest, host_resolver_.reset(new MockHostResolver); context_->set_host_resolver(host_resolver_.get()); - socket_ = new SocketStream(url, websocket_.get()); + socket_ = new SocketStream(url, websocket_.get(), context_.get(), NULL); socket_factory_.reset(new MockClientSocketFactory); DCHECK(data_.get()); socket_factory_->AddSocketDataProvider(data_.get()); @@ -381,7 +470,6 @@ class WebSocketJobTest : public PlatformTest, } websocket_->InitSocketStream(socket_.get()); - websocket_->set_context(context_.get()); // MockHostResolver resolves all hosts to 127.0.0.1; however, when we create // a WebSocketJob purely to block another one in a throttling test, we don't // perform a real connect. In that case, the following address is used @@ -448,6 +536,9 @@ class WebSocketJobTest : public PlatformTest, scoped_ptr<MockHostResolver> host_resolver_; scoped_ptr<MockHttpTransactionFactory> http_factory_; + // Must be set before call to enable_websocket_over_spdy, defaults to false. + bool enable_websocket_over_spdy_; + static const char kHandshakeRequestWithoutCookie[]; static const char kHandshakeRequestWithCookie[]; static const char kHandshakeRequestWithFilteredCookie[]; @@ -466,6 +557,34 @@ class WebSocketJobTest : public PlatformTest, static const size_t kDataWorldLength; }; +// Tests using this fixture verify that the WebSocketJob can handle being +// deleted while calling back to the delegate correctly. These tests need to be +// run under AddressSanitizer or other systems for detecting use-after-free +// errors in order to find problems. +class WebSocketJobDeleteTest : public ::testing::Test { + protected: + WebSocketJobDeleteTest() + : delegate_(new DeletingSocketStreamDelegate), + cookie_store_(new MockCookieStore), + context_(new MockURLRequestContext(cookie_store_.get())) { + WebSocketJob* websocket = new WebSocketJob(delegate_.get()); + delegate_->set_job(websocket); + + socket_ = new MockSocketStream( + GURL("ws://127.0.0.1/"), websocket, context_.get(), NULL); + + websocket->InitSocketStream(socket_.get()); + } + + void SetDeleteNext() { return delegate_->set_delete_next(true); } + WebSocketJob* job() { return delegate_->job(); } + + scoped_ptr<DeletingSocketStreamDelegate> delegate_; + scoped_refptr<MockCookieStore> cookie_store_; + scoped_ptr<MockURLRequestContext> context_; + scoped_refptr<SocketStream> socket_; +}; + const char WebSocketJobTest::kHandshakeRequestWithoutCookie[] = "GET /demo HTTP/1.1\r\n" "Host: example.com\r\n" @@ -598,11 +717,10 @@ INSTANTIATE_TEST_CASE_P( NextProto, WebSocketJobTest, testing::Values(kProtoDeprecatedSPDY2, - kProtoSPDY3, kProtoSPDY31, kProtoSPDY4a2, - kProtoHTTP2Draft04)); + kProtoSPDY3, kProtoSPDY31, kProtoSPDY4)); TEST_P(WebSocketJobTest, DelayedCookies) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; GURL url("ws://example.com/demo"); GURL cookieUrl("http://example.com/demo"); CookieOptions cookie_options; @@ -731,14 +849,14 @@ void WebSocketJobTest::TestHSTSUpgrade() { scoped_refptr<SocketStreamJob> job = SocketStreamJob::CreateSocketStreamJob( url, &delegate, context_->transport_security_state(), - context_->ssl_config_service()); + context_->ssl_config_service(), NULL, NULL); EXPECT_TRUE(GetSocket(job.get())->is_secure()); job->DetachDelegate(); url = GURL("ws://donotupgrademe.com/"); job = SocketStreamJob::CreateSocketStreamJob( url, &delegate, context_->transport_security_state(), - context_->ssl_config_service()); + context_->ssl_config_service(), NULL, NULL); EXPECT_FALSE(GetSocket(job.get())->is_secure()); job->DetachDelegate(); } @@ -1005,87 +1123,79 @@ void WebSocketJobTest::TestThrottlingLimit() { // Execute tests in both spdy-disabled mode and spdy-enabled mode. TEST_P(WebSocketJobTest, SimpleHandshake) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestSimpleHandshake(); } TEST_P(WebSocketJobTest, SlowHandshake) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestSlowHandshake(); } TEST_P(WebSocketJobTest, HandshakeWithCookie) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestHandshakeWithCookie(); } TEST_P(WebSocketJobTest, HandshakeWithCookieButNotAllowed) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestHandshakeWithCookieButNotAllowed(); } TEST_P(WebSocketJobTest, HSTSUpgrade) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestHSTSUpgrade(); } TEST_P(WebSocketJobTest, InvalidSendData) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestInvalidSendData(); } TEST_P(WebSocketJobTest, SimpleHandshakeSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestSimpleHandshake(); } TEST_P(WebSocketJobTest, SlowHandshakeSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestSlowHandshake(); } TEST_P(WebSocketJobTest, HandshakeWithCookieSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestHandshakeWithCookie(); } TEST_P(WebSocketJobTest, HandshakeWithCookieButNotAllowedSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestHandshakeWithCookieButNotAllowed(); } TEST_P(WebSocketJobTest, HSTSUpgradeSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestHSTSUpgrade(); } TEST_P(WebSocketJobTest, InvalidSendDataSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestInvalidSendData(); } TEST_P(WebSocketJobTest, ConnectByWebSocket) { - WebSocketJob::set_websocket_over_spdy_enabled(false); + enable_websocket_over_spdy_ = true; TestConnectByWebSocket(THROTTLING_OFF); } TEST_P(WebSocketJobTest, ConnectByWebSocketSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestConnectByWebSocket(THROTTLING_OFF); } TEST_P(WebSocketJobTest, ConnectBySpdy) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestConnectBySpdy(SPDY_OFF, THROTTLING_OFF); } TEST_P(WebSocketJobTest, ConnectBySpdySpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestConnectBySpdy(SPDY_ON, THROTTLING_OFF); } TEST_P(WebSocketJobTest, ThrottlingWebSocket) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestConnectByWebSocket(THROTTLING_ON); } @@ -1094,20 +1204,89 @@ TEST_P(WebSocketJobTest, ThrottlingMaxNumberOfThrottledJobLimit) { } TEST_P(WebSocketJobTest, ThrottlingWebSocketSpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestConnectByWebSocket(THROTTLING_ON); } TEST_P(WebSocketJobTest, ThrottlingSpdy) { - WebSocketJob::set_websocket_over_spdy_enabled(false); TestConnectBySpdy(SPDY_OFF, THROTTLING_ON); } TEST_P(WebSocketJobTest, ThrottlingSpdySpdyEnabled) { - WebSocketJob::set_websocket_over_spdy_enabled(true); + enable_websocket_over_spdy_ = true; TestConnectBySpdy(SPDY_ON, THROTTLING_ON); } +TEST_F(WebSocketJobDeleteTest, OnClose) { + SetDeleteNext(); + job()->OnClose(socket_.get()); + // OnClose() sets WebSocketJob::_socket to NULL before we can detach it, so + // socket_->delegate is still set at this point. Clear it to avoid hitting + // DCHECK(!delegate_) in the SocketStream destructor. SocketStream::Finish() + // is the only caller of this method in real code, and it also sets delegate_ + // to NULL. + socket_->DetachDelegate(); + EXPECT_FALSE(job()); +} + +TEST_F(WebSocketJobDeleteTest, OnAuthRequired) { + SetDeleteNext(); + job()->OnAuthRequired(socket_.get(), NULL); + EXPECT_FALSE(job()); +} + +TEST_F(WebSocketJobDeleteTest, OnSSLCertificateError) { + SSLInfo ssl_info; + SetDeleteNext(); + job()->OnSSLCertificateError(socket_.get(), ssl_info, true); + EXPECT_FALSE(job()); +} + +TEST_F(WebSocketJobDeleteTest, OnError) { + SetDeleteNext(); + job()->OnError(socket_.get(), ERR_CONNECTION_RESET); + EXPECT_FALSE(job()); +} + +TEST_F(WebSocketJobDeleteTest, OnSentSpdyHeaders) { + job()->Connect(); + SetDeleteNext(); + job()->OnSentSpdyHeaders(); + EXPECT_FALSE(job()); +} + +TEST_F(WebSocketJobDeleteTest, OnSentHandshakeRequest) { + static const char kMinimalRequest[] = + "GET /demo HTTP/1.1\r\n" + "Host: example.com\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Origin: http://example.com\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + const size_t kMinimalRequestSize = arraysize(kMinimalRequest) - 1; + job()->Connect(); + job()->SendData(kMinimalRequest, kMinimalRequestSize); + SetDeleteNext(); + job()->OnSentData(socket_.get(), kMinimalRequestSize); + EXPECT_FALSE(job()); +} + +TEST_F(WebSocketJobDeleteTest, NotifyHeadersComplete) { + static const char kMinimalResponse[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "\r\n"; + job()->Connect(); + SetDeleteNext(); + job()->OnReceivedData( + socket_.get(), kMinimalResponse, arraysize(kMinimalResponse) - 1); + EXPECT_FALSE(job()); +} + // TODO(toyoshim): Add tests to verify throttling, SPDY stream limitation. // TODO(toyoshim,yutak): Add tests to verify closing handshake. } // namespace net diff --git a/chromium/net/websockets/websocket_net_log_params_test.cc b/chromium/net/websockets/websocket_net_log_params_test.cc index 4690fd66964..d6d2a0d3ff5 100644 --- a/chromium/net/websockets/websocket_net_log_params_test.cc +++ b/chromium/net/websockets/websocket_net_log_params_test.cc @@ -44,7 +44,7 @@ TEST(NetLogWebSocketHandshakeParameterTest, ToValue) { scoped_ptr<base::Value> actual( net::NetLogWebSocketHandshakeCallback(&testInput, - net::NetLog::LOG_BASIC)); + net::NetLog::LOG_ALL)); EXPECT_TRUE(expected.Equals(actual.get())); } diff --git a/chromium/net/websockets/websocket_stream.cc b/chromium/net/websockets/websocket_stream.cc index e81c24e706e..8ddce8d9b72 100644 --- a/chromium/net/websockets/websocket_stream.cc +++ b/chromium/net/websockets/websocket_stream.cc @@ -6,16 +6,21 @@ #include "base/logging.h" #include "base/memory/scoped_ptr.h" +#include "base/metrics/histogram.h" +#include "base/metrics/sparse_histogram.h" +#include "net/base/load_flags.h" #include "net/http/http_request_headers.h" #include "net/http/http_status_code.h" #include "net/url_request/url_request.h" #include "net/url_request/url_request_context.h" #include "net/websockets/websocket_errors.h" +#include "net/websockets/websocket_event_interface.h" #include "net/websockets/websocket_handshake_constants.h" #include "net/websockets/websocket_handshake_stream_base.h" #include "net/websockets/websocket_handshake_stream_create_helper.h" #include "net/websockets/websocket_test_util.h" #include "url/gurl.h" +#include "url/origin.h" namespace net { namespace { @@ -24,10 +29,32 @@ class StreamRequestImpl; class Delegate : public URLRequest::Delegate { public: - explicit Delegate(StreamRequestImpl* owner) : owner_(owner) {} - virtual ~Delegate() {} + enum HandshakeResult { + INCOMPLETE, + CONNECTED, + FAILED, + NUM_HANDSHAKE_RESULT_TYPES, + }; + + explicit Delegate(StreamRequestImpl* owner) + : owner_(owner), result_(INCOMPLETE) {} + virtual ~Delegate() { + UMA_HISTOGRAM_ENUMERATION( + "Net.WebSocket.HandshakeResult", result_, NUM_HANDSHAKE_RESULT_TYPES); + } // Implementation of URLRequest::Delegate methods. + virtual void OnReceivedRedirect(URLRequest* request, + const GURL& new_url, + bool* defer_redirect) OVERRIDE { + // HTTP status codes returned by HttpStreamParser are filtered by + // WebSocketBasicHandshakeStream, and only 101, 401 and 407 are permitted + // back up the stack to HttpNetworkTransaction. In particular, redirect + // codes are never allowed, and so URLRequest never sees a redirect on a + // WebSocket request. + NOTREACHED(); + } + virtual void OnResponseStarted(URLRequest* request) OVERRIDE; virtual void OnAuthRequired(URLRequest* request, @@ -45,6 +72,7 @@ class Delegate : public URLRequest::Delegate { private: StreamRequestImpl* owner_; + HandshakeResult result_; }; class StreamRequestImpl : public WebSocketStreamRequest { @@ -52,25 +80,64 @@ class StreamRequestImpl : public WebSocketStreamRequest { StreamRequestImpl( const GURL& url, const URLRequestContext* context, + const url::Origin& origin, scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate, - WebSocketHandshakeStreamCreateHelper* create_helper) + scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper) : delegate_(new Delegate(this)), url_request_(url, DEFAULT_PRIORITY, delegate_.get(), context), connect_delegate_(connect_delegate.Pass()), - create_helper_(create_helper) {} + create_helper_(create_helper.release()) { + create_helper_->set_failure_message(&failure_message_); + HttpRequestHeaders headers; + headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase); + headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade); + headers.SetHeader(HttpRequestHeaders::kOrigin, origin.string()); + headers.SetHeader(websockets::kSecWebSocketVersion, + websockets::kSupportedVersion); + url_request_.SetExtraRequestHeaders(headers); + + // This passes the ownership of |create_helper_| to |url_request_|. + url_request_.SetUserData( + WebSocketHandshakeStreamBase::CreateHelper::DataKey(), + create_helper_); + url_request_.SetLoadFlags(LOAD_DISABLE_CACHE | + LOAD_BYPASS_CACHE | + LOAD_DO_NOT_PROMPT_FOR_LOGIN); + } // Destroying this object destroys the URLRequest, which cancels the request // and so terminates the handshake if it is incomplete. virtual ~StreamRequestImpl() {} - URLRequest* url_request() { return &url_request_; } + void Start() { + url_request_.Start(); + } void PerformUpgrade() { connect_delegate_->OnSuccess(create_helper_->stream()->Upgrade()); } void ReportFailure() { - connect_delegate_->OnFailure(kWebSocketErrorAbnormalClosure); + if (failure_message_.empty()) { + switch (url_request_.status().status()) { + case URLRequestStatus::SUCCESS: + case URLRequestStatus::IO_PENDING: + break; + case URLRequestStatus::CANCELED: + failure_message_ = "WebSocket opening handshake was canceled"; + break; + case URLRequestStatus::FAILED: + failure_message_ = + std::string("Error in connection establishment: ") + + ErrorToString(url_request_.status().error()); + break; + } + } + connect_delegate_->OnFailure(failure_message_); + } + + WebSocketStream::ConnectDelegate* connect_delegate() const { + return connect_delegate_.get(); } private: @@ -86,11 +153,47 @@ class StreamRequestImpl : public WebSocketStreamRequest { // Owned by the URLRequest. WebSocketHandshakeStreamCreateHelper* create_helper_; + + // The failure message supplied by WebSocketBasicHandshakeStream, if any. + std::string failure_message_; +}; + +class SSLErrorCallbacks : public WebSocketEventInterface::SSLErrorCallbacks { + public: + explicit SSLErrorCallbacks(URLRequest* url_request) + : url_request_(url_request) {} + + virtual void CancelSSLRequest(int error, const SSLInfo* ssl_info) OVERRIDE { + if (ssl_info) { + url_request_->CancelWithSSLError(error, *ssl_info); + } else { + url_request_->CancelWithError(error); + } + } + + virtual void ContinueSSLRequest() OVERRIDE { + url_request_->ContinueDespiteLastError(); + } + + private: + URLRequest* url_request_; }; void Delegate::OnResponseStarted(URLRequest* request) { - switch (request->GetResponseCode()) { + // All error codes, including OK and ABORTED, as with + // Net.ErrorCodesForMainFrame3 + UMA_HISTOGRAM_SPARSE_SLOWLY("Net.WebSocket.ErrorCodes", + -request->status().error()); + if (!request->status().is_success()) { + DVLOG(3) << "OnResponseStarted (request failed)"; + owner_->ReportFailure(); + return; + } + const int response_code = request->GetResponseCode(); + DVLOG(3) << "OnResponseStarted (response code " << response_code << ")"; + switch (response_code) { case HTTP_SWITCHING_PROTOCOLS: + result_ = CONNECTED; owner_->PerformUpgrade(); return; @@ -99,61 +202,42 @@ void Delegate::OnResponseStarted(URLRequest* request) { return; default: + result_ = FAILED; owner_->ReportFailure(); } } void Delegate::OnAuthRequired(URLRequest* request, AuthChallengeInfo* auth_info) { + // This should only be called if credentials are not already stored. request->CancelAuth(); } void Delegate::OnCertificateRequested(URLRequest* request, SSLCertRequestInfo* cert_request_info) { - request->ContinueWithCertificate(NULL); + // This method is called when a client certificate is requested, and the + // request context does not already contain a client certificate selection for + // the endpoint. In this case, a main frame resource request would pop-up UI + // to permit selection of a client certificate, but since WebSockets are + // sub-resources they should not pop-up UI and so there is nothing more we can + // do. + request->Cancel(); } void Delegate::OnSSLCertificateError(URLRequest* request, const SSLInfo& ssl_info, bool fatal) { - request->Cancel(); + owner_->connect_delegate()->OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>( + new SSLErrorCallbacks(request)), + ssl_info, + fatal); } void Delegate::OnReadCompleted(URLRequest* request, int bytes_read) { NOTREACHED(); } -// Internal implementation of CreateAndConnectStream and -// CreateAndConnectStreamForTesting. -scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamWithCreateHelper( - const GURL& socket_url, - scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper, - const GURL& origin, - URLRequestContext* url_request_context, - const BoundNetLog& net_log, - scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { - scoped_ptr<StreamRequestImpl> request( - new StreamRequestImpl(socket_url, - url_request_context, - connect_delegate.Pass(), - create_helper.get())); - HttpRequestHeaders headers; - headers.SetHeader(websockets::kUpgrade, websockets::kWebSocketLowercase); - headers.SetHeader(HttpRequestHeaders::kConnection, websockets::kUpgrade); - headers.SetHeader(HttpRequestHeaders::kOrigin, origin.spec()); - // TODO(ricea): Move the version number to websocket_handshake_constants.h - headers.SetHeader(websockets::kSecWebSocketVersion, - websockets::kSupportedVersion); - request->url_request()->SetExtraRequestHeaders(headers); - request->url_request()->SetUserData( - WebSocketHandshakeStreamBase::CreateHelper::DataKey(), - create_helper.release()); - request->url_request()->SetLoadFlags(LOAD_DISABLE_CACHE | - LOAD_DO_NOT_PROMPT_FOR_LOGIN); - request->url_request()->Start(); - return request.PassAs<WebSocketStreamRequest>(); -} - } // namespace WebSocketStreamRequest::~WebSocketStreamRequest() {} @@ -166,34 +250,39 @@ WebSocketStream::ConnectDelegate::~ConnectDelegate() {} scoped_ptr<WebSocketStreamRequest> WebSocketStream::CreateAndConnectStream( const GURL& socket_url, const std::vector<std::string>& requested_subprotocols, - const GURL& origin, + const url::Origin& origin, URLRequestContext* url_request_context, const BoundNetLog& net_log, scoped_ptr<ConnectDelegate> connect_delegate) { scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper( - new WebSocketHandshakeStreamCreateHelper(requested_subprotocols)); - return CreateAndConnectStreamWithCreateHelper(socket_url, - create_helper.Pass(), - origin, - url_request_context, - net_log, - connect_delegate.Pass()); + new WebSocketHandshakeStreamCreateHelper(connect_delegate.get(), + requested_subprotocols)); + scoped_ptr<StreamRequestImpl> request( + new StreamRequestImpl(socket_url, + url_request_context, + origin, + connect_delegate.Pass(), + create_helper.Pass())); + request->Start(); + return request.PassAs<WebSocketStreamRequest>(); } // This is declared in websocket_test_util.h. scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting( - const GURL& socket_url, - scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper, - const GURL& origin, - URLRequestContext* url_request_context, - const BoundNetLog& net_log, - scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { - return CreateAndConnectStreamWithCreateHelper(socket_url, - create_helper.Pass(), - origin, - url_request_context, - net_log, - connect_delegate.Pass()); + const GURL& socket_url, + scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper, + const url::Origin& origin, + URLRequestContext* url_request_context, + const BoundNetLog& net_log, + scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate) { + scoped_ptr<StreamRequestImpl> request( + new StreamRequestImpl(socket_url, + url_request_context, + origin, + connect_delegate.Pass(), + create_helper.Pass())); + request->Start(); + return request.PassAs<WebSocketStreamRequest>(); } } // namespace net diff --git a/chromium/net/websockets/websocket_stream.h b/chromium/net/websockets/websocket_stream.h index c08f8dc39b7..09f11b22f1a 100644 --- a/chromium/net/websockets/websocket_stream.h +++ b/chromium/net/websockets/websocket_stream.h @@ -14,9 +14,16 @@ #include "base/memory/scoped_vector.h" #include "net/base/completion_callback.h" #include "net/base/net_export.h" +#include "net/websockets/websocket_event_interface.h" +#include "net/websockets/websocket_handshake_request_info.h" +#include "net/websockets/websocket_handshake_response_info.h" class GURL; +namespace url { +class Origin; +} // namespace url + namespace net { class BoundNetLog; @@ -57,10 +64,26 @@ class NET_EXPORT_PRIVATE WebSocketStream { // WebSocketStream. virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) = 0; - // Called on failure to connect. The parameter is either one of the values - // defined in net::WebSocketError, or an error defined by some WebSocket - // extension protocol that we implement. - virtual void OnFailure(unsigned short websocket_error) = 0; + // Called on failure to connect. + // |message| contains defails of the failure. + virtual void OnFailure(const std::string& message) = 0; + + // Called when the WebSocket Opening Handshake starts. + virtual void OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) = 0; + + // Called when the WebSocket Opening Handshake ends. + virtual void OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) = 0; + + // Called when there is an SSL certificate error. Should call + // ssl_error_callbacks->ContinueSSLRequest() or + // ssl_error_callbacks->CancelSSLRequest(). + virtual void OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> + ssl_error_callbacks, + const SSLInfo& ssl_info, + bool fatal) = 0; }; // Create and connect a WebSocketStream of an appropriate type. The actual @@ -78,7 +101,7 @@ class NET_EXPORT_PRIVATE WebSocketStream { static scoped_ptr<WebSocketStreamRequest> CreateAndConnectStream( const GURL& socket_url, const std::vector<std::string>& requested_subprotocols, - const GURL& origin, + const url::Origin& origin, URLRequestContext* url_request_context, const BoundNetLog& net_log, scoped_ptr<ConnectDelegate> connect_delegate); @@ -125,6 +148,10 @@ class NET_EXPORT_PRIVATE WebSocketStream { // calling callback.Run() (and any calling methods in the same object) must // return immediately without any further method calls or access to member // variables. Implementors should write test(s) for this case. + // + // Extensions which use reserved header bits should clear them when they are + // set correctly. If the reserved header bits are set incorrectly, it is okay + // to leave it to the caller to report the error. virtual int ReadFrames(ScopedVector<WebSocketFrame>* frames, const CompletionCallback& callback) = 0; diff --git a/chromium/net/websockets/websocket_stream_test.cc b/chromium/net/websockets/websocket_stream_test.cc index 3e11a95ac1c..0a6b99be4d6 100644 --- a/chromium/net/websockets/websocket_stream_test.cc +++ b/chromium/net/websockets/websocket_stream_test.cc @@ -4,31 +4,67 @@ #include "net/websockets/websocket_stream.h" +#include <algorithm> #include <string> +#include <utility> #include <vector> +#include "base/memory/scoped_vector.h" +#include "base/metrics/histogram.h" +#include "base/metrics/histogram_samples.h" +#include "base/metrics/statistics_recorder.h" #include "base/run_loop.h" +#include "base/strings/stringprintf.h" #include "net/base/net_errors.h" +#include "net/base/test_data_directory.h" +#include "net/http/http_request_headers.h" +#include "net/http/http_response_headers.h" #include "net/socket/client_socket_handle.h" #include "net/socket/socket_test_util.h" +#include "net/test/cert_test_util.h" #include "net/url_request/url_request_test_util.h" #include "net/websockets/websocket_basic_handshake_stream.h" +#include "net/websockets/websocket_frame.h" +#include "net/websockets/websocket_handshake_request_info.h" +#include "net/websockets/websocket_handshake_response_info.h" #include "net/websockets/websocket_handshake_stream_create_helper.h" #include "net/websockets/websocket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" +#include "url/origin.h" namespace net { namespace { +typedef std::pair<std::string, std::string> HeaderKeyValuePair; + +std::vector<HeaderKeyValuePair> ToVector(const HttpRequestHeaders& headers) { + HttpRequestHeaders::Iterator it(headers); + std::vector<HeaderKeyValuePair> result; + while (it.GetNext()) + result.push_back(HeaderKeyValuePair(it.name(), it.value())); + return result; +} + +std::vector<HeaderKeyValuePair> ToVector(const HttpResponseHeaders& headers) { + void* iter = NULL; + std::string name, value; + std::vector<HeaderKeyValuePair> result; + while (headers.EnumerateHeaderLines(&iter, &name, &value)) + result.push_back(HeaderKeyValuePair(name, value)); + return result; +} + // A sub-class of WebSocketHandshakeStreamCreateHelper which always sets a // deterministic key to use in the WebSocket handshake. class DeterministicKeyWebSocketHandshakeStreamCreateHelper : public WebSocketHandshakeStreamCreateHelper { public: DeterministicKeyWebSocketHandshakeStreamCreateHelper( + WebSocketStream::ConnectDelegate* connect_delegate, const std::vector<std::string>& requested_subprotocols) - : WebSocketHandshakeStreamCreateHelper(requested_subprotocols) {} + : WebSocketHandshakeStreamCreateHelper(connect_delegate, + requested_subprotocols) {} virtual WebSocketHandshakeStreamBase* CreateBasicStream( scoped_ptr<ClientSocketHandle> connection, @@ -44,8 +80,8 @@ class DeterministicKeyWebSocketHandshakeStreamCreateHelper }; class WebSocketStreamCreateTest : public ::testing::Test { - protected: - WebSocketStreamCreateTest() : websocket_error_(0) {} + public: + WebSocketStreamCreateTest() : has_failed_(false), ssl_fatal_(false) {} void CreateAndConnectCustomResponse( const std::string& socket_url, @@ -82,7 +118,7 @@ class WebSocketStreamCreateTest : public ::testing::Test { const std::vector<std::string>& sub_protocols, const std::string& origin, scoped_ptr<DeterministicSocketData> socket_data) { - url_request_context_host_.SetRawExpectations(socket_data.Pass()); + url_request_context_host_.AddRawExpectations(socket_data.Pass()); CreateAndConnectStream(socket_url, sub_protocols, origin); } @@ -91,16 +127,25 @@ class WebSocketStreamCreateTest : public ::testing::Test { void CreateAndConnectStream(const std::string& socket_url, const std::vector<std::string>& sub_protocols, const std::string& origin) { + for (size_t i = 0; i < ssl_data_.size(); ++i) { + scoped_ptr<SSLSocketDataProvider> ssl_data(ssl_data_[i]); + ssl_data_[i] = NULL; + url_request_context_host_.AddSSLSocketDataProvider(ssl_data.Pass()); + } + ssl_data_.clear(); + scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate( + new TestConnectDelegate(this)); + WebSocketStream::ConnectDelegate* delegate = connect_delegate.get(); + scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper( + new DeterministicKeyWebSocketHandshakeStreamCreateHelper( + delegate, sub_protocols)); stream_request_ = ::net::CreateAndConnectStreamForTesting( GURL(socket_url), - scoped_ptr<WebSocketHandshakeStreamCreateHelper>( - new DeterministicKeyWebSocketHandshakeStreamCreateHelper( - sub_protocols)), - GURL(origin), + create_helper.Pass(), + url::Origin(origin), url_request_context_host_.GetURLRequestContext(), BoundNetLog(), - scoped_ptr<WebSocketStream::ConnectDelegate>( - new TestConnectDelegate(this))); + connect_delegate.Pass()); } static void RunUntilIdle() { base::RunLoop().RunUntilIdle(); } @@ -110,18 +155,43 @@ class WebSocketStreamCreateTest : public ::testing::Test { return std::vector<std::string>(); } - uint16 error() const { return websocket_error_; } + const std::string& failure_message() const { return failure_message_; } + bool has_failed() const { return has_failed_; } class TestConnectDelegate : public WebSocketStream::ConnectDelegate { public: - TestConnectDelegate(WebSocketStreamCreateTest* owner) : owner_(owner) {} + explicit TestConnectDelegate(WebSocketStreamCreateTest* owner) + : owner_(owner) {} virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE { stream.swap(owner_->stream_); } - virtual void OnFailure(uint16 websocket_error) OVERRIDE { - owner_->websocket_error_ = websocket_error; + virtual void OnFailure(const std::string& message) OVERRIDE { + owner_->has_failed_ = true; + owner_->failure_message_ = message; + } + + virtual void OnStartOpeningHandshake( + scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE { + if (owner_->request_info_) + ADD_FAILURE(); + owner_->request_info_ = request.Pass(); + } + virtual void OnFinishOpeningHandshake( + scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE { + if (owner_->response_info_) + ADD_FAILURE(); + owner_->response_info_ = response.Pass(); + } + virtual void OnSSLCertificateError( + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> + ssl_error_callbacks, + const SSLInfo& ssl_info, + bool fatal) OVERRIDE { + owner_->ssl_error_callbacks_ = ssl_error_callbacks.Pass(); + owner_->ssl_info_ = ssl_info; + owner_->ssl_fatal_ = fatal; } private: @@ -132,22 +202,144 @@ class WebSocketStreamCreateTest : public ::testing::Test { scoped_ptr<WebSocketStreamRequest> stream_request_; // Only set if the connection succeeded. scoped_ptr<WebSocketStream> stream_; - // Only set if the connection failed. 0 otherwise. - uint16 websocket_error_; + // Only set if the connection failed. + std::string failure_message_; + bool has_failed_; + scoped_ptr<WebSocketHandshakeRequestInfo> request_info_; + scoped_ptr<WebSocketHandshakeResponseInfo> response_info_; + scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> ssl_error_callbacks_; + SSLInfo ssl_info_; + bool ssl_fatal_; + ScopedVector<SSLSocketDataProvider> ssl_data_; +}; + +// There are enough tests of the Sec-WebSocket-Extensions header that they +// deserve their own test fixture. +class WebSocketStreamCreateExtensionTest : public WebSocketStreamCreateTest { + public: + // Performs a standard connect, with the value of the Sec-WebSocket-Extensions + // header in the response set to |extensions_header_value|. Runs the event + // loop to allow the connect to complete. + void CreateAndConnectWithExtensions( + const std::string& extensions_header_value) { + CreateAndConnectStandard( + "ws://localhost/testing_path", + "/testing_path", + NoSubProtocols(), + "http://localhost", + "", + "Sec-WebSocket-Extensions: " + extensions_header_value + "\r\n"); + RunUntilIdle(); + } +}; + +class WebSocketStreamCreateUMATest : public ::testing::Test { + public: + // This enum should match with the enum in Delegate in websocket_stream.cc. + enum HandshakeResult { + INCOMPLETE, + CONNECTED, + FAILED, + NUM_HANDSHAKE_RESULT_TYPES, + }; + + class StreamCreation : public WebSocketStreamCreateTest { + virtual void TestBody() OVERRIDE {} + }; + + scoped_ptr<base::HistogramSamples> GetSamples(const std::string& name) { + base::HistogramBase* histogram = + base::StatisticsRecorder::FindHistogram(name); + return histogram ? histogram->SnapshotSamples() + : scoped_ptr<base::HistogramSamples>(); + } }; // Confirm that the basic case works as expected. TEST_F(WebSocketStreamCreateTest, SimpleSuccess) { CreateAndConnectStandard( - "ws://localhost/", "/", NoSubProtocols(), "http://localhost/", "", ""); + "ws://localhost/", "/", NoSubProtocols(), "http://localhost", "", ""); + EXPECT_FALSE(request_info_); + EXPECT_FALSE(response_info_); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_TRUE(stream_); + EXPECT_TRUE(request_info_); + EXPECT_TRUE(response_info_); +} + +TEST_F(WebSocketStreamCreateTest, HandshakeInfo) { + static const char kResponse[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "foo: bar, baz\r\n" + "hoge: fuga\r\n" + "hoge: piyo\r\n" + "\r\n"; + + CreateAndConnectCustomResponse( + "ws://localhost/", + "/", + NoSubProtocols(), + "http://localhost", + "", + kResponse); + EXPECT_FALSE(request_info_); + EXPECT_FALSE(response_info_); + RunUntilIdle(); + EXPECT_TRUE(stream_); + ASSERT_TRUE(request_info_); + ASSERT_TRUE(response_info_); + std::vector<HeaderKeyValuePair> request_headers = + ToVector(request_info_->headers); + // We examine the contents of request_info_ and response_info_ + // mainly only in this test case. + EXPECT_EQ(GURL("ws://localhost/"), request_info_->url); + EXPECT_EQ(GURL("ws://localhost/"), response_info_->url); + EXPECT_EQ(101, response_info_->status_code); + EXPECT_EQ("Switching Protocols", response_info_->status_text); + ASSERT_EQ(12u, request_headers.size()); + EXPECT_EQ(HeaderKeyValuePair("Host", "localhost"), request_headers[0]); + EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), request_headers[1]); + EXPECT_EQ(HeaderKeyValuePair("Pragma", "no-cache"), request_headers[2]); + EXPECT_EQ(HeaderKeyValuePair("Cache-Control", "no-cache"), + request_headers[3]); + EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), request_headers[4]); + EXPECT_EQ(HeaderKeyValuePair("Origin", "http://localhost"), + request_headers[5]); + EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Version", "13"), + request_headers[6]); + EXPECT_EQ(HeaderKeyValuePair("User-Agent", ""), request_headers[7]); + EXPECT_EQ(HeaderKeyValuePair("Accept-Encoding", "gzip,deflate"), + request_headers[8]); + EXPECT_EQ(HeaderKeyValuePair("Accept-Language", "en-us,fr"), + request_headers[9]); + EXPECT_EQ("Sec-WebSocket-Key", request_headers[10].first); + EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Extensions", + "permessage-deflate; client_max_window_bits"), + request_headers[11]); + + std::vector<HeaderKeyValuePair> response_headers = + ToVector(*response_info_->headers); + ASSERT_EQ(6u, response_headers.size()); + // Sort the headers for ease of verification. + std::sort(response_headers.begin(), response_headers.end()); + + EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), response_headers[0]); + EXPECT_EQ("Sec-WebSocket-Accept", response_headers[1].first); + EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), response_headers[2]); + EXPECT_EQ(HeaderKeyValuePair("foo", "bar, baz"), response_headers[3]); + EXPECT_EQ(HeaderKeyValuePair("hoge", "fuga"), response_headers[4]); + EXPECT_EQ(HeaderKeyValuePair("hoge", "piyo"), response_headers[5]); } // Confirm that the stream isn't established until the message loop runs. TEST_F(WebSocketStreamCreateTest, NeedsToRunLoop) { CreateAndConnectStandard( - "ws://localhost/", "/", NoSubProtocols(), "http://localhost/", "", ""); + "ws://localhost/", "/", NoSubProtocols(), "http://localhost", "", ""); + EXPECT_FALSE(has_failed()); EXPECT_FALSE(stream_); } @@ -156,10 +348,11 @@ TEST_F(WebSocketStreamCreateTest, PathIsUsed) { CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", ""); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_TRUE(stream_); } @@ -168,10 +361,11 @@ TEST_F(WebSocketStreamCreateTest, OriginIsUsed) { CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", NoSubProtocols(), - "http://google.com/", + "http://google.com", "", ""); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_TRUE(stream_); } @@ -183,12 +377,13 @@ TEST_F(WebSocketStreamCreateTest, SubProtocolIsUsed) { CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", sub_protocols, - "http://google.com/", + "http://google.com", "Sec-WebSocket-Protocol: chatv11.chromium.org, " "chatv20.chromium.org\r\n", "Sec-WebSocket-Protocol: chatv20.chromium.org\r\n"); RunUntilIdle(); EXPECT_TRUE(stream_); + EXPECT_FALSE(has_failed()); EXPECT_EQ("chatv20.chromium.org", stream_->GetSubProtocol()); } @@ -197,25 +392,35 @@ TEST_F(WebSocketStreamCreateTest, UnsolicitedSubProtocol) { CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", NoSubProtocols(), - "http://google.com/", + "http://google.com", "", "Sec-WebSocket-Protocol: chatv20.chromium.org\r\n"); RunUntilIdle(); EXPECT_FALSE(stream_); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "Response must not include 'Sec-WebSocket-Protocol' header " + "if not present in request: chatv20.chromium.org", + failure_message()); } // Missing sub-protocol response is rejected. TEST_F(WebSocketStreamCreateTest, UnacceptedSubProtocol) { + std::vector<std::string> sub_protocols; + sub_protocols.push_back("chat.example.com"); CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", - std::vector<std::string>(1, "chat.example.com"), - "http://localhost/", + sub_protocols, + "http://localhost", "Sec-WebSocket-Protocol: chat.example.com\r\n", ""); RunUntilIdle(); EXPECT_FALSE(stream_); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "Sent non-empty 'Sec-WebSocket-Protocol' header " + "but no response was received", + failure_message()); } // Only one sub-protocol can be accepted. @@ -226,41 +431,261 @@ TEST_F(WebSocketStreamCreateTest, MultipleSubProtocolsInResponse) { CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", sub_protocols, - "http://google.com/", + "http://google.com", "Sec-WebSocket-Protocol: chatv11.chromium.org, " "chatv20.chromium.org\r\n", "Sec-WebSocket-Protocol: chatv11.chromium.org, " "chatv20.chromium.org\r\n"); RunUntilIdle(); EXPECT_FALSE(stream_); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Sec-WebSocket-Protocol' header must not appear " + "more than once in a response", + failure_message()); } -// Unknown extension in the response is rejected -TEST_F(WebSocketStreamCreateTest, UnknownExtension) { +// Unmatched sub-protocol should be rejected. +TEST_F(WebSocketStreamCreateTest, UnmatchedSubProtocolInResponse) { + std::vector<std::string> sub_protocols; + sub_protocols.push_back("chatv11.chromium.org"); + sub_protocols.push_back("chatv20.chromium.org"); CreateAndConnectStandard("ws://localhost/testing_path", "/testing_path", - NoSubProtocols(), - "http://localhost/", - "", - "Sec-WebSocket-Extensions: x-unknown-extension\r\n"); + sub_protocols, + "http://google.com", + "Sec-WebSocket-Protocol: chatv11.chromium.org, " + "chatv20.chromium.org\r\n", + "Sec-WebSocket-Protocol: chatv21.chromium.org\r\n"); + RunUntilIdle(); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Sec-WebSocket-Protocol' header value 'chatv21.chromium.org' " + "in response does not match any of sent values", + failure_message()); +} + +// permessage-deflate extension basic success case. +TEST_F(WebSocketStreamCreateExtensionTest, PerMessageDeflateSuccess) { + CreateAndConnectWithExtensions("permessage-deflate"); + EXPECT_TRUE(stream_); + EXPECT_FALSE(has_failed()); +} + +// permessage-deflate extensions success with all parameters. +TEST_F(WebSocketStreamCreateExtensionTest, PerMessageDeflateParamsSuccess) { + CreateAndConnectWithExtensions( + "permessage-deflate; client_no_context_takeover; " + "server_max_window_bits=11; client_max_window_bits=13; " + "server_no_context_takeover"); + EXPECT_TRUE(stream_); + EXPECT_FALSE(has_failed()); +} + +// Verify that incoming messages are actually decompressed with +// permessage-deflate enabled. +TEST_F(WebSocketStreamCreateExtensionTest, PerMessageDeflateInflates) { + CreateAndConnectCustomResponse( + "ws://localhost/testing_path", + "/testing_path", + NoSubProtocols(), + "http://localhost", + "", + WebSocketStandardResponse( + "Sec-WebSocket-Extensions: permessage-deflate\r\n") + + std::string( + "\xc1\x07" // WebSocket header (FIN + RSV1, Text payload 7 bytes) + "\xf2\x48\xcd\xc9\xc9\x07\x00", // "Hello" DEFLATE compressed + 9)); RunUntilIdle(); + + ASSERT_TRUE(stream_); + ScopedVector<WebSocketFrame> frames; + CompletionCallback callback; + ASSERT_EQ(OK, stream_->ReadFrames(&frames, callback)); + ASSERT_EQ(1U, frames.size()); + ASSERT_EQ(5U, frames[0]->header.payload_length); + EXPECT_EQ("Hello", std::string(frames[0]->data->data(), 5)); +} + +// Unknown extension in the response is rejected +TEST_F(WebSocketStreamCreateExtensionTest, UnknownExtension) { + CreateAndConnectWithExtensions("x-unknown-extension"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "Found an unsupported extension 'x-unknown-extension' " + "in 'Sec-WebSocket-Extensions' header", + failure_message()); +} + +// Malformed extensions are rejected (this file does not cover all possible +// parse failures, as the parser is covered thoroughly by its own unit tests). +TEST_F(WebSocketStreamCreateExtensionTest, MalformedExtension) { + CreateAndConnectWithExtensions(";"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: 'Sec-WebSocket-Extensions' header " + "value is rejected by the parser: ;", + failure_message()); +} + +// The permessage-deflate extension may only be specified once. +TEST_F(WebSocketStreamCreateExtensionTest, OnlyOnePerMessageDeflateAllowed) { + CreateAndConnectWithExtensions( + "permessage-deflate, permessage-deflate; client_max_window_bits=10"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: " + "Received duplicate permessage-deflate response", + failure_message()); +} + +// permessage-deflate parameters may not be duplicated. +TEST_F(WebSocketStreamCreateExtensionTest, NoDuplicateParameters) { + CreateAndConnectWithExtensions( + "permessage-deflate; client_no_context_takeover; " + "client_no_context_takeover"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received duplicate permessage-deflate extension parameter " + "client_no_context_takeover", + failure_message()); +} + +// permessage-deflate parameters must start with "client_" or "server_" +TEST_F(WebSocketStreamCreateExtensionTest, BadParameterPrefix) { + CreateAndConnectWithExtensions( + "permessage-deflate; absurd_no_context_takeover"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received an unexpected permessage-deflate extension parameter", + failure_message()); +} + +// permessage-deflate parameters must be either *_no_context_takeover or +// *_max_window_bits +TEST_F(WebSocketStreamCreateExtensionTest, BadParameterSuffix) { + CreateAndConnectWithExtensions( + "permessage-deflate; client_max_content_bits=5"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received an unexpected permessage-deflate extension parameter", + failure_message()); +} + +// *_no_context_takeover parameters must not have an argument +TEST_F(WebSocketStreamCreateExtensionTest, BadParameterValue) { + CreateAndConnectWithExtensions( + "permessage-deflate; client_no_context_takeover=true"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received invalid client_no_context_takeover parameter", + failure_message()); +} + +// *_max_window_bits must have an argument +TEST_F(WebSocketStreamCreateExtensionTest, NoMaxWindowBitsArgument) { + CreateAndConnectWithExtensions("permessage-deflate; client_max_window_bits"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "client_max_window_bits must have value", + failure_message()); +} + +// *_max_window_bits must be an integer +TEST_F(WebSocketStreamCreateExtensionTest, MaxWindowBitsValueInteger) { + CreateAndConnectWithExtensions( + "permessage-deflate; server_max_window_bits=banana"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received invalid server_max_window_bits parameter", + failure_message()); +} + +// *_max_window_bits must be >= 8 +TEST_F(WebSocketStreamCreateExtensionTest, MaxWindowBitsValueTooSmall) { + CreateAndConnectWithExtensions( + "permessage-deflate; server_max_window_bits=7"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received invalid server_max_window_bits parameter", + failure_message()); +} + +// *_max_window_bits must be <= 15 +TEST_F(WebSocketStreamCreateExtensionTest, MaxWindowBitsValueTooBig) { + CreateAndConnectWithExtensions( + "permessage-deflate; client_max_window_bits=16"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received invalid client_max_window_bits parameter", + failure_message()); +} + +// *_max_window_bits must not start with 0 +TEST_F(WebSocketStreamCreateExtensionTest, MaxWindowBitsValueStartsWithZero) { + CreateAndConnectWithExtensions( + "permessage-deflate; client_max_window_bits=08"); + EXPECT_FALSE(stream_); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received invalid client_max_window_bits parameter", + failure_message()); +} + +// *_max_window_bits must not start with + +TEST_F(WebSocketStreamCreateExtensionTest, MaxWindowBitsValueStartsWithPlus) { + CreateAndConnectWithExtensions( + "permessage-deflate; server_max_window_bits=+9"); EXPECT_FALSE(stream_); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ( + "Error during WebSocket handshake: Error in permessage-deflate: " + "Received invalid server_max_window_bits parameter", + failure_message()); } +// TODO(ricea): Check that WebSocketDeflateStream is initialised with the +// arguments from the server. This is difficult because the data written to the +// socket is randomly masked. + // Additional Sec-WebSocket-Accept headers should be rejected. TEST_F(WebSocketStreamCreateTest, DoubleAccept) { CreateAndConnectStandard( "ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); RunUntilIdle(); EXPECT_FALSE(stream_); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Sec-WebSocket-Accept' header must not appear " + "more than once in a response", + failure_message()); } // Response code 200 must be rejected. @@ -274,11 +699,13 @@ TEST_F(WebSocketStreamCreateTest, InvalidStatusCode) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kInvalidStatusCodeResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: Unexpected response code: 200", + failure_message()); } // Redirects are not followed (according to the WHATWG WebSocket API, which @@ -295,11 +722,13 @@ TEST_F(WebSocketStreamCreateTest, RedirectsRejected) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kRedirectResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: Unexpected response code: 302", + failure_message()); } // Malformed responses should be rejected. HttpStreamParser will accept just @@ -317,11 +746,13 @@ TEST_F(WebSocketStreamCreateTest, MalformedResponse) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kMalformedResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: Invalid status line", + failure_message()); } // Upgrade header must be present. @@ -334,11 +765,13 @@ TEST_F(WebSocketStreamCreateTest, MissingUpgradeHeader) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kMissingUpgradeResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: 'Upgrade' header is missing", + failure_message()); } // There must only be one upgrade header. @@ -347,10 +780,34 @@ TEST_F(WebSocketStreamCreateTest, DoubleUpgradeHeader) { "ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", "Upgrade: HTTP/2.0\r\n"); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Upgrade' header must not appear more than once in a response", + failure_message()); +} + +// There must only be one correct upgrade header. +TEST_F(WebSocketStreamCreateTest, IncorrectUpgradeHeader) { + static const char kMissingUpgradeResponse[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "Upgrade: hogefuga\r\n" + "\r\n"; + CreateAndConnectCustomResponse("ws://localhost/", + "/", + NoSubProtocols(), + "http://localhost", + "", + kMissingUpgradeResponse); + RunUntilIdle(); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Upgrade' header value is not 'WebSocket': hogefuga", + failure_message()); } // Connection header must be present. @@ -363,11 +820,35 @@ TEST_F(WebSocketStreamCreateTest, MissingConnectionHeader) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", + "", + kMissingConnectionResponse); + RunUntilIdle(); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Connection' header is missing", + failure_message()); +} + +// Connection header must contain "Upgrade". +TEST_F(WebSocketStreamCreateTest, IncorrectConnectionHeader) { + static const char kMissingConnectionResponse[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "Connection: hogefuga\r\n" + "\r\n"; + CreateAndConnectCustomResponse("ws://localhost/", + "/", + NoSubProtocols(), + "http://localhost", "", kMissingConnectionResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Connection' header value must contain 'Upgrade'", + failure_message()); } // Connection header is permitted to contain other tokens. @@ -381,10 +862,11 @@ TEST_F(WebSocketStreamCreateTest, AdditionalTokenInConnectionHeader) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kAdditionalConnectionTokenResponse); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_TRUE(stream_); } @@ -398,11 +880,14 @@ TEST_F(WebSocketStreamCreateTest, MissingSecWebSocketAccept) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kMissingAcceptResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "'Sec-WebSocket-Accept' header is missing", + failure_message()); } // Sec-WebSocket-Accept header must match the key that was sent. @@ -416,20 +901,26 @@ TEST_F(WebSocketStreamCreateTest, WrongSecWebSocketAccept) { CreateAndConnectCustomResponse("ws://localhost/", "/", NoSubProtocols(), - "http://localhost/", + "http://localhost", "", kIncorrectAcceptResponse); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error during WebSocket handshake: " + "Incorrect 'Sec-WebSocket-Accept' header value", + failure_message()); } // Cancellation works. TEST_F(WebSocketStreamCreateTest, Cancellation) { CreateAndConnectStandard( - "ws://localhost/", "/", NoSubProtocols(), "http://localhost/", "", ""); + "ws://localhost/", "/", NoSubProtocols(), "http://localhost", "", ""); stream_request_.reset(); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_FALSE(stream_); + EXPECT_FALSE(request_info_); + EXPECT_FALSE(response_info_); } // Connect failure must look just like negotiation failure. @@ -439,9 +930,13 @@ TEST_F(WebSocketStreamCreateTest, ConnectionFailure) { socket_data->set_connect_data( MockConnect(SYNCHRONOUS, ERR_CONNECTION_REFUSED)); CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(), - "http://localhost/", socket_data.Pass()); + "http://localhost", socket_data.Pass()); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error in connection establishment: net::ERR_CONNECTION_REFUSED", + failure_message()); + EXPECT_FALSE(request_info_); + EXPECT_FALSE(response_info_); } // Connect timeout must look just like any other failure. @@ -451,9 +946,11 @@ TEST_F(WebSocketStreamCreateTest, ConnectionTimeout) { socket_data->set_connect_data( MockConnect(ASYNC, ERR_CONNECTION_TIMED_OUT)); CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(), - "http://localhost/", socket_data.Pass()); + "http://localhost", socket_data.Pass()); RunUntilIdle(); - EXPECT_EQ(1006, error()); + EXPECT_TRUE(has_failed()); + EXPECT_EQ("Error in connection establishment: net::ERR_CONNECTION_TIMED_OUT", + failure_message()); } // Cancellation during connect works. @@ -463,10 +960,11 @@ TEST_F(WebSocketStreamCreateTest, CancellationDuringConnect) { socket_data->set_connect_data(MockConnect(SYNCHRONOUS, ERR_IO_PENDING)); CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(), - "http://localhost/", + "http://localhost", socket_data.Pass()); stream_request_.reset(); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_FALSE(stream_); } @@ -482,17 +980,20 @@ TEST_F(WebSocketStreamCreateTest, CancellationDuringWrite) { socket_data->SetStop(1); CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(), - "http://localhost/", + "http://localhost", make_scoped_ptr(socket_data)); socket_data->Run(); stream_request_.reset(); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_FALSE(stream_); + EXPECT_TRUE(request_info_); + EXPECT_FALSE(response_info_); } // Cancellation during read of the response headers works. TEST_F(WebSocketStreamCreateTest, CancellationDuringRead) { - std::string request = WebSocketStandardRequest("/", "http://localhost/", ""); + std::string request = WebSocketStandardRequest("/", "http://localhost", ""); MockWrite writes[] = {MockWrite(ASYNC, 0, request.c_str())}; MockRead reads[] = { MockRead(ASYNC, 1, "HTTP/1.1 101 Switching Protocols\r\nUpgr"), @@ -503,12 +1004,176 @@ TEST_F(WebSocketStreamCreateTest, CancellationDuringRead) { socket_data->SetStop(1); CreateAndConnectRawExpectations("ws://localhost/", NoSubProtocols(), - "http://localhost/", + "http://localhost", make_scoped_ptr(socket_data)); socket_data->Run(); stream_request_.reset(); RunUntilIdle(); + EXPECT_FALSE(has_failed()); EXPECT_FALSE(stream_); + EXPECT_TRUE(request_info_); + EXPECT_FALSE(response_info_); +} + +// Over-size response headers (> 256KB) should not cause a crash. This is a +// regression test for crbug.com/339456. It is based on the layout test +// "cookie-flood.html". +TEST_F(WebSocketStreamCreateTest, VeryLargeResponseHeaders) { + std::string set_cookie_headers; + set_cookie_headers.reserve(45 * 10000); + for (int i = 0; i < 10000; ++i) { + set_cookie_headers += + base::StringPrintf("Set-Cookie: WK-websocket-test-flood-%d=1\r\n", i); + } + CreateAndConnectStandard("ws://localhost/", "/", NoSubProtocols(), + "http://localhost", "", set_cookie_headers); + RunUntilIdle(); + EXPECT_TRUE(has_failed()); + EXPECT_FALSE(response_info_); +} + +// If the remote host closes the connection without sending headers, we should +// log the console message "Connection closed before receiving a handshake +// response". +TEST_F(WebSocketStreamCreateTest, NoResponse) { + std::string request = WebSocketStandardRequest("/", "http://localhost", ""); + MockWrite writes[] = {MockWrite(ASYNC, request.data(), request.size(), 0)}; + MockRead reads[] = {MockRead(ASYNC, 0, 1)}; + DeterministicSocketData* socket_data(new DeterministicSocketData( + reads, arraysize(reads), writes, arraysize(writes))); + socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK)); + CreateAndConnectRawExpectations("ws://localhost/", + NoSubProtocols(), + "http://localhost", + make_scoped_ptr(socket_data)); + socket_data->RunFor(2); + EXPECT_TRUE(has_failed()); + EXPECT_FALSE(stream_); + EXPECT_FALSE(response_info_); + EXPECT_EQ("Connection closed before receiving a handshake response", + failure_message()); +} + +TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateFailure) { + ssl_data_.push_back( + new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); + ssl_data_[0]->cert = + ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); + ASSERT_TRUE(ssl_data_[0]->cert); + scoped_ptr<DeterministicSocketData> raw_socket_data( + new DeterministicSocketData(NULL, 0, NULL, 0)); + CreateAndConnectRawExpectations("wss://localhost/", + NoSubProtocols(), + "http://localhost", + raw_socket_data.Pass()); + RunUntilIdle(); + EXPECT_FALSE(has_failed()); + ASSERT_TRUE(ssl_error_callbacks_); + ssl_error_callbacks_->CancelSSLRequest(ERR_CERT_AUTHORITY_INVALID, + &ssl_info_); + RunUntilIdle(); + EXPECT_TRUE(has_failed()); +} + +TEST_F(WebSocketStreamCreateTest, SelfSignedCertificateSuccess) { + scoped_ptr<SSLSocketDataProvider> ssl_data( + new SSLSocketDataProvider(ASYNC, ERR_CERT_AUTHORITY_INVALID)); + ssl_data->cert = + ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); + ASSERT_TRUE(ssl_data->cert); + ssl_data_.push_back(ssl_data.release()); + ssl_data.reset(new SSLSocketDataProvider(ASYNC, OK)); + ssl_data_.push_back(ssl_data.release()); + url_request_context_host_.AddRawExpectations( + make_scoped_ptr(new DeterministicSocketData(NULL, 0, NULL, 0))); + CreateAndConnectStandard( + "wss://localhost/", "/", NoSubProtocols(), "http://localhost", "", ""); + RunUntilIdle(); + ASSERT_TRUE(ssl_error_callbacks_); + ssl_error_callbacks_->ContinueSSLRequest(); + RunUntilIdle(); + EXPECT_FALSE(has_failed()); + EXPECT_TRUE(stream_); +} + +TEST_F(WebSocketStreamCreateUMATest, Incomplete) { + const std::string name("Net.WebSocket.HandshakeResult"); + scoped_ptr<base::HistogramSamples> original(GetSamples(name)); + + { + StreamCreation creation; + creation.CreateAndConnectStandard("ws://localhost/", + "/", + creation.NoSubProtocols(), + "http://localhost", + "", + ""); + } + + scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); + ASSERT_TRUE(samples); + if (original) { + samples->Subtract(*original); // Cancel the original values. + } + EXPECT_EQ(1, samples->GetCount(INCOMPLETE)); + EXPECT_EQ(0, samples->GetCount(CONNECTED)); + EXPECT_EQ(0, samples->GetCount(FAILED)); +} + +TEST_F(WebSocketStreamCreateUMATest, Connected) { + const std::string name("Net.WebSocket.HandshakeResult"); + scoped_ptr<base::HistogramSamples> original(GetSamples(name)); + + { + StreamCreation creation; + creation.CreateAndConnectStandard("ws://localhost/", + "/", + creation.NoSubProtocols(), + "http://localhost", + "", + ""); + creation.RunUntilIdle(); + } + + scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); + ASSERT_TRUE(samples); + if (original) { + samples->Subtract(*original); // Cancel the original values. + } + EXPECT_EQ(0, samples->GetCount(INCOMPLETE)); + EXPECT_EQ(1, samples->GetCount(CONNECTED)); + EXPECT_EQ(0, samples->GetCount(FAILED)); +} + +TEST_F(WebSocketStreamCreateUMATest, Failed) { + const std::string name("Net.WebSocket.HandshakeResult"); + scoped_ptr<base::HistogramSamples> original(GetSamples(name)); + + { + StreamCreation creation; + static const char kInvalidStatusCodeResponse[] = + "HTTP/1.1 200 OK\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "\r\n"; + creation.CreateAndConnectCustomResponse("ws://localhost/", + "/", + creation.NoSubProtocols(), + "http://localhost", + "", + kInvalidStatusCodeResponse); + creation.RunUntilIdle(); + } + + scoped_ptr<base::HistogramSamples> samples(GetSamples(name)); + ASSERT_TRUE(samples); + if (original) { + samples->Subtract(*original); // Cancel the original values. + } + EXPECT_EQ(1, samples->GetCount(INCOMPLETE)); + EXPECT_EQ(0, samples->GetCount(CONNECTED)); + EXPECT_EQ(0, samples->GetCount(FAILED)); } } // namespace diff --git a/chromium/net/websockets/websocket_test_util.cc b/chromium/net/websockets/websocket_test_util.cc index 55113c6f15c..bfa89803447 100644 --- a/chromium/net/websockets/websocket_test_util.cc +++ b/chromium/net/websockets/websocket_test_util.cc @@ -4,7 +4,12 @@ #include "net/websockets/websocket_test_util.h" +#include <algorithm> +#include <vector> + #include "base/basictypes.h" +#include "base/memory/scoped_vector.h" +#include "base/stl_util.h" #include "base/strings/stringprintf.h" #include "net/socket/socket_test_util.h" @@ -37,6 +42,8 @@ std::string WebSocketStandardRequest(const std::string& path, "GET %s HTTP/1.1\r\n" "Host: localhost\r\n" "Connection: Upgrade\r\n" + "Pragma: no-cache\r\n" + "Cache-Control: no-cache\r\n" "Upgrade: websocket\r\n" "Origin: %s\r\n" "Sec-WebSocket-Version: 13\r\n" @@ -44,6 +51,7 @@ std::string WebSocketStandardRequest(const std::string& path, "Accept-Encoding: gzip,deflate\r\n" "Accept-Language: en-us,fr\r\n" "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; client_max_window_bits\r\n" "%s\r\n", path.c_str(), origin.c_str(), @@ -63,9 +71,10 @@ std::string WebSocketStandardResponse(const std::string& extra_headers) { struct WebSocketDeterministicMockClientSocketFactoryMaker::Detail { std::string expect_written; std::string return_to_read; - MockRead read; + std::vector<MockRead> reads; MockWrite write; - scoped_ptr<DeterministicSocketData> data; + ScopedVector<DeterministicSocketData> socket_data_vector; + ScopedVector<SSLSocketDataProvider> ssl_socket_data_vector; DeterministicMockClientSocketFactory factory; }; @@ -84,22 +93,46 @@ WebSocketDeterministicMockClientSocketFactoryMaker::factory() { void WebSocketDeterministicMockClientSocketFactoryMaker::SetExpectations( const std::string& expect_written, const std::string& return_to_read) { + const size_t kHttpStreamParserBufferSize = 4096; // We need to extend the lifetime of these strings. detail_->expect_written = expect_written; detail_->return_to_read = return_to_read; - detail_->write = MockWrite(SYNCHRONOUS, 0, detail_->expect_written.c_str()); - detail_->read = MockRead(SYNCHRONOUS, 1, detail_->return_to_read.c_str()); + int sequence = 0; + detail_->write = MockWrite(SYNCHRONOUS, + detail_->expect_written.data(), + detail_->expect_written.size(), + sequence++); + // HttpStreamParser reads 4KB at a time. We need to take this implementation + // detail into account if |return_to_read| is big enough. + for (size_t place = 0; place < detail_->return_to_read.size(); + place += kHttpStreamParserBufferSize) { + detail_->reads.push_back( + MockRead(SYNCHRONOUS, detail_->return_to_read.data() + place, + std::min(detail_->return_to_read.size() - place, + kHttpStreamParserBufferSize), + sequence++)); + } scoped_ptr<DeterministicSocketData> socket_data( - new DeterministicSocketData(&detail_->read, 1, &detail_->write, 1)); + new DeterministicSocketData(vector_as_array(&detail_->reads), + detail_->reads.size(), + &detail_->write, + 1)); socket_data->set_connect_data(MockConnect(SYNCHRONOUS, OK)); - socket_data->SetStop(2); - SetRawExpectations(socket_data.Pass()); + socket_data->SetStop(sequence); + AddRawExpectations(socket_data.Pass()); } -void WebSocketDeterministicMockClientSocketFactoryMaker::SetRawExpectations( +void WebSocketDeterministicMockClientSocketFactoryMaker::AddRawExpectations( scoped_ptr<DeterministicSocketData> socket_data) { - detail_->data = socket_data.Pass(); - detail_->factory.AddSocketDataProvider(detail_->data.get()); + detail_->factory.AddSocketDataProvider(socket_data.get()); + detail_->socket_data_vector.push_back(socket_data.release()); +} + +void +WebSocketDeterministicMockClientSocketFactoryMaker::AddSSLSocketDataProvider( + scoped_ptr<SSLSocketDataProvider> ssl_socket_data) { + detail_->factory.AddSSLSocketDataProvider(ssl_socket_data.get()); + detail_->ssl_socket_data_vector.push_back(ssl_socket_data.release()); } WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost() @@ -109,9 +142,14 @@ WebSocketTestURLRequestContextHost::WebSocketTestURLRequestContextHost() WebSocketTestURLRequestContextHost::~WebSocketTestURLRequestContextHost() {} -void WebSocketTestURLRequestContextHost::SetRawExpectations( +void WebSocketTestURLRequestContextHost::AddRawExpectations( scoped_ptr<DeterministicSocketData> socket_data) { - maker_.SetRawExpectations(socket_data.Pass()); + maker_.AddRawExpectations(socket_data.Pass()); +} + +void WebSocketTestURLRequestContextHost::AddSSLSocketDataProvider( + scoped_ptr<SSLSocketDataProvider> ssl_socket_data) { + maker_.AddSSLSocketDataProvider(ssl_socket_data.Pass()); } TestURLRequestContext* diff --git a/chromium/net/websockets/websocket_test_util.h b/chromium/net/websockets/websocket_test_util.h index 71b2ce668c8..2ad86c08fe0 100644 --- a/chromium/net/websockets/websocket_test_util.h +++ b/chromium/net/websockets/websocket_test_util.h @@ -14,13 +14,18 @@ class GURL; +namespace url { +class Origin; +} // namespace url + namespace net { class BoundNetLog; +class DeterministicMockClientSocketFactory; class DeterministicSocketData; class URLRequestContext; class WebSocketHandshakeStreamCreateHelper; -class DeterministicMockClientSocketFactory; +struct SSLSocketDataProvider; class LinearCongruentialGenerator { public: @@ -38,7 +43,7 @@ NET_EXPORT_PRIVATE extern scoped_ptr<WebSocketStreamRequest> CreateAndConnectStreamForTesting( const GURL& socket_url, scoped_ptr<WebSocketHandshakeStreamCreateHelper> create_helper, - const GURL& origin, + const url::Origin& origin, URLRequestContext* url_request_context, const BoundNetLog& net_log, scoped_ptr<WebSocketStream::ConnectDelegate> connect_delegate); @@ -61,15 +66,26 @@ class WebSocketDeterministicMockClientSocketFactoryMaker { WebSocketDeterministicMockClientSocketFactoryMaker(); ~WebSocketDeterministicMockClientSocketFactoryMaker(); - // The socket created by the factory will expect |expect_written| to be - // written to the socket, and will respond with |return_to_read|. The test - // will fail if the expected text is not written, or all the bytes are not - // read. + // Tell the factory to create a socket which expects |expect_written| to be + // written, and responds with |return_to_read|. The test will fail if the + // expected text is not written, or all the bytes are not read. This adds data + // for a new mock-socket using AddRawExpections(), and so can be called + // multiple times to queue up multiple mock sockets, but usually in those + // cases the lower-level AddRawExpections() interface is more appropriate. void SetExpectations(const std::string& expect_written, const std::string& return_to_read); - // A low-level interface to permit arbitrary expectations to be set. - void SetRawExpectations(scoped_ptr<DeterministicSocketData> socket_data); + // A low-level interface to permit arbitrary expectations to be added. The + // mock sockets will be created in the same order that they were added. + void AddRawExpectations(scoped_ptr<DeterministicSocketData> socket_data); + + // Allow an SSL socket data provider to be added. You must also supply a mock + // transport socket for it to use. If the mock SSL handshake fails then the + // mock transport socket will connect but have nothing read or written. If the + // mock handshake succeeds then the data from the underlying transport socket + // will be passed through unchanged (without encryption). + void AddSSLSocketDataProvider( + scoped_ptr<SSLSocketDataProvider> ssl_socket_data); // Call to get a pointer to the factory, which remains owned by this object. DeterministicMockClientSocketFactory* factory(); @@ -94,9 +110,13 @@ struct WebSocketTestURLRequestContextHost { maker_.SetExpectations(expect_written, return_to_read); } - void SetRawExpectations(scoped_ptr<DeterministicSocketData> socket_data); + void AddRawExpectations(scoped_ptr<DeterministicSocketData> socket_data); + + // Allow an SSL socket data provider to be added. + void AddSSLSocketDataProvider( + scoped_ptr<SSLSocketDataProvider> ssl_socket_data); - // Call after calling one of SetExpections() or SetRawExpectations(). The + // Call after calling one of SetExpections() or AddRawExpectations(). The // returned pointer remains owned by this object. This should only be called // once. TestURLRequestContext* GetURLRequestContext(); diff --git a/chromium/net/websockets/websocket_throttle_test.cc b/chromium/net/websockets/websocket_throttle_test.cc index 14237b9b265..4d9300191bb 100644 --- a/chromium/net/websockets/websocket_throttle_test.cc +++ b/chromium/net/websockets/websocket_throttle_test.cc @@ -16,20 +16,34 @@ #include "testing/platform_test.h" #include "url/gurl.h" -class DummySocketStreamDelegate : public net::SocketStream::Delegate { +namespace net { + +namespace { + +class DummySocketStreamDelegate : public SocketStream::Delegate { public: DummySocketStreamDelegate() {} virtual ~DummySocketStreamDelegate() {} virtual void OnConnected( - net::SocketStream* socket, int max_pending_send_allowed) OVERRIDE {} - virtual void OnSentData(net::SocketStream* socket, + SocketStream* socket, int max_pending_send_allowed) OVERRIDE {} + virtual void OnSentData(SocketStream* socket, int amount_sent) OVERRIDE {} - virtual void OnReceivedData(net::SocketStream* socket, + virtual void OnReceivedData(SocketStream* socket, const char* data, int len) OVERRIDE {} - virtual void OnClose(net::SocketStream* socket) OVERRIDE {} + virtual void OnClose(SocketStream* socket) OVERRIDE {} }; -namespace net { +class WebSocketThrottleTestContext : public TestURLRequestContext { + public: + explicit WebSocketThrottleTestContext(bool enable_websocket_over_spdy) + : TestURLRequestContext(true) { + HttpNetworkSession::Params params; + params.enable_websocket_over_spdy = enable_websocket_over_spdy; + Init(); + } +}; + +} // namespace class WebSocketThrottleTest : public PlatformTest { protected: @@ -60,11 +74,10 @@ class WebSocketThrottleTest : public PlatformTest { }; TEST_F(WebSocketThrottleTest, Throttle) { - TestURLRequestContext context; - DummySocketStreamDelegate delegate; // TODO(toyoshim): We need to consider both spdy-enabled and spdy-disabled // configuration. - WebSocketJob::set_websocket_over_spdy_enabled(true); + WebSocketThrottleTestContext context(true); + DummySocketStreamDelegate delegate; // For host1: 1.2.3.4, 1.2.3.5, 1.2.3.6 AddressList addr; @@ -73,8 +86,7 @@ TEST_F(WebSocketThrottleTest, Throttle) { addr.push_back(MakeAddr(1, 2, 3, 6)); scoped_refptr<WebSocketJob> w1(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s1( - new SocketStream(GURL("ws://host1/"), w1.get())); - s1->set_context(&context); + new SocketStream(GURL("ws://host1/"), w1.get(), &context, NULL)); w1->InitSocketStream(s1.get()); WebSocketThrottleTest::MockSocketStreamConnect(s1.get(), addr); @@ -94,8 +106,7 @@ TEST_F(WebSocketThrottleTest, Throttle) { addr.push_back(MakeAddr(1, 2, 3, 4)); scoped_refptr<WebSocketJob> w2(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s2( - new SocketStream(GURL("ws://host2/"), w2.get())); - s2->set_context(&context); + new SocketStream(GURL("ws://host2/"), w2.get(), &context, NULL)); w2->InitSocketStream(s2.get()); WebSocketThrottleTest::MockSocketStreamConnect(s2.get(), addr); @@ -115,8 +126,7 @@ TEST_F(WebSocketThrottleTest, Throttle) { addr.push_back(MakeAddr(1, 2, 3, 5)); scoped_refptr<WebSocketJob> w3(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s3( - new SocketStream(GURL("ws://host3/"), w3.get())); - s3->set_context(&context); + new SocketStream(GURL("ws://host3/"), w3.get(), &context, NULL)); w3->InitSocketStream(s3.get()); WebSocketThrottleTest::MockSocketStreamConnect(s3.get(), addr); @@ -136,8 +146,7 @@ TEST_F(WebSocketThrottleTest, Throttle) { addr.push_back(MakeAddr(1, 2, 3, 6)); scoped_refptr<WebSocketJob> w4(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s4( - new SocketStream(GURL("ws://host4/"), w4.get())); - s4->set_context(&context); + new SocketStream(GURL("ws://host4/"), w4.get(), &context, NULL)); w4->InitSocketStream(s4.get()); WebSocketThrottleTest::MockSocketStreamConnect(s4.get(), addr); @@ -156,8 +165,7 @@ TEST_F(WebSocketThrottleTest, Throttle) { addr.push_back(MakeAddr(1, 2, 3, 6)); scoped_refptr<WebSocketJob> w5(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s5( - new SocketStream(GURL("ws://host5/"), w5.get())); - s5->set_context(&context); + new SocketStream(GURL("ws://host5/"), w5.get(), &context, NULL)); w5->InitSocketStream(s5.get()); WebSocketThrottleTest::MockSocketStreamConnect(s5.get(), addr); @@ -176,8 +184,7 @@ TEST_F(WebSocketThrottleTest, Throttle) { addr.push_back(MakeAddr(1, 2, 3, 6)); scoped_refptr<WebSocketJob> w6(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s6( - new SocketStream(GURL("ws://host6/"), w6.get())); - s6->set_context(&context); + new SocketStream(GURL("ws://host6/"), w6.get(), &context, NULL)); w6->InitSocketStream(s6.get()); WebSocketThrottleTest::MockSocketStreamConnect(s6.get(), addr); @@ -279,9 +286,8 @@ TEST_F(WebSocketThrottleTest, Throttle) { } TEST_F(WebSocketThrottleTest, NoThrottleForDuplicateAddress) { - TestURLRequestContext context; + WebSocketThrottleTestContext context(true); DummySocketStreamDelegate delegate; - WebSocketJob::set_websocket_over_spdy_enabled(true); // For localhost: 127.0.0.1, 127.0.0.1 AddressList addr; @@ -289,8 +295,7 @@ TEST_F(WebSocketThrottleTest, NoThrottleForDuplicateAddress) { addr.push_back(MakeAddr(127, 0, 0, 1)); scoped_refptr<WebSocketJob> w1(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s1( - new SocketStream(GURL("ws://localhost/"), w1.get())); - s1->set_context(&context); + new SocketStream(GURL("ws://localhost/"), w1.get(), &context, NULL)); w1->InitSocketStream(s1.get()); WebSocketThrottleTest::MockSocketStreamConnect(s1.get(), addr); @@ -309,17 +314,15 @@ TEST_F(WebSocketThrottleTest, NoThrottleForDuplicateAddress) { // A connection should not be blocked by another connection to the same IP // with a different port. TEST_F(WebSocketThrottleTest, NoThrottleForDistinctPort) { - TestURLRequestContext context; + WebSocketThrottleTestContext context(false); DummySocketStreamDelegate delegate; IPAddressNumber localhost; ParseIPLiteralToNumber("127.0.0.1", &localhost); - WebSocketJob::set_websocket_over_spdy_enabled(false); // socket1: 127.0.0.1:80 scoped_refptr<WebSocketJob> w1(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s1( - new SocketStream(GURL("ws://localhost:80/"), w1.get())); - s1->set_context(&context); + new SocketStream(GURL("ws://localhost:80/"), w1.get(), &context, NULL)); w1->InitSocketStream(s1.get()); MockSocketStreamConnect(s1.get(), AddressList::CreateFromIPAddress(localhost, 80)); @@ -332,8 +335,7 @@ TEST_F(WebSocketThrottleTest, NoThrottleForDistinctPort) { // socket2: 127.0.0.1:81 scoped_refptr<WebSocketJob> w2(new WebSocketJob(&delegate)); scoped_refptr<SocketStream> s2( - new SocketStream(GURL("ws://localhost:81/"), w2.get())); - s2->set_context(&context); + new SocketStream(GURL("ws://localhost:81/"), w2.get(), &context, NULL)); w2->InitSocketStream(s2.get()); MockSocketStreamConnect(s2.get(), AddressList::CreateFromIPAddress(localhost, 81)); @@ -354,4 +356,4 @@ TEST_F(WebSocketThrottleTest, NoThrottleForDistinctPort) { base::MessageLoopForIO::current()->RunUntilIdle(); } -} +} // namespace net |