diff options
author | Lorry Tar Creator <lorry-tar-importer@lorry> | 2017-06-27 06:07:23 +0000 |
---|---|---|
committer | Lorry Tar Creator <lorry-tar-importer@lorry> | 2017-06-27 06:07:23 +0000 |
commit | 1bf1084f2b10c3b47fd1a588d85d21ed0eb41d0c (patch) | |
tree | 46dcd36c86e7fbc6e5df36deb463b33e9967a6f7 /Source/WebCore/Modules/websockets/WebSocketHandshake.cpp | |
parent | 32761a6cee1d0dee366b885b7b9c777e67885688 (diff) | |
download | WebKitGtk-tarball-master.tar.gz |
webkitgtk-2.16.5HEADwebkitgtk-2.16.5master
Diffstat (limited to 'Source/WebCore/Modules/websockets/WebSocketHandshake.cpp')
-rw-r--r-- | Source/WebCore/Modules/websockets/WebSocketHandshake.cpp | 249 |
1 files changed, 153 insertions, 96 deletions
diff --git a/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp b/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp index e6782cf8f..dd3b9a4e4 100644 --- a/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp +++ b/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp @@ -40,13 +40,14 @@ #include "CookieJar.h" #include "Document.h" #include "HTTPHeaderMap.h" +#include "HTTPHeaderNames.h" #include "HTTPParsers.h" #include "URL.h" #include "Logging.h" #include "ResourceRequest.h" -#include "ScriptCallStack.h" #include "ScriptExecutionContext.h" #include "SecurityOrigin.h" +#include <wtf/ASCIICType.h> #include <wtf/CryptographicallyRandomNumber.h> #include <wtf/MD5.h> #include <wtf/SHA1.h> @@ -56,6 +57,7 @@ #include <wtf/text/Base64.h> #include <wtf/text/CString.h> #include <wtf/text/StringBuilder.h> +#include <wtf/text/StringView.h> #include <wtf/text/WTFString.h> #include <wtf/unicode/CharacterNames.h> @@ -81,10 +83,10 @@ static String hostName(const URL& url, bool secure) { ASSERT(url.protocolIs("wss") == secure); StringBuilder builder; - builder.append(url.host().lower()); - if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) { + builder.append(url.host().convertToASCIILowercase()); + if (url.port() && ((!secure && url.port().value() != 80) || (secure && url.port().value() != 443))) { builder.append(':'); - builder.appendNumber(url.port()); + builder.appendNumber(url.port().value()); } return builder.toString(); } @@ -118,12 +120,13 @@ String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocket return base64Encode(hash.data(), SHA1::hashSize); } -WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, ScriptExecutionContext* context) +WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, Document* document, bool allowCookies) : m_url(url) , m_clientProtocol(protocol) , m_secure(m_url.protocolIs("wss")) - , m_context(context) + , m_document(document) , m_mode(Incomplete) + , m_allowCookies(allowCookies) { m_secWebSocketKey = generateSecWebSocketKey(); m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey); @@ -140,12 +143,13 @@ const URL& WebSocketHandshake::url() const void WebSocketHandshake::setURL(const URL& url) { - m_url = url.copy(); + m_url = url.isolatedCopy(); } +// FIXME: Return type should just be String, not const String. const String WebSocketHandshake::host() const { - return m_url.host().lower(); + return m_url.host().convertToASCIILowercase(); } const String& WebSocketHandshake::clientProtocol() const @@ -165,14 +169,14 @@ bool WebSocketHandshake::secure() const String WebSocketHandshake::clientOrigin() const { - return m_context->securityOrigin()->toString(); + return m_document->securityOrigin().toString(); } String WebSocketHandshake::clientLocation() const { StringBuilder builder; builder.append(m_secure ? "wss" : "ws"); - builder.append("://"); + builder.appendLiteral("://"); builder.append(hostName(m_url, m_secure)); builder.append(resourceName(m_url)); return builder.toString(); @@ -183,9 +187,9 @@ CString WebSocketHandshake::clientHandshakeMessage() const // Keep the following consistent with clientHandshakeRequest(). StringBuilder builder; - builder.append("GET "); + builder.appendLiteral("GET "); builder.append(resourceName(m_url)); - builder.append(" HTTP/1.1\r\n"); + builder.appendLiteral(" HTTP/1.1\r\n"); Vector<String> fields; fields.append("Upgrade: websocket"); @@ -196,12 +200,10 @@ CString WebSocketHandshake::clientHandshakeMessage() const fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol); URL url = httpURLForAuthenticationAndCookies(); - if (m_context->isDocument()) { - Document* document = toDocument(m_context); - String cookie = cookieRequestHeaderFieldValue(document, url); + if (m_allowCookies && m_document) { + String cookie = cookieRequestHeaderFieldValue(*m_document, url); if (!cookie.isEmpty()) fields.append("Cookie: " + cookie); - // Set "Cookie2: <cookie>" if cookies 2 exists for url? } // Add no-cache headers to avoid compatibility issue. @@ -218,18 +220,18 @@ CString WebSocketHandshake::clientHandshakeMessage() const fields.append("Sec-WebSocket-Extensions: " + extensionValue); // Add a User-Agent header. - fields.append("User-Agent: " + m_context->userAgent(m_context->url())); + fields.append("User-Agent: " + m_document->userAgent(m_document->url())); // Fields in the handshake are sent by the client in a random order; the // order is not meaningful. Thus, it's ok to send the order we constructed // the fields. - for (size_t i = 0; i < fields.size(); i++) { - builder.append(fields[i]); - builder.append("\r\n"); + for (auto& field : fields) { + builder.append(field); + builder.appendLiteral("\r\n"); } - builder.append("\r\n"); + builder.appendLiteral("\r\n"); return builder.toString().utf8(); } @@ -237,37 +239,33 @@ CString WebSocketHandshake::clientHandshakeMessage() const ResourceRequest WebSocketHandshake::clientHandshakeRequest() const { // Keep the following consistent with clientHandshakeMessage(). - // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and - // m_key3 in the request? ResourceRequest request(m_url); request.setHTTPMethod("GET"); - request.addHTTPHeaderField("Connection", "Upgrade"); - request.addHTTPHeaderField("Host", hostName(m_url, m_secure)); - request.addHTTPHeaderField("Origin", clientOrigin()); + request.setHTTPHeaderField(HTTPHeaderName::Connection, "Upgrade"); + request.setHTTPHeaderField(HTTPHeaderName::Host, hostName(m_url, m_secure)); + request.setHTTPHeaderField(HTTPHeaderName::Origin, clientOrigin()); if (!m_clientProtocol.isEmpty()) - request.addHTTPHeaderField("Sec-WebSocket-Protocol", m_clientProtocol); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketProtocol, m_clientProtocol); URL url = httpURLForAuthenticationAndCookies(); - if (m_context->isDocument()) { - Document* document = toDocument(m_context); - String cookie = cookieRequestHeaderFieldValue(document, url); + if (m_allowCookies && m_document) { + String cookie = cookieRequestHeaderFieldValue(*m_document, url); if (!cookie.isEmpty()) - request.addHTTPHeaderField("Cookie", cookie); - // Set "Cookie2: <cookie>" if cookies 2 exists for url? + request.setHTTPHeaderField(HTTPHeaderName::Cookie, cookie); } - request.addHTTPHeaderField("Pragma", "no-cache"); - request.addHTTPHeaderField("Cache-Control", "no-cache"); + request.setHTTPHeaderField(HTTPHeaderName::Pragma, "no-cache"); + request.setHTTPHeaderField(HTTPHeaderName::CacheControl, "no-cache"); - request.addHTTPHeaderField("Sec-WebSocket-Key", m_secWebSocketKey); - request.addHTTPHeaderField("Sec-WebSocket-Version", "13"); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketKey, m_secWebSocketKey); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketVersion, "13"); const String extensionValue = m_extensionDispatcher.createHeaderValue(); if (extensionValue.length()) - request.addHTTPHeaderField("Sec-WebSocket-Extensions", extensionValue); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketExtensions, extensionValue); // Add a User-Agent header. - request.addHTTPHeaderField("User-Agent", m_context->userAgent(m_context->url())); + request.setHTTPHeaderField(HTTPHeaderName::UserAgent, m_document->userAgent(m_document->url())); return request; } @@ -278,9 +276,9 @@ void WebSocketHandshake::reset() m_extensionDispatcher.reset(); } -void WebSocketHandshake::clearScriptExecutionContext() +void WebSocketHandshake::clearDocument() { - m_context = 0; + m_document = nullptr; } int WebSocketHandshake::readServerHandshake(const char* header, size_t len) @@ -303,7 +301,7 @@ int WebSocketHandshake::readServerHandshake(const char* header, size_t len) if (statusCode != 101) { m_mode = Failed; - m_failureReason = "Unexpected response code: " + String::number(statusCode); + m_failureReason = makeString("Unexpected response code: ", String::number(statusCode)); return len; } m_mode = Normal; @@ -340,32 +338,27 @@ String WebSocketHandshake::failureReason() const String WebSocketHandshake::serverWebSocketProtocol() const { - return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-protocol"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketProtocol); } String WebSocketHandshake::serverSetCookie() const { - return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie"); -} - -String WebSocketHandshake::serverSetCookie2() const -{ - return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie2"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SetCookie); } String WebSocketHandshake::serverUpgrade() const { - return m_serverHandshakeResponse.httpHeaderFields().get("upgrade"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Upgrade); } String WebSocketHandshake::serverConnection() const { - return m_serverHandshakeResponse.httpHeaderFields().get("connection"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Connection); } String WebSocketHandshake::serverWebSocketAccept() const { - return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-accept"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketAccept); } String WebSocketHandshake::acceptedExtensions() const @@ -378,19 +371,56 @@ const ResourceResponse& WebSocketHandshake::serverHandshakeResponse() const return m_serverHandshakeResponse; } -void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor) +void WebSocketHandshake::addExtensionProcessor(std::unique_ptr<WebSocketExtensionProcessor> processor) { - m_extensionDispatcher.addProcessor(processor); + m_extensionDispatcher.addProcessor(WTFMove(processor)); } URL WebSocketHandshake::httpURLForAuthenticationAndCookies() const { - URL url = m_url.copy(); + URL url = m_url.isolatedCopy(); bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http"); ASSERT_UNUSED(couldSetProtocol, couldSetProtocol); return url; } +// https://tools.ietf.org/html/rfc6455#section-4.1 +// "The HTTP version MUST be at least 1.1." +static inline bool headerHasValidHTTPVersion(StringView httpStatusLine) +{ + const char* httpVersionStaticPreambleLiteral = "HTTP/"; + StringView httpVersionStaticPreamble(reinterpret_cast<const LChar*>(httpVersionStaticPreambleLiteral), strlen(httpVersionStaticPreambleLiteral)); + if (!httpStatusLine.startsWith(httpVersionStaticPreamble)) + return false; + + // Check that there is a version number which should be at least three characters after "HTTP/" + unsigned preambleLength = httpVersionStaticPreamble.length(); + if (httpStatusLine.length() < preambleLength + 3) + return false; + + auto dotPosition = httpStatusLine.find('.', preambleLength); + if (dotPosition == notFound) + return false; + + StringView majorVersionView = httpStatusLine.substring(preambleLength, dotPosition - preambleLength); + bool isValid; + int majorVersion = majorVersionView.toIntStrict(isValid); + if (!isValid) + return false; + + unsigned minorVersionLength; + unsigned charactersLeftAfterDotPosition = httpStatusLine.length() - dotPosition; + for (minorVersionLength = 1; minorVersionLength < charactersLeftAfterDotPosition; minorVersionLength++) { + if (!isASCIIDigit(httpStatusLine[dotPosition + minorVersionLength])) + break; + } + int minorVersion = (httpStatusLine.substring(dotPosition + 1, minorVersionLength)).toIntStrict(isValid); + if (!isValid) + return false; + + return (majorVersion >= 1 && minorVersion >= 1) || majorVersion >= 2; +} + // Returns the header length (including "\r\n"), or -1 if we have not received enough data yet. // If the line is malformed or the status code is not a 3-digit number, // statusCode and statusText will be set to -1 and a null string, respectively. @@ -403,8 +433,8 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, statusCode = -1; statusText = String(); - const char* space1 = 0; - const char* space2 = 0; + const char* space1 = nullptr; + const char* space2 = nullptr; const char* p; size_t consumedLength; @@ -418,7 +448,10 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, // The caller isn't prepared to deal with null bytes in status // line. WebSockets specification doesn't prohibit this, but HTTP // does, so we'll just treat this as an error. - m_failureReason = "Status line contains embedded null"; + m_failureReason = ASCIILiteral("Status line contains embedded null"); + return p + 1 - header; + } else if (!isASCII(*p)) { + m_failureReason = ASCIILiteral("Status line contains non-ASCII character"); return p + 1 - header; } else if (*p == '\n') break; @@ -429,32 +462,38 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, const char* end = p + 1; int lineLength = end - header; if (lineLength > maximumLength) { - m_failureReason = "Status line is too long"; + m_failureReason = ASCIILiteral("Status line is too long"); return maximumLength; } // The line must end with "\r\n". if (lineLength < 2 || *(end - 2) != '\r') { - m_failureReason = "Status line does not end with CRLF"; + m_failureReason = ASCIILiteral("Status line does not end with CRLF"); return lineLength; } if (!space1 || !space2) { - m_failureReason = "No response code found: " + trimInputSample(header, lineLength - 2); + m_failureReason = makeString("No response code found: ", trimInputSample(header, lineLength - 2)); + return lineLength; + } + + StringView httpStatusLine(reinterpret_cast<const LChar*>(header), space1 - header); + if (!headerHasValidHTTPVersion(httpStatusLine)) { + m_failureReason = makeString("Invalid HTTP version string: ", httpStatusLine); return lineLength; } - String statusCodeString(space1 + 1, space2 - space1 - 1); + StringView statusCodeString(reinterpret_cast<const LChar*>(space1 + 1), space2 - space1 - 1); if (statusCodeString.length() != 3) // Status code must consist of three digits. return lineLength; for (int i = 0; i < 3; ++i) - if (statusCodeString[i] < '0' || statusCodeString[i] > '9') { - m_failureReason = "Invalid status code: " + statusCodeString; + if (!isASCIIDigit(statusCodeString[i])) { + m_failureReason = makeString("Invalid status code: ", statusCodeString); return lineLength; } bool ok = false; - statusCode = statusCodeString.toInt(&ok); + statusCode = statusCodeString.toIntStrict(ok); ASSERT(ok); statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n". @@ -463,7 +502,7 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end) { - AtomicString name; + StringView name; String value; bool sawSecWebSocketExtensionsHeaderField = false; bool sawSecWebSocketAcceptHeaderField = false; @@ -472,39 +511,57 @@ const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* e for (; p < end; p++) { size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value); if (!consumedLength) - return 0; + return nullptr; p += consumedLength; // Stop once we consumed an empty line. if (name.isEmpty()) break; - if (equalIgnoringCase("sec-websocket-extensions", name)) { + HTTPHeaderName headerName; + if (!findHTTPHeaderName(name, headerName)) { + // Evidence in the wild shows that services make use of custom headers in the handshake + m_serverHandshakeResponse.addHTTPHeaderField(name.toString(), value); + continue; + } + + // https://tools.ietf.org/html/rfc7230#section-3.2.4 + // "Newly defined header fields SHOULD limit their field values to US-ASCII octets." + if ((headerName == HTTPHeaderName::SecWebSocketExtensions + || headerName == HTTPHeaderName::SecWebSocketAccept + || headerName == HTTPHeaderName::SecWebSocketProtocol) + && !value.containsOnlyASCII()) { + m_failureReason = makeString(name, " header value should only contain ASCII characters"); + return nullptr; + } + + if (headerName == HTTPHeaderName::SecWebSocketExtensions) { if (sawSecWebSocketExtensionsHeaderField) { - m_failureReason = "The Sec-WebSocket-Extensions header MUST NOT appear more than once in an HTTP response"; - return 0; + m_failureReason = ASCIILiteral("The Sec-WebSocket-Extensions header must not appear more than once in an HTTP response"); + return nullptr; } if (!m_extensionDispatcher.processHeaderValue(value)) { m_failureReason = m_extensionDispatcher.failureReason(); - return 0; + return nullptr; } sawSecWebSocketExtensionsHeaderField = true; - } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) { - if (sawSecWebSocketAcceptHeaderField) { - m_failureReason = "The Sec-WebSocket-Accept header MUST NOT appear more than once in an HTTP response"; - return 0; + } else { + if (headerName == HTTPHeaderName::SecWebSocketAccept) { + if (sawSecWebSocketAcceptHeaderField) { + m_failureReason = ASCIILiteral("The Sec-WebSocket-Accept header must not appear more than once in an HTTP response"); + return nullptr; + } + sawSecWebSocketAcceptHeaderField = true; + } else if (headerName == HTTPHeaderName::SecWebSocketProtocol) { + if (sawSecWebSocketProtocolHeaderField) { + m_failureReason = ASCIILiteral("The Sec-WebSocket-Protocol header must not appear more than once in an HTTP response"); + return nullptr; + } + sawSecWebSocketProtocolHeaderField = true; } - m_serverHandshakeResponse.addHTTPHeaderField(name, value); - sawSecWebSocketAcceptHeaderField = true; - } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) { - if (sawSecWebSocketProtocolHeaderField) { - m_failureReason = "The Sec-WebSocket-Protocol header MUST NOT appear more than once in an HTTP response"; - return 0; - } - m_serverHandshakeResponse.addHTTPHeaderField(name, value); - sawSecWebSocketProtocolHeaderField = true; - } else - m_serverHandshakeResponse.addHTTPHeaderField(name, value); + + m_serverHandshakeResponse.addHTTPHeaderField(headerName, value); + } } return p; } @@ -517,40 +574,40 @@ bool WebSocketHandshake::checkResponseHeaders() const String& serverWebSocketAccept = this->serverWebSocketAccept(); if (serverUpgrade.isNull()) { - m_failureReason = "Error during WebSocket handshake: 'Upgrade' header is missing"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Upgrade' header is missing"); return false; } if (serverConnection.isNull()) { - m_failureReason = "Error during WebSocket handshake: 'Connection' header is missing"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Connection' header is missing"); return false; } if (serverWebSocketAccept.isNull()) { - m_failureReason = "Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing"); return false; } - if (!equalIgnoringCase(serverUpgrade, "websocket")) { - m_failureReason = "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'"; + if (!equalLettersIgnoringASCIICase(serverUpgrade, "websocket")) { + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'"); return false; } - if (!equalIgnoringCase(serverConnection, "upgrade")) { - m_failureReason = "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'"; + if (!equalLettersIgnoringASCIICase(serverConnection, "upgrade")) { + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'"); return false; } if (serverWebSocketAccept != m_expectedAccept) { - m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Accept mismatch"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Accept mismatch"); return false; } if (!serverWebSocketProtocol.isNull()) { if (m_clientProtocol.isEmpty()) { - m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"); return false; } Vector<String> result; - m_clientProtocol.split(String(WebSocket::subProtocolSeperator()), result); + m_clientProtocol.split(WebSocket::subprotocolSeparator(), result); if (!result.contains(serverWebSocketProtocol)) { - m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"); return false; } } |