summaryrefslogtreecommitdiff
path: root/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp
diff options
context:
space:
mode:
authorLorry Tar Creator <lorry-tar-importer@lorry>2017-06-27 06:07:23 +0000
committerLorry Tar Creator <lorry-tar-importer@lorry>2017-06-27 06:07:23 +0000
commit1bf1084f2b10c3b47fd1a588d85d21ed0eb41d0c (patch)
tree46dcd36c86e7fbc6e5df36deb463b33e9967a6f7 /Source/WebCore/Modules/websockets/WebSocketHandshake.cpp
parent32761a6cee1d0dee366b885b7b9c777e67885688 (diff)
downloadWebKitGtk-tarball-master.tar.gz
Diffstat (limited to 'Source/WebCore/Modules/websockets/WebSocketHandshake.cpp')
-rw-r--r--Source/WebCore/Modules/websockets/WebSocketHandshake.cpp249
1 files changed, 153 insertions, 96 deletions
diff --git a/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp b/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp
index e6782cf8f..dd3b9a4e4 100644
--- a/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp
+++ b/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp
@@ -40,13 +40,14 @@
#include "CookieJar.h"
#include "Document.h"
#include "HTTPHeaderMap.h"
+#include "HTTPHeaderNames.h"
#include "HTTPParsers.h"
#include "URL.h"
#include "Logging.h"
#include "ResourceRequest.h"
-#include "ScriptCallStack.h"
#include "ScriptExecutionContext.h"
#include "SecurityOrigin.h"
+#include <wtf/ASCIICType.h>
#include <wtf/CryptographicallyRandomNumber.h>
#include <wtf/MD5.h>
#include <wtf/SHA1.h>
@@ -56,6 +57,7 @@
#include <wtf/text/Base64.h>
#include <wtf/text/CString.h>
#include <wtf/text/StringBuilder.h>
+#include <wtf/text/StringView.h>
#include <wtf/text/WTFString.h>
#include <wtf/unicode/CharacterNames.h>
@@ -81,10 +83,10 @@ static String hostName(const URL& url, bool secure)
{
ASSERT(url.protocolIs("wss") == secure);
StringBuilder builder;
- builder.append(url.host().lower());
- if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) {
+ builder.append(url.host().convertToASCIILowercase());
+ if (url.port() && ((!secure && url.port().value() != 80) || (secure && url.port().value() != 443))) {
builder.append(':');
- builder.appendNumber(url.port());
+ builder.appendNumber(url.port().value());
}
return builder.toString();
}
@@ -118,12 +120,13 @@ String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocket
return base64Encode(hash.data(), SHA1::hashSize);
}
-WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, ScriptExecutionContext* context)
+WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, Document* document, bool allowCookies)
: m_url(url)
, m_clientProtocol(protocol)
, m_secure(m_url.protocolIs("wss"))
- , m_context(context)
+ , m_document(document)
, m_mode(Incomplete)
+ , m_allowCookies(allowCookies)
{
m_secWebSocketKey = generateSecWebSocketKey();
m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey);
@@ -140,12 +143,13 @@ const URL& WebSocketHandshake::url() const
void WebSocketHandshake::setURL(const URL& url)
{
- m_url = url.copy();
+ m_url = url.isolatedCopy();
}
+// FIXME: Return type should just be String, not const String.
const String WebSocketHandshake::host() const
{
- return m_url.host().lower();
+ return m_url.host().convertToASCIILowercase();
}
const String& WebSocketHandshake::clientProtocol() const
@@ -165,14 +169,14 @@ bool WebSocketHandshake::secure() const
String WebSocketHandshake::clientOrigin() const
{
- return m_context->securityOrigin()->toString();
+ return m_document->securityOrigin().toString();
}
String WebSocketHandshake::clientLocation() const
{
StringBuilder builder;
builder.append(m_secure ? "wss" : "ws");
- builder.append("://");
+ builder.appendLiteral("://");
builder.append(hostName(m_url, m_secure));
builder.append(resourceName(m_url));
return builder.toString();
@@ -183,9 +187,9 @@ CString WebSocketHandshake::clientHandshakeMessage() const
// Keep the following consistent with clientHandshakeRequest().
StringBuilder builder;
- builder.append("GET ");
+ builder.appendLiteral("GET ");
builder.append(resourceName(m_url));
- builder.append(" HTTP/1.1\r\n");
+ builder.appendLiteral(" HTTP/1.1\r\n");
Vector<String> fields;
fields.append("Upgrade: websocket");
@@ -196,12 +200,10 @@ CString WebSocketHandshake::clientHandshakeMessage() const
fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol);
URL url = httpURLForAuthenticationAndCookies();
- if (m_context->isDocument()) {
- Document* document = toDocument(m_context);
- String cookie = cookieRequestHeaderFieldValue(document, url);
+ if (m_allowCookies && m_document) {
+ String cookie = cookieRequestHeaderFieldValue(*m_document, url);
if (!cookie.isEmpty())
fields.append("Cookie: " + cookie);
- // Set "Cookie2: <cookie>" if cookies 2 exists for url?
}
// Add no-cache headers to avoid compatibility issue.
@@ -218,18 +220,18 @@ CString WebSocketHandshake::clientHandshakeMessage() const
fields.append("Sec-WebSocket-Extensions: " + extensionValue);
// Add a User-Agent header.
- fields.append("User-Agent: " + m_context->userAgent(m_context->url()));
+ fields.append("User-Agent: " + m_document->userAgent(m_document->url()));
// Fields in the handshake are sent by the client in a random order; the
// order is not meaningful. Thus, it's ok to send the order we constructed
// the fields.
- for (size_t i = 0; i < fields.size(); i++) {
- builder.append(fields[i]);
- builder.append("\r\n");
+ for (auto& field : fields) {
+ builder.append(field);
+ builder.appendLiteral("\r\n");
}
- builder.append("\r\n");
+ builder.appendLiteral("\r\n");
return builder.toString().utf8();
}
@@ -237,37 +239,33 @@ CString WebSocketHandshake::clientHandshakeMessage() const
ResourceRequest WebSocketHandshake::clientHandshakeRequest() const
{
// Keep the following consistent with clientHandshakeMessage().
- // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and
- // m_key3 in the request?
ResourceRequest request(m_url);
request.setHTTPMethod("GET");
- request.addHTTPHeaderField("Connection", "Upgrade");
- request.addHTTPHeaderField("Host", hostName(m_url, m_secure));
- request.addHTTPHeaderField("Origin", clientOrigin());
+ request.setHTTPHeaderField(HTTPHeaderName::Connection, "Upgrade");
+ request.setHTTPHeaderField(HTTPHeaderName::Host, hostName(m_url, m_secure));
+ request.setHTTPHeaderField(HTTPHeaderName::Origin, clientOrigin());
if (!m_clientProtocol.isEmpty())
- request.addHTTPHeaderField("Sec-WebSocket-Protocol", m_clientProtocol);
+ request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketProtocol, m_clientProtocol);
URL url = httpURLForAuthenticationAndCookies();
- if (m_context->isDocument()) {
- Document* document = toDocument(m_context);
- String cookie = cookieRequestHeaderFieldValue(document, url);
+ if (m_allowCookies && m_document) {
+ String cookie = cookieRequestHeaderFieldValue(*m_document, url);
if (!cookie.isEmpty())
- request.addHTTPHeaderField("Cookie", cookie);
- // Set "Cookie2: <cookie>" if cookies 2 exists for url?
+ request.setHTTPHeaderField(HTTPHeaderName::Cookie, cookie);
}
- request.addHTTPHeaderField("Pragma", "no-cache");
- request.addHTTPHeaderField("Cache-Control", "no-cache");
+ request.setHTTPHeaderField(HTTPHeaderName::Pragma, "no-cache");
+ request.setHTTPHeaderField(HTTPHeaderName::CacheControl, "no-cache");
- request.addHTTPHeaderField("Sec-WebSocket-Key", m_secWebSocketKey);
- request.addHTTPHeaderField("Sec-WebSocket-Version", "13");
+ request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketKey, m_secWebSocketKey);
+ request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketVersion, "13");
const String extensionValue = m_extensionDispatcher.createHeaderValue();
if (extensionValue.length())
- request.addHTTPHeaderField("Sec-WebSocket-Extensions", extensionValue);
+ request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketExtensions, extensionValue);
// Add a User-Agent header.
- request.addHTTPHeaderField("User-Agent", m_context->userAgent(m_context->url()));
+ request.setHTTPHeaderField(HTTPHeaderName::UserAgent, m_document->userAgent(m_document->url()));
return request;
}
@@ -278,9 +276,9 @@ void WebSocketHandshake::reset()
m_extensionDispatcher.reset();
}
-void WebSocketHandshake::clearScriptExecutionContext()
+void WebSocketHandshake::clearDocument()
{
- m_context = 0;
+ m_document = nullptr;
}
int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
@@ -303,7 +301,7 @@ int WebSocketHandshake::readServerHandshake(const char* header, size_t len)
if (statusCode != 101) {
m_mode = Failed;
- m_failureReason = "Unexpected response code: " + String::number(statusCode);
+ m_failureReason = makeString("Unexpected response code: ", String::number(statusCode));
return len;
}
m_mode = Normal;
@@ -340,32 +338,27 @@ String WebSocketHandshake::failureReason() const
String WebSocketHandshake::serverWebSocketProtocol() const
{
- return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-protocol");
+ return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketProtocol);
}
String WebSocketHandshake::serverSetCookie() const
{
- return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie");
-}
-
-String WebSocketHandshake::serverSetCookie2() const
-{
- return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie2");
+ return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SetCookie);
}
String WebSocketHandshake::serverUpgrade() const
{
- return m_serverHandshakeResponse.httpHeaderFields().get("upgrade");
+ return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Upgrade);
}
String WebSocketHandshake::serverConnection() const
{
- return m_serverHandshakeResponse.httpHeaderFields().get("connection");
+ return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Connection);
}
String WebSocketHandshake::serverWebSocketAccept() const
{
- return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-accept");
+ return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketAccept);
}
String WebSocketHandshake::acceptedExtensions() const
@@ -378,19 +371,56 @@ const ResourceResponse& WebSocketHandshake::serverHandshakeResponse() const
return m_serverHandshakeResponse;
}
-void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor)
+void WebSocketHandshake::addExtensionProcessor(std::unique_ptr<WebSocketExtensionProcessor> processor)
{
- m_extensionDispatcher.addProcessor(processor);
+ m_extensionDispatcher.addProcessor(WTFMove(processor));
}
URL WebSocketHandshake::httpURLForAuthenticationAndCookies() const
{
- URL url = m_url.copy();
+ URL url = m_url.isolatedCopy();
bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http");
ASSERT_UNUSED(couldSetProtocol, couldSetProtocol);
return url;
}
+// https://tools.ietf.org/html/rfc6455#section-4.1
+// "The HTTP version MUST be at least 1.1."
+static inline bool headerHasValidHTTPVersion(StringView httpStatusLine)
+{
+ const char* httpVersionStaticPreambleLiteral = "HTTP/";
+ StringView httpVersionStaticPreamble(reinterpret_cast<const LChar*>(httpVersionStaticPreambleLiteral), strlen(httpVersionStaticPreambleLiteral));
+ if (!httpStatusLine.startsWith(httpVersionStaticPreamble))
+ return false;
+
+ // Check that there is a version number which should be at least three characters after "HTTP/"
+ unsigned preambleLength = httpVersionStaticPreamble.length();
+ if (httpStatusLine.length() < preambleLength + 3)
+ return false;
+
+ auto dotPosition = httpStatusLine.find('.', preambleLength);
+ if (dotPosition == notFound)
+ return false;
+
+ StringView majorVersionView = httpStatusLine.substring(preambleLength, dotPosition - preambleLength);
+ bool isValid;
+ int majorVersion = majorVersionView.toIntStrict(isValid);
+ if (!isValid)
+ return false;
+
+ unsigned minorVersionLength;
+ unsigned charactersLeftAfterDotPosition = httpStatusLine.length() - dotPosition;
+ for (minorVersionLength = 1; minorVersionLength < charactersLeftAfterDotPosition; minorVersionLength++) {
+ if (!isASCIIDigit(httpStatusLine[dotPosition + minorVersionLength]))
+ break;
+ }
+ int minorVersion = (httpStatusLine.substring(dotPosition + 1, minorVersionLength)).toIntStrict(isValid);
+ if (!isValid)
+ return false;
+
+ return (majorVersion >= 1 && minorVersion >= 1) || majorVersion >= 2;
+}
+
// Returns the header length (including "\r\n"), or -1 if we have not received enough data yet.
// If the line is malformed or the status code is not a 3-digit number,
// statusCode and statusText will be set to -1 and a null string, respectively.
@@ -403,8 +433,8 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength,
statusCode = -1;
statusText = String();
- const char* space1 = 0;
- const char* space2 = 0;
+ const char* space1 = nullptr;
+ const char* space2 = nullptr;
const char* p;
size_t consumedLength;
@@ -418,7 +448,10 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength,
// The caller isn't prepared to deal with null bytes in status
// line. WebSockets specification doesn't prohibit this, but HTTP
// does, so we'll just treat this as an error.
- m_failureReason = "Status line contains embedded null";
+ m_failureReason = ASCIILiteral("Status line contains embedded null");
+ return p + 1 - header;
+ } else if (!isASCII(*p)) {
+ m_failureReason = ASCIILiteral("Status line contains non-ASCII character");
return p + 1 - header;
} else if (*p == '\n')
break;
@@ -429,32 +462,38 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength,
const char* end = p + 1;
int lineLength = end - header;
if (lineLength > maximumLength) {
- m_failureReason = "Status line is too long";
+ m_failureReason = ASCIILiteral("Status line is too long");
return maximumLength;
}
// The line must end with "\r\n".
if (lineLength < 2 || *(end - 2) != '\r') {
- m_failureReason = "Status line does not end with CRLF";
+ m_failureReason = ASCIILiteral("Status line does not end with CRLF");
return lineLength;
}
if (!space1 || !space2) {
- m_failureReason = "No response code found: " + trimInputSample(header, lineLength - 2);
+ m_failureReason = makeString("No response code found: ", trimInputSample(header, lineLength - 2));
+ return lineLength;
+ }
+
+ StringView httpStatusLine(reinterpret_cast<const LChar*>(header), space1 - header);
+ if (!headerHasValidHTTPVersion(httpStatusLine)) {
+ m_failureReason = makeString("Invalid HTTP version string: ", httpStatusLine);
return lineLength;
}
- String statusCodeString(space1 + 1, space2 - space1 - 1);
+ StringView statusCodeString(reinterpret_cast<const LChar*>(space1 + 1), space2 - space1 - 1);
if (statusCodeString.length() != 3) // Status code must consist of three digits.
return lineLength;
for (int i = 0; i < 3; ++i)
- if (statusCodeString[i] < '0' || statusCodeString[i] > '9') {
- m_failureReason = "Invalid status code: " + statusCodeString;
+ if (!isASCIIDigit(statusCodeString[i])) {
+ m_failureReason = makeString("Invalid status code: ", statusCodeString);
return lineLength;
}
bool ok = false;
- statusCode = statusCodeString.toInt(&ok);
+ statusCode = statusCodeString.toIntStrict(ok);
ASSERT(ok);
statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n".
@@ -463,7 +502,7 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength,
const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end)
{
- AtomicString name;
+ StringView name;
String value;
bool sawSecWebSocketExtensionsHeaderField = false;
bool sawSecWebSocketAcceptHeaderField = false;
@@ -472,39 +511,57 @@ const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* e
for (; p < end; p++) {
size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value);
if (!consumedLength)
- return 0;
+ return nullptr;
p += consumedLength;
// Stop once we consumed an empty line.
if (name.isEmpty())
break;
- if (equalIgnoringCase("sec-websocket-extensions", name)) {
+ HTTPHeaderName headerName;
+ if (!findHTTPHeaderName(name, headerName)) {
+ // Evidence in the wild shows that services make use of custom headers in the handshake
+ m_serverHandshakeResponse.addHTTPHeaderField(name.toString(), value);
+ continue;
+ }
+
+ // https://tools.ietf.org/html/rfc7230#section-3.2.4
+ // "Newly defined header fields SHOULD limit their field values to US-ASCII octets."
+ if ((headerName == HTTPHeaderName::SecWebSocketExtensions
+ || headerName == HTTPHeaderName::SecWebSocketAccept
+ || headerName == HTTPHeaderName::SecWebSocketProtocol)
+ && !value.containsOnlyASCII()) {
+ m_failureReason = makeString(name, " header value should only contain ASCII characters");
+ return nullptr;
+ }
+
+ if (headerName == HTTPHeaderName::SecWebSocketExtensions) {
if (sawSecWebSocketExtensionsHeaderField) {
- m_failureReason = "The Sec-WebSocket-Extensions header MUST NOT appear more than once in an HTTP response";
- return 0;
+ m_failureReason = ASCIILiteral("The Sec-WebSocket-Extensions header must not appear more than once in an HTTP response");
+ return nullptr;
}
if (!m_extensionDispatcher.processHeaderValue(value)) {
m_failureReason = m_extensionDispatcher.failureReason();
- return 0;
+ return nullptr;
}
sawSecWebSocketExtensionsHeaderField = true;
- } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) {
- if (sawSecWebSocketAcceptHeaderField) {
- m_failureReason = "The Sec-WebSocket-Accept header MUST NOT appear more than once in an HTTP response";
- return 0;
+ } else {
+ if (headerName == HTTPHeaderName::SecWebSocketAccept) {
+ if (sawSecWebSocketAcceptHeaderField) {
+ m_failureReason = ASCIILiteral("The Sec-WebSocket-Accept header must not appear more than once in an HTTP response");
+ return nullptr;
+ }
+ sawSecWebSocketAcceptHeaderField = true;
+ } else if (headerName == HTTPHeaderName::SecWebSocketProtocol) {
+ if (sawSecWebSocketProtocolHeaderField) {
+ m_failureReason = ASCIILiteral("The Sec-WebSocket-Protocol header must not appear more than once in an HTTP response");
+ return nullptr;
+ }
+ sawSecWebSocketProtocolHeaderField = true;
}
- m_serverHandshakeResponse.addHTTPHeaderField(name, value);
- sawSecWebSocketAcceptHeaderField = true;
- } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) {
- if (sawSecWebSocketProtocolHeaderField) {
- m_failureReason = "The Sec-WebSocket-Protocol header MUST NOT appear more than once in an HTTP response";
- return 0;
- }
- m_serverHandshakeResponse.addHTTPHeaderField(name, value);
- sawSecWebSocketProtocolHeaderField = true;
- } else
- m_serverHandshakeResponse.addHTTPHeaderField(name, value);
+
+ m_serverHandshakeResponse.addHTTPHeaderField(headerName, value);
+ }
}
return p;
}
@@ -517,40 +574,40 @@ bool WebSocketHandshake::checkResponseHeaders()
const String& serverWebSocketAccept = this->serverWebSocketAccept();
if (serverUpgrade.isNull()) {
- m_failureReason = "Error during WebSocket handshake: 'Upgrade' header is missing";
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Upgrade' header is missing");
return false;
}
if (serverConnection.isNull()) {
- m_failureReason = "Error during WebSocket handshake: 'Connection' header is missing";
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Connection' header is missing");
return false;
}
if (serverWebSocketAccept.isNull()) {
- m_failureReason = "Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing";
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing");
return false;
}
- if (!equalIgnoringCase(serverUpgrade, "websocket")) {
- m_failureReason = "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'";
+ if (!equalLettersIgnoringASCIICase(serverUpgrade, "websocket")) {
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'");
return false;
}
- if (!equalIgnoringCase(serverConnection, "upgrade")) {
- m_failureReason = "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'";
+ if (!equalLettersIgnoringASCIICase(serverConnection, "upgrade")) {
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'");
return false;
}
if (serverWebSocketAccept != m_expectedAccept) {
- m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Accept mismatch";
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Accept mismatch");
return false;
}
if (!serverWebSocketProtocol.isNull()) {
if (m_clientProtocol.isEmpty()) {
- m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch";
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch");
return false;
}
Vector<String> result;
- m_clientProtocol.split(String(WebSocket::subProtocolSeperator()), result);
+ m_clientProtocol.split(WebSocket::subprotocolSeparator(), result);
if (!result.contains(serverWebSocketProtocol)) {
- m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch";
+ m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch");
return false;
}
}