diff options
author | Lorry Tar Creator <lorry-tar-importer@lorry> | 2017-06-27 06:07:23 +0000 |
---|---|---|
committer | Lorry Tar Creator <lorry-tar-importer@lorry> | 2017-06-27 06:07:23 +0000 |
commit | 1bf1084f2b10c3b47fd1a588d85d21ed0eb41d0c (patch) | |
tree | 46dcd36c86e7fbc6e5df36deb463b33e9967a6f7 /Source/WebCore/Modules/websockets | |
parent | 32761a6cee1d0dee366b885b7b9c777e67885688 (diff) | |
download | WebKitGtk-tarball-master.tar.gz |
webkitgtk-2.16.5HEADwebkitgtk-2.16.5master
Diffstat (limited to 'Source/WebCore/Modules/websockets')
25 files changed, 1117 insertions, 1235 deletions
diff --git a/Source/WebCore/Modules/websockets/CloseEvent.h b/Source/WebCore/Modules/websockets/CloseEvent.h index fda51d8d7..85954f07f 100644 --- a/Source/WebCore/Modules/websockets/CloseEvent.h +++ b/Source/WebCore/Modules/websockets/CloseEvent.h @@ -28,41 +28,29 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef CloseEvent_h -#define CloseEvent_h +#pragma once #include "Event.h" #include "EventNames.h" namespace WebCore { -struct CloseEventInit : public EventInit { - CloseEventInit() - : wasClean(false) - , code(0) - { - }; - - bool wasClean; - unsigned short code; - String reason; -}; - class CloseEvent : public Event { public: - static PassRefPtr<CloseEvent> create() + static Ref<CloseEvent> create(bool wasClean, unsigned short code, const String& reason) { - return adoptRef(new CloseEvent()); + return adoptRef(*new CloseEvent(wasClean, code, reason)); } - static PassRefPtr<CloseEvent> create(bool wasClean, unsigned short code, const String& reason) - { - return adoptRef(new CloseEvent(wasClean, code, reason)); - } + struct Init : EventInit { + bool wasClean { false }; + unsigned short code { 0 }; + String reason; + }; - static PassRefPtr<CloseEvent> create(const AtomicString& type, const CloseEventInit& initializer) + static Ref<CloseEvent> create(const AtomicString& type, const Init& initializer, IsTrusted isTrusted = IsTrusted::No) { - return adoptRef(new CloseEvent(type, initializer)); + return adoptRef(*new CloseEvent(type, initializer, isTrusted)); } bool wasClean() const { return m_wasClean; } @@ -70,16 +58,9 @@ public: String reason() const { return m_reason; } // Event function. - virtual EventInterface eventInterface() const override { return CloseEventInterfaceType; } + EventInterface eventInterface() const override { return CloseEventInterfaceType; } private: - CloseEvent() - : Event(eventNames().closeEvent, false, false) - , m_wasClean(false) - , m_code(0) - { - } - CloseEvent(bool wasClean, int code, const String& reason) : Event(eventNames().closeEvent, false, false) , m_wasClean(wasClean) @@ -88,8 +69,8 @@ private: { } - CloseEvent(const AtomicString& type, const CloseEventInit& initializer) - : Event(type, initializer) + CloseEvent(const AtomicString& type, const Init& initializer, IsTrusted isTrusted) + : Event(type, initializer, isTrusted) , m_wasClean(initializer.wasClean) , m_code(initializer.code) , m_reason(initializer.reason) @@ -102,5 +83,3 @@ private: }; } // namespace WebCore - -#endif // CloseEvent_h diff --git a/Source/WebCore/Modules/websockets/CloseEvent.idl b/Source/WebCore/Modules/websockets/CloseEvent.idl index 3f29a4959..db536c227 100644 --- a/Source/WebCore/Modules/websockets/CloseEvent.idl +++ b/Source/WebCore/Modules/websockets/CloseEvent.idl @@ -28,12 +28,17 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +// FIXME: This should be exposed to workers as well. [ - JSNoStaticTables, - ConstructorTemplate=Event + Constructor(DOMString type, optional CloseEventInit eventInitDict), ] interface CloseEvent : Event { - [InitializedByEventConstructor] readonly attribute boolean wasClean; - [InitializedByEventConstructor] readonly attribute unsigned short code; - [InitializedByEventConstructor] readonly attribute DOMString reason; + readonly attribute boolean wasClean; + readonly attribute unsigned short code; + readonly attribute USVString reason; }; +dictionary CloseEventInit : EventInit { + boolean wasClean = false; + unsigned short code = 0; + USVString reason = ""; +}; diff --git a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.cpp b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.cpp index 919c60501..18aaa8c12 100644 --- a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.cpp +++ b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.cpp @@ -31,7 +31,6 @@ #include "config.h" #if ENABLE(WEB_SOCKETS) - #include "ThreadableWebSocketChannel.h" #include "Document.h" @@ -43,27 +42,18 @@ #include "WorkerRunLoop.h" #include "WorkerThread.h" #include "WorkerThreadableWebSocketChannel.h" -#include <wtf/PassRefPtr.h> -#include <wtf/text/WTFString.h> namespace WebCore { -static const char webSocketChannelMode[] = "webSocketChannelMode"; - -PassRefPtr<ThreadableWebSocketChannel> ThreadableWebSocketChannel::create(ScriptExecutionContext* context, WebSocketChannelClient* client) +Ref<ThreadableWebSocketChannel> ThreadableWebSocketChannel::create(ScriptExecutionContext& context, WebSocketChannelClient& client, SocketProvider& provider) { - ASSERT(context); - ASSERT(client); - - if (context->isWorkerGlobalScope()) { - WorkerGlobalScope* workerGlobalScope = static_cast<WorkerGlobalScope*>(context); - WorkerRunLoop& runLoop = workerGlobalScope->thread()->runLoop(); - String mode = webSocketChannelMode; - mode.append(String::number(runLoop.createUniqueId())); - return WorkerThreadableWebSocketChannel::create(workerGlobalScope, client, mode); + if (is<WorkerGlobalScope>(context)) { + WorkerGlobalScope& workerGlobalScope = downcast<WorkerGlobalScope>(context); + WorkerRunLoop& runLoop = workerGlobalScope.thread().runLoop(); + return WorkerThreadableWebSocketChannel::create(workerGlobalScope, client, makeString("webSocketChannelMode", String::number(runLoop.createUniqueId())), provider); } - return WebSocketChannel::create(toDocument(context), client); + return WebSocketChannel::create(downcast<Document>(context), client, provider); } } // namespace WebCore diff --git a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.h b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.h index ded34c85b..9327e2f02 100644 --- a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.h +++ b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannel.h @@ -28,18 +28,15 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef ThreadableWebSocketChannel_h -#define ThreadableWebSocketChannel_h +#pragma once #if ENABLE(WEB_SOCKETS) #include <wtf/Forward.h> #include <wtf/Noncopyable.h> -#include <wtf/PassRefPtr.h> namespace JSC { class ArrayBuffer; -class ArrayBufferView; } namespace WebCore { @@ -47,18 +44,18 @@ namespace WebCore { class Blob; class URL; class ScriptExecutionContext; +class SocketProvider; class WebSocketChannelClient; class ThreadableWebSocketChannel { WTF_MAKE_NONCOPYABLE(ThreadableWebSocketChannel); public: + static Ref<ThreadableWebSocketChannel> create(ScriptExecutionContext&, WebSocketChannelClient&, SocketProvider&); ThreadableWebSocketChannel() { } - static PassRefPtr<ThreadableWebSocketChannel> create(ScriptExecutionContext*, WebSocketChannelClient*); enum SendResult { SendSuccess, - SendFail, - InvalidMessage + SendFail }; virtual void connect(const URL&, const String& protocol) = 0; @@ -66,8 +63,8 @@ public: virtual String extensions() = 0; // Will be available after didConnect() callback is invoked. virtual SendResult send(const String& message) = 0; virtual SendResult send(const JSC::ArrayBuffer&, unsigned byteOffset, unsigned byteLength) = 0; - virtual SendResult send(const Blob&) = 0; - virtual unsigned long bufferedAmount() const = 0; + virtual SendResult send(Blob&) = 0; + virtual unsigned bufferedAmount() const = 0; virtual void close(int code, const String& reason) = 0; // Log the reason text and close the connection. Will call didClose(). virtual void fail(const String& reason) = 0; @@ -88,5 +85,3 @@ protected: } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // ThreadableWebSocketChannel_h diff --git a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.cpp b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.cpp index 0ab3c5eb1..d19f6cc9e 100644 --- a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.cpp +++ b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.cpp @@ -32,19 +32,17 @@ #if ENABLE(WEB_SOCKETS) #include "ThreadableWebSocketChannelClientWrapper.h" -#include "CrossThreadCopier.h" -#include "CrossThreadTask.h" #include "ScriptExecutionContext.h" #include "WebSocketChannelClient.h" -#include <wtf/PassRefPtr.h> #include <wtf/RefPtr.h> +#include <wtf/text/StringView.h> namespace WebCore { -ThreadableWebSocketChannelClientWrapper::ThreadableWebSocketChannelClientWrapper(ScriptExecutionContext* context, WebSocketChannelClient* client) +ThreadableWebSocketChannelClientWrapper::ThreadableWebSocketChannelClientWrapper(ScriptExecutionContext& context, WebSocketChannelClient& client) : m_context(context) - , m_client(client) - , m_peer(0) + , m_client(&client) + , m_peer(nullptr) , m_failedWebSocketChannelCreation(false) , m_syncMethodDone(true) , m_sendRequestResult(ThreadableWebSocketChannel::SendFail) @@ -53,9 +51,9 @@ ThreadableWebSocketChannelClientWrapper::ThreadableWebSocketChannelClientWrapper { } -PassRefPtr<ThreadableWebSocketChannelClientWrapper> ThreadableWebSocketChannelClientWrapper::create(ScriptExecutionContext* context, WebSocketChannelClient* client) +Ref<ThreadableWebSocketChannelClientWrapper> ThreadableWebSocketChannelClientWrapper::create(ScriptExecutionContext& context, WebSocketChannelClient& client) { - return adoptRef(new ThreadableWebSocketChannelClientWrapper(context, client)); + return adoptRef(*new ThreadableWebSocketChannelClientWrapper(context, client)); } void ThreadableWebSocketChannelClientWrapper::clearSyncMethodDone() @@ -86,7 +84,7 @@ void ThreadableWebSocketChannelClientWrapper::didCreateWebSocketChannel(WorkerTh void ThreadableWebSocketChannelClientWrapper::clearPeer() { - m_peer = 0; + m_peer = nullptr; } bool ThreadableWebSocketChannelClientWrapper::failedWebSocketChannelCreation() const @@ -110,8 +108,7 @@ void ThreadableWebSocketChannelClientWrapper::setSubprotocol(const String& subpr { unsigned length = subprotocol.length(); m_subprotocol.resize(length); - if (length) - memcpy(m_subprotocol.data(), subprotocol.deprecatedCharacters(), sizeof(UChar) * length); + StringView(subprotocol).getCharactersWithUpconvert(m_subprotocol.data()); } String ThreadableWebSocketChannelClientWrapper::extensions() const @@ -125,8 +122,7 @@ void ThreadableWebSocketChannelClientWrapper::setExtensions(const String& extens { unsigned length = extensions.length(); m_extensions.resize(length); - if (length) - memcpy(m_extensions.data(), extensions.deprecatedCharacters(), sizeof(UChar) * length); + StringView(extensions).getCharactersWithUpconvert(m_extensions.data()); } ThreadableWebSocketChannel::SendResult ThreadableWebSocketChannelClientWrapper::sendRequestResult() const @@ -140,12 +136,12 @@ void ThreadableWebSocketChannelClientWrapper::setSendRequestResult(ThreadableWeb m_syncMethodDone = true; } -unsigned long ThreadableWebSocketChannelClientWrapper::bufferedAmount() const +unsigned ThreadableWebSocketChannelClientWrapper::bufferedAmount() const { return m_bufferedAmount; } -void ThreadableWebSocketChannelClientWrapper::setBufferedAmount(unsigned long bufferedAmount) +void ThreadableWebSocketChannelClientWrapper::setBufferedAmount(unsigned bufferedAmount) { m_bufferedAmount = bufferedAmount; m_syncMethodDone = true; @@ -153,54 +149,93 @@ void ThreadableWebSocketChannelClientWrapper::setBufferedAmount(unsigned long bu void ThreadableWebSocketChannelClientWrapper::clearClient() { - m_client = 0; + m_client = nullptr; } void ThreadableWebSocketChannelClientWrapper::didConnect() { - m_pendingTasks.append(createCallbackTask(&didConnectCallback, this)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this)] (ScriptExecutionContext&) { + if (m_client) + m_client->didConnect(); + })); + if (!m_suspended) processPendingTasks(); } void ThreadableWebSocketChannelClientWrapper::didReceiveMessage(const String& message) { - m_pendingTasks.append(createCallbackTask(&didReceiveMessageCallback, this, message)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this), message = message.isolatedCopy()] (ScriptExecutionContext&) { + if (m_client) + m_client->didReceiveMessage(message); + })); + if (!m_suspended) processPendingTasks(); } -void ThreadableWebSocketChannelClientWrapper::didReceiveBinaryData(PassOwnPtr<Vector<char>> binaryData) +void ThreadableWebSocketChannelClientWrapper::didReceiveBinaryData(Vector<uint8_t>&& binaryData) { - m_pendingTasks.append(createCallbackTask(&didReceiveBinaryDataCallback, this, binaryData)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this), binaryData = WTFMove(binaryData)] (ScriptExecutionContext&) mutable { + if (m_client) + m_client->didReceiveBinaryData(WTFMove(binaryData)); + })); + if (!m_suspended) processPendingTasks(); } -void ThreadableWebSocketChannelClientWrapper::didUpdateBufferedAmount(unsigned long bufferedAmount) +void ThreadableWebSocketChannelClientWrapper::didUpdateBufferedAmount(unsigned bufferedAmount) { - m_pendingTasks.append(createCallbackTask(&didUpdateBufferedAmountCallback, this, bufferedAmount)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this), bufferedAmount] (ScriptExecutionContext&) { + if (m_client) + m_client->didUpdateBufferedAmount(bufferedAmount); + })); + if (!m_suspended) processPendingTasks(); } void ThreadableWebSocketChannelClientWrapper::didStartClosingHandshake() { - m_pendingTasks.append(createCallbackTask(&didStartClosingHandshakeCallback, this)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this)] (ScriptExecutionContext&) { + if (m_client) + m_client->didStartClosingHandshake(); + })); + if (!m_suspended) processPendingTasks(); } -void ThreadableWebSocketChannelClientWrapper::didClose(unsigned long unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) +void ThreadableWebSocketChannelClientWrapper::didClose(unsigned unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) { - m_pendingTasks.append(createCallbackTask(&didCloseCallback, this, unhandledBufferedAmount, closingHandshakeCompletion, code, reason)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this), unhandledBufferedAmount, closingHandshakeCompletion, code, reason = reason.isolatedCopy()] (ScriptExecutionContext&) { + if (m_client) + m_client->didClose(unhandledBufferedAmount, closingHandshakeCompletion, code, reason); + })); + if (!m_suspended) processPendingTasks(); } void ThreadableWebSocketChannelClientWrapper::didReceiveMessageError() { - m_pendingTasks.append(createCallbackTask(&didReceiveMessageErrorCallback, this)); + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this)] (ScriptExecutionContext&) { + if (m_client) + m_client->didReceiveMessageError(); + })); + + if (!m_suspended) + processPendingTasks(); +} + +void ThreadableWebSocketChannelClientWrapper::didUpgradeURL() +{ + m_pendingTasks.append(std::make_unique<ScriptExecutionContext::Task>([this, protectedThis = makeRef(*this)] (ScriptExecutionContext&) { + if (m_client) + m_client->didUpgradeURL(); + })); + if (!m_suspended) processPendingTasks(); } @@ -216,12 +251,6 @@ void ThreadableWebSocketChannelClientWrapper::resume() processPendingTasks(); } -void ThreadableWebSocketChannelClientWrapper::processPendingTasksCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - wrapper->processPendingTasks(); -} - void ThreadableWebSocketChannelClientWrapper::processPendingTasks() { if (m_suspended) @@ -229,62 +258,16 @@ void ThreadableWebSocketChannelClientWrapper::processPendingTasks() if (!m_syncMethodDone) { // When a synchronous operation is in progress (i.e. the execution stack contains // WorkerThreadableWebSocketChannel::waitForMethodCompletion()), we cannot invoke callbacks in this run loop. - m_context->postTask(createCallbackTask(&ThreadableWebSocketChannelClientWrapper::processPendingTasksCallback, this)); + m_context.postTask([this, protectedThis = makeRef(*this)] (ScriptExecutionContext& context) { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + processPendingTasks(); + }); return; } - Vector<OwnPtr<ScriptExecutionContext::Task>> tasks; - tasks.swap(m_pendingTasks); - for (Vector<OwnPtr<ScriptExecutionContext::Task>>::const_iterator iter = tasks.begin(); iter != tasks.end(); ++iter) - (*iter)->performTask(0); -} -void ThreadableWebSocketChannelClientWrapper::didConnectCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didConnect(); -} - -void ThreadableWebSocketChannelClientWrapper::didReceiveMessageCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper, const String& message) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didReceiveMessage(message); -} - -void ThreadableWebSocketChannelClientWrapper::didReceiveBinaryDataCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper, PassOwnPtr<Vector<char>> binaryData) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didReceiveBinaryData(binaryData); -} - -void ThreadableWebSocketChannelClientWrapper::didUpdateBufferedAmountCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper, unsigned long bufferedAmount) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didUpdateBufferedAmount(bufferedAmount); -} - -void ThreadableWebSocketChannelClientWrapper::didStartClosingHandshakeCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didStartClosingHandshake(); -} - -void ThreadableWebSocketChannelClientWrapper::didCloseCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper, unsigned long unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didClose(unhandledBufferedAmount, closingHandshakeCompletion, code, reason); -} - -void ThreadableWebSocketChannelClientWrapper::didReceiveMessageErrorCallback(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> wrapper) -{ - ASSERT_UNUSED(context, !context); - if (wrapper->m_client) - wrapper->m_client->didReceiveMessageError(); + Vector<std::unique_ptr<ScriptExecutionContext::Task>> pendingTasks = WTFMove(m_pendingTasks); + for (auto& task : pendingTasks) + task->performTask(m_context); } } // namespace WebCore diff --git a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.h b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.h index 29f54950a..492f0ca2a 100644 --- a/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.h +++ b/Source/WebCore/Modules/websockets/ThreadableWebSocketChannelClientWrapper.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef ThreadableWebSocketChannelClientWrapper_h -#define ThreadableWebSocketChannelClientWrapper_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -37,9 +36,8 @@ #include "ThreadableWebSocketChannel.h" #include "WebSocketChannelClient.h" #include "WorkerThreadableWebSocketChannel.h" +#include <memory> #include <wtf/Forward.h> -#include <wtf/OwnPtr.h> -#include <wtf/PassOwnPtr.h> #include <wtf/Threading.h> #include <wtf/Vector.h> #include <wtf/text/WTFString.h> @@ -51,7 +49,7 @@ class WebSocketChannelClient; class ThreadableWebSocketChannelClientWrapper : public ThreadSafeRefCounted<ThreadableWebSocketChannelClientWrapper> { public: - static PassRefPtr<ThreadableWebSocketChannelClientWrapper> create(ScriptExecutionContext*, WebSocketChannelClient*); + static Ref<ThreadableWebSocketChannelClientWrapper> create(ScriptExecutionContext&, WebSocketChannelClient&); void clearSyncMethodDone(); void setSyncMethodDone(); @@ -73,37 +71,29 @@ public: ThreadableWebSocketChannel::SendResult sendRequestResult() const; void setSendRequestResult(ThreadableWebSocketChannel::SendResult); - unsigned long bufferedAmount() const; - void setBufferedAmount(unsigned long); + unsigned bufferedAmount() const; + void setBufferedAmount(unsigned); void clearClient(); void didConnect(); void didReceiveMessage(const String& message); - void didReceiveBinaryData(PassOwnPtr<Vector<char>>); - void didUpdateBufferedAmount(unsigned long bufferedAmount); + void didReceiveBinaryData(Vector<uint8_t>&&); + void didUpdateBufferedAmount(unsigned bufferedAmount); void didStartClosingHandshake(); - void didClose(unsigned long unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus, unsigned short code, const String& reason); + void didClose(unsigned unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus, unsigned short code, const String& reason); void didReceiveMessageError(); + void didUpgradeURL(); void suspend(); void resume(); private: - ThreadableWebSocketChannelClientWrapper(ScriptExecutionContext*, WebSocketChannelClient*); + ThreadableWebSocketChannelClientWrapper(ScriptExecutionContext&, WebSocketChannelClient&); void processPendingTasks(); - static void didConnectCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>); - static void didReceiveMessageCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>, const String& message); - static void didReceiveBinaryDataCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>, PassOwnPtr<Vector<char>>); - static void didUpdateBufferedAmountCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>, unsigned long bufferedAmount); - static void didStartClosingHandshakeCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>); - static void didCloseCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>, unsigned long unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus, unsigned short code, const String& reason); - static void processPendingTasksCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>); - static void didReceiveMessageErrorCallback(ScriptExecutionContext*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>); - - ScriptExecutionContext* m_context; + ScriptExecutionContext& m_context; WebSocketChannelClient* m_client; WorkerThreadableWebSocketChannel::Peer* m_peer; bool m_failedWebSocketChannelCreation; @@ -112,13 +102,11 @@ private: Vector<UChar> m_subprotocol; Vector<UChar> m_extensions; ThreadableWebSocketChannel::SendResult m_sendRequestResult; - unsigned long m_bufferedAmount; + unsigned m_bufferedAmount; bool m_suspended; - Vector<OwnPtr<ScriptExecutionContext::Task>> m_pendingTasks; + Vector<std::unique_ptr<ScriptExecutionContext::Task>> m_pendingTasks; }; } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // ThreadableWebSocketChannelClientWrapper_h diff --git a/Source/WebCore/Modules/websockets/WebSocket.cpp b/Source/WebCore/Modules/websockets/WebSocket.cpp index 72b192af3..bed936267 100644 --- a/Source/WebCore/Modules/websockets/WebSocket.cpp +++ b/Source/WebCore/Modules/websockets/WebSocket.cpp @@ -1,5 +1,6 @@ /* * Copyright (C) 2011 Google Inc. All rights reserved. + * Copyright (C) 2015-2016 Apple Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are @@ -35,35 +36,38 @@ #include "WebSocket.h" #include "Blob.h" -#include "BlobData.h" #include "CloseEvent.h" #include "ContentSecurityPolicy.h" #include "DOMWindow.h" #include "Document.h" #include "Event.h" -#include "EventException.h" #include "EventListener.h" #include "EventNames.h" #include "ExceptionCode.h" #include "Frame.h" #include "Logging.h" #include "MessageEvent.h" -#include "ScriptCallStack.h" +#include "ResourceLoadObserver.h" #include "ScriptController.h" #include "ScriptExecutionContext.h" #include "SecurityOrigin.h" +#include "SocketProvider.h" #include "ThreadableWebSocketChannel.h" #include "WebSocketChannel.h" +#include <inspector/ScriptCallStack.h> #include <runtime/ArrayBuffer.h> #include <runtime/ArrayBufferView.h> #include <wtf/HashSet.h> -#include <wtf/OwnPtr.h> -#include <wtf/PassOwnPtr.h> +#include <wtf/RunLoop.h> #include <wtf/StdLibExtras.h> #include <wtf/text/CString.h> #include <wtf/text/StringBuilder.h> #include <wtf/text/WTFString.h> +#if USE(WEB_THREAD) +#include "WebCoreThreadRun.h" +#endif + namespace WebCore { const size_t maxReasonSizeInBytes = 123; @@ -81,12 +85,12 @@ static inline bool isValidProtocolCharacter(UChar character) && character != '{' && character != '}'; } -static bool isValidProtocolString(const String& protocol) +static bool isValidProtocolString(StringView protocol) { if (protocol.isEmpty()) return false; - for (size_t i = 0; i < protocol.length(); ++i) { - if (!isValidProtocolCharacter(protocol[i])) + for (auto codeUnit : protocol.codeUnits()) { + if (!isValidProtocolCharacter(codeUnit)) return false; } return true; @@ -99,7 +103,7 @@ static String encodeProtocolString(const String& protocol) if (protocol[i] < 0x20 || protocol[i] > 0x7E) builder.append(String::format("\\u%04X", protocol[i])); else if (protocol[i] == 0x5c) - builder.append("\\\\"); + builder.appendLiteral("\\\\"); else builder.append(protocol[i]); } @@ -117,10 +121,10 @@ static String joinStrings(const Vector<String>& strings, const char* separator) return builder.toString(); } -static unsigned long saturateAdd(unsigned long a, unsigned long b) +static unsigned saturateAdd(unsigned a, unsigned b) { - if (std::numeric_limits<unsigned long>::max() - a < b) - return std::numeric_limits<unsigned long>::max(); + if (std::numeric_limits<unsigned>::max() - a < b) + return std::numeric_limits<unsigned>::max(); return a + b; } @@ -136,19 +140,16 @@ bool WebSocket::isAvailable() return webSocketsAvailable; } -const char* WebSocket::subProtocolSeperator() +const char* WebSocket::subprotocolSeparator() { return ", "; } WebSocket::WebSocket(ScriptExecutionContext& context) : ActiveDOMObject(&context) - , m_state(CONNECTING) - , m_bufferedAmount(0) - , m_bufferedAmountAfterClose(0) - , m_binaryType(BinaryTypeBlob) - , m_subprotocol("") - , m_extensions("") + , m_subprotocol(emptyString()) + , m_extensions(emptyString()) + , m_resumeTimer(*this, &WebSocket::resumeTimerFired) { } @@ -158,102 +159,95 @@ WebSocket::~WebSocket() m_channel->disconnect(); } -PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext& context) -{ - RefPtr<WebSocket> webSocket(adoptRef(new WebSocket(context))); - webSocket->suspendIfNeeded(); - return webSocket.release(); -} - -PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext& context, const String& url, ExceptionCode& ec) +ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url) { - Vector<String> protocols; - return WebSocket::create(context, url, protocols, ec); + return create(context, url, Vector<String> { }); } -PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols, ExceptionCode& ec) +ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const Vector<String>& protocols) { - if (url.isNull()) { - ec = SYNTAX_ERR; - return 0; - } + if (url.isNull()) + return Exception { SYNTAX_ERR }; - RefPtr<WebSocket> webSocket(adoptRef(new WebSocket(context))); - webSocket->suspendIfNeeded(); + auto socket = adoptRef(*new WebSocket(context)); + socket->suspendIfNeeded(); - webSocket->connect(context.completeURL(url), protocols, ec); - if (ec) - return 0; + auto result = socket->connect(context.completeURL(url), protocols); + if (result.hasException()) + return result.releaseException(); - return webSocket.release(); + return WTFMove(socket); } -PassRefPtr<WebSocket> WebSocket::create(ScriptExecutionContext& context, const String& url, const String& protocol, ExceptionCode& ec) +ExceptionOr<Ref<WebSocket>> WebSocket::create(ScriptExecutionContext& context, const String& url, const String& protocol) { - Vector<String> protocols; - protocols.append(protocol); - return WebSocket::create(context, url, protocols, ec); + return create(context, url, Vector<String> { 1, protocol }); } -void WebSocket::connect(const String& url, ExceptionCode& ec) +ExceptionOr<void> WebSocket::connect(const String& url) { - Vector<String> protocols; - connect(url, protocols, ec); + return connect(url, Vector<String> { }); } -void WebSocket::connect(const String& url, const String& protocol, ExceptionCode& ec) +ExceptionOr<void> WebSocket::connect(const String& url, const String& protocol) { - Vector<String> protocols; - protocols.append(protocol); - connect(url, protocols, ec); + return connect(url, Vector<String> { 1, protocol }); } -void WebSocket::connect(const String& url, const Vector<String>& protocols, ExceptionCode& ec) +ExceptionOr<void> WebSocket::connect(const String& url, const Vector<String>& protocols) { LOG(Network, "WebSocket %p connect() url='%s'", this, url.utf8().data()); m_url = URL(URL(), url); + ASSERT(scriptExecutionContext()); + auto& context = *scriptExecutionContext(); + if (!m_url.isValid()) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "Invalid url for WebSocket " + m_url.stringCenterEllipsizedToLength()); + context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, "Invalid url for WebSocket " + m_url.stringCenterEllipsizedToLength()); m_state = CLOSED; - ec = SYNTAX_ERR; - return; + return Exception { SYNTAX_ERR }; } if (!m_url.protocolIs("ws") && !m_url.protocolIs("wss")) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "Wrong url scheme for WebSocket " + m_url.stringCenterEllipsizedToLength()); + context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, "Wrong url scheme for WebSocket " + m_url.stringCenterEllipsizedToLength()); m_state = CLOSED; - ec = SYNTAX_ERR; - return; + return Exception { SYNTAX_ERR }; } if (m_url.hasFragmentIdentifier()) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "URL has fragment component " + m_url.stringCenterEllipsizedToLength()); + context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, "URL has fragment component " + m_url.stringCenterEllipsizedToLength()); m_state = CLOSED; - ec = SYNTAX_ERR; - return; + return Exception { SYNTAX_ERR }; } + + ASSERT(context.contentSecurityPolicy()); + auto& contentSecurityPolicy = *context.contentSecurityPolicy(); + + contentSecurityPolicy.upgradeInsecureRequestIfNeeded(m_url, ContentSecurityPolicy::InsecureRequestType::Load); + if (!portAllowed(m_url)) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "WebSocket port " + String::number(m_url.port()) + " blocked"); + String message; + if (m_url.port()) + message = makeString("WebSocket port ", String::number(m_url.port().value()), " blocked"); + else + message = ASCIILiteral("WebSocket without port blocked"); + context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, message); m_state = CLOSED; - ec = SECURITY_ERR; - return; + return Exception { SECURITY_ERR }; } // FIXME: Convert this to check the isolated world's Content Security Policy once webkit.org/b/104520 is solved. - bool shouldBypassMainWorldContentSecurityPolicy = false; - if (scriptExecutionContext()->isDocument()) { - Document* document = toDocument(scriptExecutionContext()); - shouldBypassMainWorldContentSecurityPolicy = document->frame()->script().shouldBypassMainWorldContentSecurityPolicy(); - } - if (!shouldBypassMainWorldContentSecurityPolicy && !scriptExecutionContext()->contentSecurityPolicy()->allowConnectToSource(m_url)) { + if (!context.shouldBypassMainWorldContentSecurityPolicy() && !contentSecurityPolicy.allowConnectToSource(m_url)) { m_state = CLOSED; // FIXME: Should this be throwing an exception? - ec = SECURITY_ERR; - return; + return Exception { SECURITY_ERR }; } - m_channel = ThreadableWebSocketChannel::create(scriptExecutionContext(), this); + if (auto* provider = context.socketProvider()) + m_channel = ThreadableWebSocketChannel::create(*scriptExecutionContext(), *this, *provider); + + // Every ScriptExecutionContext should have a SocketProvider. + RELEASE_ASSERT(m_channel); // FIXME: There is a disagreement about restriction of subprotocols between WebSocket API and hybi-10 protocol // draft. The former simply says "only characters in the range U+0021 to U+007E are allowed," while the latter @@ -262,138 +256,155 @@ void WebSocket::connect(const String& url, const Vector<String>& protocols, Exce // // Here, we throw SYNTAX_ERR if the given protocols do not meet the latter criteria. This behavior does not // comply with WebSocket API specification, but it seems to be the only reasonable way to handle this conflict. - for (size_t i = 0; i < protocols.size(); ++i) { - if (!isValidProtocolString(protocols[i])) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "Wrong protocol for WebSocket '" + encodeProtocolString(protocols[i]) + "'"); + for (auto& protocol : protocols) { + if (!isValidProtocolString(protocol)) { + context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, "Wrong protocol for WebSocket '" + encodeProtocolString(protocol) + "'"); m_state = CLOSED; - ec = SYNTAX_ERR; - return; + return Exception { SYNTAX_ERR }; } } HashSet<String> visited; - for (size_t i = 0; i < protocols.size(); ++i) { - if (!visited.add(protocols[i]).isNewEntry) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "WebSocket protocols contain duplicates: '" + encodeProtocolString(protocols[i]) + "'"); + for (auto& protocol : protocols) { + if (!visited.add(protocol).isNewEntry) { + context.addConsoleMessage(MessageSource::JS, MessageLevel::Error, "WebSocket protocols contain duplicates: '" + encodeProtocolString(protocol) + "'"); m_state = CLOSED; - ec = SYNTAX_ERR; - return; + return Exception { SYNTAX_ERR }; } } + if (is<Document>(context)) { + Document& document = downcast<Document>(context); + if (!document.frame()->loader().mixedContentChecker().canRunInsecureContent(document.securityOrigin(), m_url)) { + // Balanced by the call to ActiveDOMObject::unsetPendingActivity() in WebSocket::stop(). + ActiveDOMObject::setPendingActivity(this); + + // We must block this connection. Instead of throwing an exception, we indicate this + // using the error event. But since this code executes as part of the WebSocket's + // constructor, we have to wait until the constructor has completed before firing the + // event; otherwise, users can't connect to the event. +#if USE(WEB_THREAD) + ref(); + dispatch_async(dispatch_get_main_queue(), ^{ + WebThreadRun(^{ + dispatchOrQueueErrorEvent(); + stop(); + deref(); + }); + }); +#else + RunLoop::main().dispatch([this, protectedThis = makeRef(*this)]() { + dispatchOrQueueErrorEvent(); + stop(); + }); +#endif + return { }; + } else + ResourceLoadObserver::sharedObserver().logWebSocketLoading(document.frame(), m_url); + } + String protocolString; if (!protocols.isEmpty()) - protocolString = joinStrings(protocols, subProtocolSeperator()); + protocolString = joinStrings(protocols, subprotocolSeparator()); m_channel->connect(m_url, protocolString); ActiveDOMObject::setPendingActivity(this); + + return { }; } -void WebSocket::send(const String& message, ExceptionCode& ec) +ExceptionOr<void> WebSocket::send(const String& message) { LOG(Network, "WebSocket %p send() Sending String '%s'", this, message.utf8().data()); - if (m_state == CONNECTING) { - ec = INVALID_STATE_ERR; - return; - } + if (m_state == CONNECTING) + return Exception { INVALID_STATE_ERR }; // No exception is raised if the connection was once established but has subsequently been closed. if (m_state == CLOSING || m_state == CLOSED) { size_t payloadSize = message.utf8().length(); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize)); - return; + return { }; } ASSERT(m_channel); - ThreadableWebSocketChannel::SendResult result = m_channel->send(message); - if (result == ThreadableWebSocketChannel::InvalidMessage) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "Websocket message contains invalid character(s)."); - ec = SYNTAX_ERR; - return; - } + m_channel->send(message); + return { }; } -void WebSocket::send(ArrayBuffer* binaryData, ExceptionCode& ec) +ExceptionOr<void> WebSocket::send(ArrayBuffer& binaryData) { - LOG(Network, "WebSocket %p send() Sending ArrayBuffer %p", this, binaryData); - ASSERT(binaryData); - if (m_state == CONNECTING) { - ec = INVALID_STATE_ERR; - return; - } + LOG(Network, "WebSocket %p send() Sending ArrayBuffer %p", this, &binaryData); + if (m_state == CONNECTING) + return Exception { INVALID_STATE_ERR }; if (m_state == CLOSING || m_state == CLOSED) { - unsigned payloadSize = binaryData->byteLength(); + unsigned payloadSize = binaryData.byteLength(); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize)); - return; + return { }; } ASSERT(m_channel); - m_channel->send(*binaryData, 0, binaryData->byteLength()); + m_channel->send(binaryData, 0, binaryData.byteLength()); + return { }; } -void WebSocket::send(ArrayBufferView* arrayBufferView, ExceptionCode& ec) +ExceptionOr<void> WebSocket::send(ArrayBufferView& arrayBufferView) { - LOG(Network, "WebSocket %p send() Sending ArrayBufferView %p", this, arrayBufferView); - ASSERT(arrayBufferView); - if (m_state == CONNECTING) { - ec = INVALID_STATE_ERR; - return; - } + LOG(Network, "WebSocket %p send() Sending ArrayBufferView %p", this, &arrayBufferView); + + if (m_state == CONNECTING) + return Exception { INVALID_STATE_ERR }; if (m_state == CLOSING || m_state == CLOSED) { - unsigned payloadSize = arrayBufferView->byteLength(); + unsigned payloadSize = arrayBufferView.byteLength(); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize)); - return; + return { }; } ASSERT(m_channel); - RefPtr<ArrayBuffer> arrayBuffer(arrayBufferView->buffer()); - m_channel->send(*arrayBuffer, arrayBufferView->byteOffset(), arrayBufferView->byteLength()); + m_channel->send(*arrayBufferView.unsharedBuffer(), arrayBufferView.byteOffset(), arrayBufferView.byteLength()); + return { }; } -void WebSocket::send(Blob* binaryData, ExceptionCode& ec) +ExceptionOr<void> WebSocket::send(Blob& binaryData) { - LOG(Network, "WebSocket %p send() Sending Blob '%s'", this, binaryData->url().stringCenterEllipsizedToLength().utf8().data()); - ASSERT(binaryData); - if (m_state == CONNECTING) { - ec = INVALID_STATE_ERR; - return; - } + LOG(Network, "WebSocket %p send() Sending Blob '%s'", this, binaryData.url().stringCenterEllipsizedToLength().utf8().data()); + if (m_state == CONNECTING) + return Exception { INVALID_STATE_ERR }; if (m_state == CLOSING || m_state == CLOSED) { - unsigned long payloadSize = static_cast<unsigned long>(binaryData->size()); + unsigned payloadSize = static_cast<unsigned>(binaryData.size()); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, payloadSize); m_bufferedAmountAfterClose = saturateAdd(m_bufferedAmountAfterClose, getFramingOverhead(payloadSize)); - return; + return { }; } ASSERT(m_channel); - m_channel->send(*binaryData); + m_channel->send(binaryData); + return { }; } -void WebSocket::close(int code, const String& reason, ExceptionCode& ec) +ExceptionOr<void> WebSocket::close(std::optional<unsigned short> optionalCode, const String& reason) { + int code = optionalCode ? optionalCode.value() : static_cast<int>(WebSocketChannel::CloseEventCodeNotSpecified); if (code == WebSocketChannel::CloseEventCodeNotSpecified) LOG(Network, "WebSocket %p close() without code and reason", this); else { LOG(Network, "WebSocket %p close() code=%d reason='%s'", this, code, reason.utf8().data()); - if (!(code == WebSocketChannel::CloseEventCodeNormalClosure || (WebSocketChannel::CloseEventCodeMinimumUserDefined <= code && code <= WebSocketChannel::CloseEventCodeMaximumUserDefined))) { - ec = INVALID_ACCESS_ERR; - return; - } + if (!(code == WebSocketChannel::CloseEventCodeNormalClosure || (WebSocketChannel::CloseEventCodeMinimumUserDefined <= code && code <= WebSocketChannel::CloseEventCodeMaximumUserDefined))) + return Exception { INVALID_ACCESS_ERR }; CString utf8 = reason.utf8(StrictConversionReplacingUnpairedSurrogatesWithFFFD); if (utf8.length() > maxReasonSizeInBytes) { - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "WebSocket close message is too long."); - ec = SYNTAX_ERR; - return; + scriptExecutionContext()->addConsoleMessage(MessageSource::JS, MessageLevel::Error, ASCIILiteral("WebSocket close message is too long.")); + return Exception { SYNTAX_ERR }; } } if (m_state == CLOSING || m_state == CLOSED) - return; + return { }; if (m_state == CONNECTING) { m_state = CLOSING; m_channel->fail("WebSocket is closed before the connection is established."); - return; + return { }; } m_state = CLOSING; if (m_channel) m_channel->close(code, reason); + return { }; } const URL& WebSocket::url() const @@ -406,7 +417,7 @@ WebSocket::State WebSocket::readyState() const return m_state; } -unsigned long WebSocket::bufferedAmount() const +unsigned WebSocket::bufferedAmount() const { return saturateAdd(m_bufferedAmount, m_bufferedAmountAfterClose); } @@ -424,26 +435,27 @@ String WebSocket::extensions() const String WebSocket::binaryType() const { switch (m_binaryType) { - case BinaryTypeBlob: - return "blob"; - case BinaryTypeArrayBuffer: - return "arraybuffer"; + case BinaryType::Blob: + return ASCIILiteral("blob"); + case BinaryType::ArrayBuffer: + return ASCIILiteral("arraybuffer"); } ASSERT_NOT_REACHED(); return String(); } -void WebSocket::setBinaryType(const String& binaryType) +ExceptionOr<void> WebSocket::setBinaryType(const String& binaryType) { if (binaryType == "blob") { - m_binaryType = BinaryTypeBlob; - return; + m_binaryType = BinaryType::Blob; + return { }; } if (binaryType == "arraybuffer") { - m_binaryType = BinaryTypeArrayBuffer; - return; + m_binaryType = BinaryType::ArrayBuffer; + return { }; } - scriptExecutionContext()->addConsoleMessage(JSMessageSource, ErrorMessageLevel, "'" + binaryType + "' is not a valid value for binaryType; binaryType remains unchanged."); + scriptExecutionContext()->addConsoleMessage(MessageSource::JS, MessageLevel::Error, "'" + binaryType + "' is not a valid value for binaryType; binaryType remains unchanged."); + return Exception { SYNTAX_ERR }; } EventTargetInterface WebSocket::eventTargetInterface() const @@ -464,21 +476,49 @@ void WebSocket::contextDestroyed() ActiveDOMObject::contextDestroyed(); } -bool WebSocket::canSuspend() const +bool WebSocket::canSuspendForDocumentSuspension() const { - return !m_channel; + return true; } -void WebSocket::suspend(ReasonForSuspension) +void WebSocket::suspend(ReasonForSuspension reason) { - if (m_channel) - m_channel->suspend(); + if (m_resumeTimer.isActive()) + m_resumeTimer.stop(); + + m_shouldDelayEventFiring = true; + + if (m_channel) { + if (reason == ActiveDOMObject::PageCache) { + // This will cause didClose() to be called. + m_channel->fail("WebSocket is closed due to suspension."); + } else + m_channel->suspend(); + } } void WebSocket::resume() { if (m_channel) m_channel->resume(); + else if (!m_pendingEvents.isEmpty() && !m_resumeTimer.isActive()) { + // Fire the pending events in a timer as we are not allowed to execute arbitrary JS from resume(). + m_resumeTimer.startOneShot(0); + } + + m_shouldDelayEventFiring = false; +} + +void WebSocket::resumeTimerFired() +{ + Ref<WebSocket> protectedThis(*this); + + ASSERT(!m_pendingEvents.isEmpty()); + + // Check m_shouldDelayEventFiring when iterating in case firing an event causes + // suspend() to be called. + while (!m_pendingEvents.isEmpty() && !m_shouldDelayEventFiring) + dispatchEvent(m_pendingEvents.takeFirst()); } void WebSocket::stop() @@ -486,18 +526,24 @@ void WebSocket::stop() bool pending = hasPendingActivity(); if (m_channel) m_channel->disconnect(); - m_channel = 0; + m_channel = nullptr; m_state = CLOSED; + m_pendingEvents.clear(); ActiveDOMObject::stop(); if (pending) ActiveDOMObject::unsetPendingActivity(this); } +const char* WebSocket::activeDOMObjectName() const +{ + return "WebSocket"; +} + void WebSocket::didConnect() { LOG(Network, "WebSocket %p didConnect()", this); if (m_state != CONNECTING) { - didClose(0, ClosingHandshakeIncomplete, WebSocketChannel::CloseEventCodeAbnormalClosure, ""); + didClose(0, ClosingHandshakeIncomplete, WebSocketChannel::CloseEventCodeAbnormalClosure, emptyString()); return; } ASSERT(scriptExecutionContext()); @@ -516,23 +562,16 @@ void WebSocket::didReceiveMessage(const String& msg) dispatchEvent(MessageEvent::create(msg, SecurityOrigin::create(m_url)->toString())); } -void WebSocket::didReceiveBinaryData(PassOwnPtr<Vector<char>> binaryData) +void WebSocket::didReceiveBinaryData(Vector<uint8_t>&& binaryData) { - LOG(Network, "WebSocket %p didReceiveBinaryData() %lu byte binary message", this, static_cast<unsigned long>(binaryData->size())); + LOG(Network, "WebSocket %p didReceiveBinaryData() %u byte binary message", this, static_cast<unsigned>(binaryData.size())); switch (m_binaryType) { - case BinaryTypeBlob: { - size_t size = binaryData->size(); - RefPtr<RawData> rawData = RawData::create(); - binaryData->swap(*rawData->mutableData()); - auto blobData = std::make_unique<BlobData>(); - blobData->appendData(rawData.release(), 0, BlobDataItem::toEndOfFile); - RefPtr<Blob> blob = Blob::create(std::move(blobData), size); - dispatchEvent(MessageEvent::create(blob.release(), SecurityOrigin::create(m_url)->toString())); + case BinaryType::Blob: + // FIXME: We just received the data from NetworkProcess, and are sending it back. This is inefficient. + dispatchEvent(MessageEvent::create(Blob::create(WTFMove(binaryData), emptyString()), SecurityOrigin::create(m_url)->toString())); break; - } - - case BinaryTypeArrayBuffer: - dispatchEvent(MessageEvent::create(ArrayBuffer::create(binaryData->data(), binaryData->size()), SecurityOrigin::create(m_url)->toString())); + case BinaryType::ArrayBuffer: + dispatchEvent(MessageEvent::create(ArrayBuffer::create(binaryData.data(), binaryData.size()), SecurityOrigin::create(m_url)->toString())); break; } } @@ -540,13 +579,14 @@ void WebSocket::didReceiveBinaryData(PassOwnPtr<Vector<char>> binaryData) void WebSocket::didReceiveMessageError() { LOG(Network, "WebSocket %p didReceiveErrorMessage()", this); + m_state = CLOSED; ASSERT(scriptExecutionContext()); - dispatchEvent(Event::create(eventNames().errorEvent, false, false)); + dispatchOrQueueErrorEvent(); } -void WebSocket::didUpdateBufferedAmount(unsigned long bufferedAmount) +void WebSocket::didUpdateBufferedAmount(unsigned bufferedAmount) { - LOG(Network, "WebSocket %p didUpdateBufferedAmount() New bufferedAmount is %lu", this, bufferedAmount); + LOG(Network, "WebSocket %p didUpdateBufferedAmount() New bufferedAmount is %u", this, bufferedAmount); if (m_state == CLOSED) return; m_bufferedAmount = bufferedAmount; @@ -558,7 +598,7 @@ void WebSocket::didStartClosingHandshake() m_state = CLOSING; } -void WebSocket::didClose(unsigned long unhandledBufferedAmount, ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) +void WebSocket::didClose(unsigned unhandledBufferedAmount, ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) { LOG(Network, "WebSocket %p didClose()", this); if (!m_channel) @@ -567,16 +607,23 @@ void WebSocket::didClose(unsigned long unhandledBufferedAmount, ClosingHandshake m_state = CLOSED; m_bufferedAmount = unhandledBufferedAmount; ASSERT(scriptExecutionContext()); - RefPtr<CloseEvent> event = CloseEvent::create(wasClean, code, reason); - dispatchEvent(event); + + dispatchOrQueueEvent(CloseEvent::create(wasClean, code, reason)); + if (m_channel) { m_channel->disconnect(); - m_channel = 0; + m_channel = nullptr; } if (hasPendingActivity()) ActiveDOMObject::unsetPendingActivity(this); } +void WebSocket::didUpgradeURL() +{ + ASSERT(m_url.protocolIs("ws")); + m_url.setProtocol("wss"); +} + size_t WebSocket::getFramingOverhead(size_t payloadSize) { static const size_t hybiBaseFramingOverhead = 2; // Every frame has at least two-byte header. @@ -591,6 +638,23 @@ size_t WebSocket::getFramingOverhead(size_t payloadSize) return overhead; } +void WebSocket::dispatchOrQueueErrorEvent() +{ + if (m_dispatchedErrorEvent) + return; + + m_dispatchedErrorEvent = true; + dispatchOrQueueEvent(Event::create(eventNames().errorEvent, false, false)); +} + +void WebSocket::dispatchOrQueueEvent(Ref<Event>&& event) +{ + if (m_shouldDelayEventFiring) + m_pendingEvents.append(WTFMove(event)); + else + dispatchEvent(event); +} + } // namespace WebCore #endif diff --git a/Source/WebCore/Modules/websockets/WebSocket.h b/Source/WebCore/Modules/websockets/WebSocket.h index 75482d811..7d43f0436 100644 --- a/Source/WebCore/Modules/websockets/WebSocket.h +++ b/Source/WebCore/Modules/websockets/WebSocket.h @@ -28,37 +28,38 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocket_h -#define WebSocket_h +#pragma once #if ENABLE(WEB_SOCKETS) #include "ActiveDOMObject.h" -#include "EventListener.h" -#include "EventNames.h" #include "EventTarget.h" +#include "ExceptionOr.h" +#include "Timer.h" #include "URL.h" -#include "WebSocketChannel.h" #include "WebSocketChannelClient.h" -#include <wtf/Forward.h> -#include <wtf/OwnPtr.h> -#include <wtf/RefCounted.h> -#include <wtf/text/AtomicStringHash.h> +#include <wtf/Deque.h> + +namespace JSC { +class ArrayBuffer; +class ArrayBufferView; +} namespace WebCore { class Blob; class ThreadableWebSocketChannel; -class WebSocket final : public RefCounted<WebSocket>, public EventTargetWithInlineData, public ActiveDOMObject, public WebSocketChannelClient { +class WebSocket final : public RefCounted<WebSocket>, public EventTargetWithInlineData, public ActiveDOMObject, private WebSocketChannelClient { public: static void setIsAvailable(bool); static bool isAvailable(); - static const char* subProtocolSeperator(); - static PassRefPtr<WebSocket> create(ScriptExecutionContext&); - static PassRefPtr<WebSocket> create(ScriptExecutionContext&, const String& url, ExceptionCode&); - static PassRefPtr<WebSocket> create(ScriptExecutionContext&, const String& url, const String& protocol, ExceptionCode&); - static PassRefPtr<WebSocket> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols, ExceptionCode&); + + static const char* subprotocolSeparator(); + + static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url); + static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const String& protocol); + static ExceptionOr<Ref<WebSocket>> create(ScriptExecutionContext&, const String& url, const Vector<String>& protocols); virtual ~WebSocket(); enum State { @@ -68,83 +69,79 @@ public: CLOSED = 3 }; - void connect(const String& url, ExceptionCode&); - void connect(const String& url, const String& protocol, ExceptionCode&); - void connect(const String& url, const Vector<String>& protocols, ExceptionCode&); + ExceptionOr<void> connect(const String& url); + ExceptionOr<void> connect(const String& url, const String& protocol); + ExceptionOr<void> connect(const String& url, const Vector<String>& protocols); - void send(const String& message, ExceptionCode&); - void send(JSC::ArrayBuffer*, ExceptionCode&); - void send(JSC::ArrayBufferView*, ExceptionCode&); - void send(Blob*, ExceptionCode&); + ExceptionOr<void> send(const String& message); + ExceptionOr<void> send(JSC::ArrayBuffer&); + ExceptionOr<void> send(JSC::ArrayBufferView&); + ExceptionOr<void> send(Blob&); - void close(int code, const String& reason, ExceptionCode&); - void close(ExceptionCode& ec) { close(WebSocketChannel::CloseEventCodeNotSpecified, String(), ec); } - void close(int code, ExceptionCode& ec) { close(code, String(), ec); } + ExceptionOr<void> close(std::optional<unsigned short> code, const String& reason); const URL& url() const; State readyState() const; - unsigned long bufferedAmount() const; + unsigned bufferedAmount() const; String protocol() const; String extensions() const; String binaryType() const; - void setBinaryType(const String&); + ExceptionOr<void> setBinaryType(const String&); - DEFINE_ATTRIBUTE_EVENT_LISTENER(open); - DEFINE_ATTRIBUTE_EVENT_LISTENER(message); - DEFINE_ATTRIBUTE_EVENT_LISTENER(error); - DEFINE_ATTRIBUTE_EVENT_LISTENER(close); + using RefCounted::ref; + using RefCounted::deref; - // EventTarget functions. - virtual EventTargetInterface eventTargetInterface() const override; - virtual ScriptExecutionContext* scriptExecutionContext() const override; +private: + explicit WebSocket(ScriptExecutionContext&); - using RefCounted<WebSocket>::ref; - using RefCounted<WebSocket>::deref; + void resumeTimerFired(); + void dispatchOrQueueErrorEvent(); + void dispatchOrQueueEvent(Ref<Event>&&); - // WebSocketChannelClient functions. - virtual void didConnect() override; - virtual void didReceiveMessage(const String& message) override; - virtual void didReceiveBinaryData(PassOwnPtr<Vector<char>>) override; - virtual void didReceiveMessageError() override; - virtual void didUpdateBufferedAmount(unsigned long bufferedAmount) override; - virtual void didStartClosingHandshake() override; - virtual void didClose(unsigned long unhandledBufferedAmount, ClosingHandshakeCompletionStatus, unsigned short code, const String& reason) override; + void contextDestroyed() final; + bool canSuspendForDocumentSuspension() const final; + void suspend(ReasonForSuspension) final; + void resume() final; + void stop() final; + const char* activeDOMObjectName() const final; -private: - explicit WebSocket(ScriptExecutionContext&); + EventTargetInterface eventTargetInterface() const final; + ScriptExecutionContext* scriptExecutionContext() const final; - // ActiveDOMObject functions. - virtual void contextDestroyed() override; - virtual bool canSuspend() const override; - virtual void suspend(ReasonForSuspension) override; - virtual void resume() override; - virtual void stop() override; + void refEventTarget() final { ref(); } + void derefEventTarget() final { deref(); } - virtual void refEventTarget() override { ref(); } - virtual void derefEventTarget() override { deref(); } + void didConnect() final; + void didReceiveMessage(const String& message) final; + void didReceiveBinaryData(Vector<uint8_t>&&) final; + void didReceiveMessageError() final; + void didUpdateBufferedAmount(unsigned bufferedAmount) final; + void didStartClosingHandshake() final; + void didClose(unsigned unhandledBufferedAmount, ClosingHandshakeCompletionStatus, unsigned short code, const String& reason) final; + void didUpgradeURL() final; size_t getFramingOverhead(size_t payloadSize); - enum BinaryType { - BinaryTypeBlob, - BinaryTypeArrayBuffer - }; + enum class BinaryType { Blob, ArrayBuffer }; RefPtr<ThreadableWebSocketChannel> m_channel; - State m_state; + State m_state { CONNECTING }; URL m_url; - unsigned long m_bufferedAmount; - unsigned long m_bufferedAmountAfterClose; - BinaryType m_binaryType; + unsigned m_bufferedAmount { 0 }; + unsigned m_bufferedAmountAfterClose { 0 }; + BinaryType m_binaryType { BinaryType::Blob }; String m_subprotocol; String m_extensions; + + Timer m_resumeTimer; + bool m_shouldDelayEventFiring { false }; + Deque<Ref<Event>> m_pendingEvents; + bool m_dispatchedErrorEvent { false }; }; } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocket_h diff --git a/Source/WebCore/Modules/websockets/WebSocket.idl b/Source/WebCore/Modules/websockets/WebSocket.idl index 59de969b9..d71ed2b86 100644 --- a/Source/WebCore/Modules/websockets/WebSocket.idl +++ b/Source/WebCore/Modules/websockets/WebSocket.idl @@ -30,22 +30,18 @@ */ [ - GlobalContext=DOMWindow&WorkerGlobalScope, - EnabledAtRuntime, - Conditional=WEB_SOCKETS, ActiveDOMObject, - Constructor(DOMString url), - Constructor(DOMString url, sequence<DOMString> protocols), - Constructor(DOMString url, DOMString protocol), - ConstructorRaisesException, + Conditional=WEB_SOCKETS, + Constructor(USVString url, optional sequence<DOMString> protocols = []), + Constructor(USVString url, DOMString protocol), + ConstructorMayThrowException, ConstructorCallWith=ScriptExecutionContext, - EventTarget, - JSNoStaticTables, -] interface WebSocket { - readonly attribute DOMString URL; // Lowercased .url is the one in the spec, but leaving .URL for compatibility reasons. - readonly attribute DOMString url; + EnabledAtRuntime, + Exposed=(Window,Worker), +] interface WebSocket : EventTarget { + readonly attribute USVString URL; // Lowercased .url is the one in the spec, but leaving .URL for compatibility reasons. + readonly attribute USVString url; - // ready state const unsigned short CONNECTING = 0; const unsigned short OPEN = 1; const unsigned short CLOSING = 2; @@ -54,30 +50,20 @@ readonly attribute unsigned long bufferedAmount; - // networking - attribute EventListener onopen; - attribute EventListener onmessage; - attribute EventListener onerror; - attribute EventListener onclose; - - [TreatReturnedNullStringAs=Undefined] readonly attribute DOMString protocol; - [TreatReturnedNullStringAs=Undefined] readonly attribute DOMString extensions; + attribute EventHandler onopen; + attribute EventHandler onmessage; + attribute EventHandler onerror; + attribute EventHandler onclose; - attribute DOMString binaryType; + readonly attribute DOMString? protocol; + readonly attribute DOMString? extensions; - [RaisesException] void send(ArrayBuffer data); - [RaisesException] void send(ArrayBufferView data); - [RaisesException] void send(Blob data); - [RaisesException] void send(DOMString data); + [SetterMayThrowException] attribute DOMString binaryType; - [RaisesException] void close([Clamp] optional unsigned short code, optional DOMString reason); + [MayThrowException] void send(ArrayBuffer data); + [MayThrowException] void send(ArrayBufferView data); + [MayThrowException] void send(Blob data); + [MayThrowException] void send(USVString data); - // EventTarget interface - void addEventListener(DOMString type, - EventListener listener, - optional boolean useCapture); - void removeEventListener(DOMString type, - EventListener listener, - optional boolean useCapture); - [RaisesException] boolean dispatchEvent(Event evt); + [MayThrowException] void close([Clamp] optional unsigned short code, optional DOMString reason); }; diff --git a/Source/WebCore/Modules/websockets/WebSocketChannel.cpp b/Source/WebCore/Modules/websockets/WebSocketChannel.cpp index e13d63885..e76f6b9cf 100644 --- a/Source/WebCore/Modules/websockets/WebSocketChannel.cpp +++ b/Source/WebCore/Modules/websockets/WebSocketChannel.cpp @@ -37,31 +37,24 @@ #include "Blob.h" #include "CookieJar.h" #include "Document.h" -#include "ExceptionCodePlaceholder.h" #include "FileError.h" #include "FileReaderLoader.h" #include "Frame.h" -#include "FrameLoader.h" -#include "FrameLoaderClient.h" #include "InspectorInstrumentation.h" #include "Logging.h" #include "Page.h" #include "ProgressTracker.h" #include "ResourceRequest.h" -#include "ScriptCallStack.h" #include "ScriptExecutionContext.h" -#include "Settings.h" +#include "SocketProvider.h" #include "SocketStreamError.h" #include "SocketStreamHandle.h" +#include "UserContentProvider.h" #include "WebSocketChannelClient.h" #include "WebSocketHandshake.h" - #include <runtime/ArrayBuffer.h> -#include <wtf/Deque.h> #include <wtf/FastMalloc.h> #include <wtf/HashMap.h> -#include <wtf/OwnPtr.h> -#include <wtf/PassOwnPtr.h> #include <wtf/text/CString.h> #include <wtf/text/StringHash.h> #include <wtf/text/WTFString.h> @@ -70,29 +63,17 @@ namespace WebCore { const double TCPMaximumSegmentLifetime = 2 * 60.0; -WebSocketChannel::WebSocketChannel(Document* document, WebSocketChannelClient* client) - : m_document(document) - , m_client(client) - , m_resumeTimer(this, &WebSocketChannel::resumeTimerFired) - , m_suspended(false) - , m_closing(false) - , m_receivedClosingHandshake(false) - , m_closingTimer(this, &WebSocketChannel::closingTimerFired) - , m_closed(false) - , m_shouldDiscardReceivedData(false) - , m_unhandledBufferedAmount(0) - , m_identifier(0) - , m_hasContinuousFrame(false) - , m_closeEventCode(CloseEventCodeAbnormalClosure) - , m_outgoingFrameQueueStatus(OutgoingFrameQueueOpen) -#if ENABLE(BLOB) - , m_blobLoaderStatus(BlobLoaderNotStarted) -#endif +WebSocketChannel::WebSocketChannel(Document& document, WebSocketChannelClient& client, SocketProvider& provider) + : m_document(&document) + , m_client(&client) + , m_resumeTimer(*this, &WebSocketChannel::resumeTimerFired) + , m_closingTimer(*this, &WebSocketChannel::closingTimerFired) + , m_socketProvider(provider) { - if (Page* page = m_document->page()) + if (Page* page = document.page()) m_identifier = page->progress().createUniqueIdentifier(); - LOG(Network, "WebSocketChannel %p ctor, identifier %lu", this, m_identifier); + LOG(Network, "WebSocketChannel %p ctor, identifier %u", this, m_identifier); } WebSocketChannel::~WebSocketChannel() @@ -100,29 +81,62 @@ WebSocketChannel::~WebSocketChannel() LOG(Network, "WebSocketChannel %p dtor", this); } -void WebSocketChannel::connect(const URL& url, const String& protocol) +void WebSocketChannel::connect(const URL& requestedURL, const String& protocol) { LOG(Network, "WebSocketChannel %p connect()", this); + + URL url = requestedURL; + bool allowCookies = true; +#if ENABLE(CONTENT_EXTENSIONS) + if (auto* page = m_document->page()) { + if (auto* documentLoader = m_document->loader()) { + auto blockedStatus = page->userContentProvider().processContentExtensionRulesForLoad(url, ResourceType::Raw, *documentLoader); + if (blockedStatus.blockedLoad) { + Ref<WebSocketChannel> protectedThis(*this); + callOnMainThread([protectedThis = WTFMove(protectedThis)] { + if (protectedThis->m_client) + protectedThis->m_client->didReceiveMessageError(); + }); + return; + } + if (blockedStatus.madeHTTPS) { + ASSERT(url.protocolIs("ws")); + url.setProtocol("wss"); + if (m_client) + m_client->didUpgradeURL(); + } + if (blockedStatus.blockedCookies) + allowCookies = false; + } + } +#endif + ASSERT(!m_handle); ASSERT(!m_suspended); - m_handshake = adoptPtr(new WebSocketHandshake(url, protocol, m_document)); + m_handshake = std::make_unique<WebSocketHandshake>(url, protocol, m_document, allowCookies); m_handshake->reset(); if (m_deflateFramer.canDeflate()) m_handshake->addExtensionProcessor(m_deflateFramer.createExtensionProcessor()); if (m_identifier) - InspectorInstrumentation::didCreateWebSocket(m_document, m_identifier, url, m_document->url(), protocol); - ref(); - m_handle = SocketStreamHandle::create(m_handshake->url(), this); + InspectorInstrumentation::didCreateWebSocket(m_document, m_identifier, url); + + if (Frame* frame = m_document->frame()) { + ref(); + Page* page = frame->page(); + SessionID sessionID = page ? page->sessionID() : SessionID::defaultSessionID(); + String partition = m_document->topDocument().securityOrigin().domainForCachePartition(); + m_handle = m_socketProvider->createSocketStreamHandle(m_handshake->url(), *this, sessionID, partition); + } } String WebSocketChannel::subprotocol() { LOG(Network, "WebSocketChannel %p subprotocol()", this); if (!m_handshake || m_handshake->mode() != WebSocketHandshake::Connected) - return ""; + return emptyString(); String serverProtocol = m_handshake->serverWebSocketProtocol(); if (serverProtocol.isNull()) - return ""; + return emptyString(); return serverProtocol; } @@ -130,10 +144,10 @@ String WebSocketChannel::extensions() { LOG(Network, "WebSocketChannel %p extensions()", this); if (!m_handshake || m_handshake->mode() != WebSocketHandshake::Connected) - return ""; + return emptyString(); String extensions = m_handshake->acceptedExtensions(); if (extensions.isNull()) - return ""; + return emptyString(); return extensions; } @@ -160,9 +174,9 @@ ThreadableWebSocketChannel::SendResult WebSocketChannel::send(const ArrayBuffer& return ThreadableWebSocketChannel::SendSuccess; } -ThreadableWebSocketChannel::SendResult WebSocketChannel::send(const Blob& binaryData) +ThreadableWebSocketChannel::SendResult WebSocketChannel::send(Blob& binaryData) { - LOG(Network, "WebSocketChannel %p send() Sending Blob '%s'", this, binaryData.url().stringCenterEllipsizedToLength().utf8().data()); + LOG(Network, "WebSocketChannel %p send() Sending Blob '%s'", this, binaryData.url().string().utf8().data()); enqueueBlobFrame(WebSocketFrame::OpCodeBinary, binaryData); processOutgoingFrameQueue(); return ThreadableWebSocketChannel::SendSuccess; @@ -176,7 +190,7 @@ bool WebSocketChannel::send(const char* data, int length) return true; } -unsigned long WebSocketChannel::bufferedAmount() const +unsigned WebSocketChannel::bufferedAmount() const { LOG(Network, "WebSocketChannel %p bufferedAmount()", this); ASSERT(m_handle); @@ -190,7 +204,7 @@ void WebSocketChannel::close(int code, const String& reason) ASSERT(!m_suspended); if (!m_handle) return; - Ref<WebSocketChannel> protect(*this); // An attempt to send closing handshake may fail, which will get the channel closed and dereferenced. + Ref<WebSocketChannel> protectedThis(*this); // An attempt to send closing handshake may fail, which will get the channel closed and dereferenced. startClosingHandshake(code, reason); if (m_closing && !m_closingTimer.isActive()) m_closingTimer.startOneShot(2 * TCPMaximumSegmentLifetime); @@ -202,12 +216,19 @@ void WebSocketChannel::fail(const String& reason) ASSERT(!m_suspended); if (m_document) { InspectorInstrumentation::didReceiveWebSocketFrameError(m_document, m_identifier, reason); - m_document->addConsoleMessage(NetworkMessageSource, ErrorMessageLevel, "WebSocket connection to '" + m_handshake->url().stringCenterEllipsizedToLength() + "' failed: " + reason); + + String consoleMessage; + if (m_handshake) + consoleMessage = makeString("WebSocket connection to '", m_handshake->url().stringCenterEllipsizedToLength(), "' failed: ", reason); + else + consoleMessage = makeString("WebSocket connection failed: ", reason); + + m_document->addConsoleMessage(MessageSource::Network, MessageLevel::Error, consoleMessage); } // Hybi-10 specification explicitly states we must not continue to handle incoming data // once the WebSocket connection is failed (section 7.1.7). - Ref<WebSocketChannel> protect(*this); // The client can close the channel, potentially removing the last reference. + Ref<WebSocketChannel> protectedThis(*this); // The client can close the channel, potentially removing the last reference. m_shouldDiscardReceivedData = true; if (!m_buffer.isEmpty()) skipBuffer(m_buffer.size()); // Save memory. @@ -219,7 +240,8 @@ void WebSocketChannel::fail(const String& reason) if (m_handle && !m_closed) m_handle->disconnect(); // Will call didClose(). - ASSERT(m_closed); + // We should be closed by now, but if we never got a handshake then we never even opened. + ASSERT(m_closed || !m_handshake); } void WebSocketChannel::disconnect() @@ -228,9 +250,9 @@ void WebSocketChannel::disconnect() if (m_identifier && m_document) InspectorInstrumentation::didCloseWebSocket(m_document, m_identifier); if (m_handshake) - m_handshake->clearScriptExecutionContext(); - m_client = 0; - m_document = 0; + m_handshake->clearDocument(); + m_client = nullptr; + m_document = nullptr; if (m_handle) m_handle->disconnect(); } @@ -247,33 +269,25 @@ void WebSocketChannel::resume() m_resumeTimer.startOneShot(0); } -void WebSocketChannel::willOpenSocketStream(SocketStreamHandle* handle) -{ - LOG(Network, "WebSocketChannel %p willOpenSocketStream()", this); - ASSERT(handle); - if (m_document->frame()) - m_document->frame()->loader().client().dispatchWillOpenSocketStream(handle); -} - -void WebSocketChannel::didOpenSocketStream(SocketStreamHandle* handle) +void WebSocketChannel::didOpenSocketStream(SocketStreamHandle& handle) { LOG(Network, "WebSocketChannel %p didOpenSocketStream()", this); - ASSERT(handle == m_handle); + ASSERT(&handle == m_handle); if (!m_document) return; if (m_identifier) InspectorInstrumentation::willSendWebSocketHandshakeRequest(m_document, m_identifier, m_handshake->clientHandshakeRequest()); CString handshakeMessage = m_handshake->clientHandshakeMessage(); - if (!handle->send(handshakeMessage.data(), handshakeMessage.length())) + if (!handle.send(handshakeMessage.data(), handshakeMessage.length())) fail("Failed to send WebSocket handshake."); } -void WebSocketChannel::didCloseSocketStream(SocketStreamHandle* handle) +void WebSocketChannel::didCloseSocketStream(SocketStreamHandle& handle) { LOG(Network, "WebSocketChannel %p didCloseSocketStream()", this); if (m_identifier && m_document) InspectorInstrumentation::didCloseWebSocket(m_document, m_identifier); - ASSERT_UNUSED(handle, handle == m_handle || !m_handle); + ASSERT_UNUSED(handle, &handle == m_handle || !m_handle); m_closed = true; if (m_closingTimer.isActive()) m_closingTimer.stop(); @@ -284,54 +298,58 @@ void WebSocketChannel::didCloseSocketStream(SocketStreamHandle* handle) if (m_suspended) return; WebSocketChannelClient* client = m_client; - m_client = 0; - m_document = 0; - m_handle = 0; + m_client = nullptr; + m_document = nullptr; + m_handle = nullptr; if (client) client->didClose(m_unhandledBufferedAmount, m_receivedClosingHandshake ? WebSocketChannelClient::ClosingHandshakeComplete : WebSocketChannelClient::ClosingHandshakeIncomplete, m_closeEventCode, m_closeEventReason); } deref(); } -void WebSocketChannel::didReceiveSocketStreamData(SocketStreamHandle* handle, const char* data, int len) +void WebSocketChannel::didReceiveSocketStreamData(SocketStreamHandle& handle, const char* data, std::optional<size_t> len) { - LOG(Network, "WebSocketChannel %p didReceiveSocketStreamData() Received %d bytes", this, len); - Ref<WebSocketChannel> protect(*this); // The client can close the channel, potentially removing the last reference. - ASSERT(handle == m_handle); + if (len) + LOG(Network, "WebSocketChannel %p didReceiveSocketStreamData() Received %zu bytes", this, len.value()); + else + LOG(Network, "WebSocketChannel %p didReceiveSocketStreamData() Received no bytes", this); + Ref<WebSocketChannel> protectedThis(*this); // The client can close the channel, potentially removing the last reference. + ASSERT(&handle == m_handle); if (!m_document) { return; } - if (len <= 0) { - handle->disconnect(); + if (!len || !len.value()) { + handle.disconnect(); return; } if (!m_client) { m_shouldDiscardReceivedData = true; - handle->disconnect(); + handle.disconnect(); return; } if (m_shouldDiscardReceivedData) return; - if (!appendToBuffer(data, len)) { + if (!appendToBuffer(data, len.value())) { m_shouldDiscardReceivedData = true; fail("Ran out of memory while receiving WebSocket data."); return; } - while (!m_suspended && m_client && !m_buffer.isEmpty()) + while (!m_suspended && m_client && !m_buffer.isEmpty()) { if (!processBuffer()) break; + } } -void WebSocketChannel::didUpdateBufferedAmount(SocketStreamHandle*, size_t bufferedAmount) +void WebSocketChannel::didUpdateBufferedAmount(SocketStreamHandle&, size_t bufferedAmount) { if (m_client) m_client->didUpdateBufferedAmount(bufferedAmount); } -void WebSocketChannel::didFailSocketStream(SocketStreamHandle* handle, const SocketStreamError& error) +void WebSocketChannel::didFailSocketStream(SocketStreamHandle& handle, const SocketStreamError& error) { LOG(Network, "WebSocketChannel %p didFailSocketStream()", this); - ASSERT(handle == m_handle || !m_handle); + ASSERT(&handle == m_handle || !m_handle); if (m_document) { String message; if (error.isNull()) @@ -341,21 +359,12 @@ void WebSocketChannel::didFailSocketStream(SocketStreamHandle* handle, const Soc else message = "WebSocket network error: " + error.localizedDescription(); InspectorInstrumentation::didReceiveWebSocketFrameError(m_document, m_identifier, message); - m_document->addConsoleMessage(NetworkMessageSource, ErrorMessageLevel, message); + m_document->addConsoleMessage(MessageSource::Network, MessageLevel::Error, message); } m_shouldDiscardReceivedData = true; - handle->disconnect(); -} - -void WebSocketChannel::didReceiveAuthenticationChallenge(SocketStreamHandle*, const AuthenticationChallenge&) -{ + handle.disconnect(); } -void WebSocketChannel::didCancelAuthenticationChallenge(SocketStreamHandle*, const AuthenticationChallenge&) -{ -} - -#if ENABLE(BLOB) void WebSocketChannel::didStartLoading() { LOG(Network, "WebSocketChannel %p didStartLoading()", this); @@ -385,18 +394,17 @@ void WebSocketChannel::didFail(int errorCode) LOG(Network, "WebSocketChannel %p didFail() errorCode=%d", this, errorCode); ASSERT(m_blobLoader); ASSERT(m_blobLoaderStatus == BlobLoaderStarted); - m_blobLoader.clear(); + m_blobLoader = nullptr; m_blobLoaderStatus = BlobLoaderFailed; fail("Failed to load Blob: error code = " + String::number(errorCode)); // FIXME: Generate human-friendly reason message. deref(); } -#endif bool WebSocketChannel::appendToBuffer(const char* data, size_t len) { size_t newBufferSize = m_buffer.size() + len; if (newBufferSize < m_buffer.size()) { - LOG(Network, "WebSocketChannel %p appendToBuffer() Buffer overflow (%lu bytes already in receive buffer and appending %lu bytes)", this, static_cast<unsigned long>(m_buffer.size()), static_cast<unsigned long>(len)); + LOG(Network, "WebSocketChannel %p appendToBuffer() Buffer overflow (%u bytes already in receive buffer and appending %u bytes)", this, static_cast<unsigned>(m_buffer.size()), static_cast<unsigned>(len)); return false; } m_buffer.append(data, len); @@ -415,7 +423,7 @@ bool WebSocketChannel::processBuffer() ASSERT(!m_suspended); ASSERT(m_client); ASSERT(!m_buffer.isEmpty()); - LOG(Network, "WebSocketChannel %p processBuffer() Receive buffer has %lu bytes", this, static_cast<unsigned long>(m_buffer.size())); + LOG(Network, "WebSocketChannel %p processBuffer() Receive buffer has %u bytes", this, static_cast<unsigned>(m_buffer.size())); if (m_shouldDiscardReceivedData) return false; @@ -425,7 +433,7 @@ bool WebSocketChannel::processBuffer() return false; } - Ref<WebSocketChannel> protect(*this); // The client can close the channel, potentially removing the last reference. + Ref<WebSocketChannel> protectedThis(*this); // The client can close the channel, potentially removing the last reference. if (m_handshake->mode() == WebSocketHandshake::Incomplete) { int headerLength = m_handshake->readServerHandshake(m_buffer.data(), m_buffer.size()); @@ -435,16 +443,16 @@ bool WebSocketChannel::processBuffer() if (m_identifier) InspectorInstrumentation::didReceiveWebSocketHandshakeResponse(m_document, m_identifier, m_handshake->serverHandshakeResponse()); if (!m_handshake->serverSetCookie().isEmpty()) { - if (cookiesEnabled(m_document)) { + if (m_document && cookiesEnabled(*m_document)) { // Exception (for sandboxed documents) ignored. - m_document->setCookie(m_handshake->serverSetCookie(), IGNORE_EXCEPTION); + m_document->setCookie(m_handshake->serverSetCookie()); } } // FIXME: handle set-cookie2. LOG(Network, "WebSocketChannel %p Connected", this); skipBuffer(headerLength); m_client->didConnect(); - LOG(Network, "WebSocketChannel %p %lu bytes remaining in m_buffer", this, static_cast<unsigned long>(m_buffer.size())); + LOG(Network, "WebSocketChannel %p %u bytes remaining in m_buffer", this, static_cast<unsigned>(m_buffer.size())); return !m_buffer.isEmpty(); } ASSERT(m_handshake->mode() == WebSocketHandshake::Failed); @@ -460,16 +468,14 @@ bool WebSocketChannel::processBuffer() return processFrame(); } -void WebSocketChannel::resumeTimerFired(Timer<WebSocketChannel>* timer) +void WebSocketChannel::resumeTimerFired() { - ASSERT_UNUSED(timer, timer == &m_resumeTimer); - - Ref<WebSocketChannel> protect(*this); // The client can close the channel, potentially removing the last reference. + Ref<WebSocketChannel> protectedThis(*this); // The client can close the channel, potentially removing the last reference. while (!m_suspended && m_client && !m_buffer.isEmpty()) if (!processBuffer()) break; if (!m_suspended && m_client && m_closed && m_handle) - didCloseSocketStream(m_handle.get()); + didCloseSocketStream(*m_handle); } void WebSocketChannel::startClosingHandshake(int code, const String& reason) @@ -489,7 +495,7 @@ void WebSocketChannel::startClosingHandshake(int code, const String& reason) buf.append(reason.utf8().data(), reason.utf8().length()); } enqueueRawFrame(WebSocketFrame::OpCodeClose, buf.data(), buf.size()); - Ref<WebSocketChannel> protect(*this); // An attempt to send closing handshake may fail, which will get the channel closed and dereferenced. + Ref<WebSocketChannel> protectedThis(*this); // An attempt to send closing handshake may fail, which will get the channel closed and dereferenced. processOutgoingFrameQueue(); if (m_closed) { @@ -502,10 +508,9 @@ void WebSocketChannel::startClosingHandshake(int code, const String& reason) m_client->didStartClosingHandshake(); } -void WebSocketChannel::closingTimerFired(Timer<WebSocketChannel>* timer) +void WebSocketChannel::closingTimerFired() { LOG(Network, "WebSocketChannel %p closingTimerFired()", this); - ASSERT_UNUSED(timer, &m_closingTimer == timer); if (m_handle) m_handle->disconnect(); } @@ -529,7 +534,7 @@ bool WebSocketChannel::processFrame() ASSERT(m_buffer.data() < frameEnd); ASSERT(frameEnd <= m_buffer.data() + m_buffer.size()); - OwnPtr<InflateResultHolder> inflateResult = m_deflateFramer.inflate(frame); + auto inflateResult = m_deflateFramer.inflate(frame); if (!inflateResult->succeeded()) { fail(inflateResult->failureReason()); return false; @@ -587,22 +592,20 @@ bool WebSocketChannel::processFrame() // so we should pretend that we have finished to read this frame and // make sure that the member variables are in a consistent state before // the handler is invoked. - // Vector<char>::swap() is used here to clear m_continuousFrameData. - OwnPtr<Vector<char>> continuousFrameData = adoptPtr(new Vector<char>); - m_continuousFrameData.swap(*continuousFrameData); + Vector<uint8_t> continuousFrameData = WTFMove(m_continuousFrameData); m_hasContinuousFrame = false; if (m_continuousFrameOpCode == WebSocketFrame::OpCodeText) { String message; - if (continuousFrameData->size()) - message = String::fromUTF8(continuousFrameData->data(), continuousFrameData->size()); + if (continuousFrameData.size()) + message = String::fromUTF8(continuousFrameData.data(), continuousFrameData.size()); else - message = ""; + message = emptyString(); if (message.isNull()) fail("Could not decode a text frame as UTF-8."); else m_client->didReceiveMessage(message); } else if (m_continuousFrameOpCode == WebSocketFrame::OpCodeBinary) - m_client->didReceiveBinaryData(continuousFrameData.release()); + m_client->didReceiveBinaryData(WTFMove(continuousFrameData)); } break; @@ -612,7 +615,7 @@ bool WebSocketChannel::processFrame() if (frame.payloadLength) message = String::fromUTF8(frame.payload, frame.payloadLength); else - message = ""; + message = emptyString(); skipBuffer(frameEnd - m_buffer.data()); if (message.isNull()) fail("Could not decode a text frame as UTF-8."); @@ -629,10 +632,10 @@ bool WebSocketChannel::processFrame() case WebSocketFrame::OpCodeBinary: if (frame.final) { - OwnPtr<Vector<char>> binaryData = adoptPtr(new Vector<char>(frame.payloadLength)); - memcpy(binaryData->data(), frame.payload, frame.payloadLength); + Vector<uint8_t> binaryData(frame.payloadLength); + memcpy(binaryData.data(), frame.payload, frame.payloadLength); skipBuffer(frameEnd - m_buffer.data()); - m_client->didReceiveBinaryData(binaryData.release()); + m_client->didReceiveBinaryData(WTFMove(binaryData)); } else { m_hasContinuousFrame = true; m_continuousFrameOpCode = WebSocketFrame::OpCodeBinary; @@ -662,12 +665,13 @@ bool WebSocketChannel::processFrame() if (frame.payloadLength >= 3) m_closeEventReason = String::fromUTF8(&frame.payload[2], frame.payloadLength - 2); else - m_closeEventReason = ""; + m_closeEventReason = emptyString(); skipBuffer(frameEnd - m_buffer.data()); m_receivedClosingHandshake = true; startClosingHandshake(m_closeEventCode, m_closeEventReason); if (m_closing) { - m_outgoingFrameQueueStatus = OutgoingFrameQueueClosing; + if (m_outgoingFrameQueueStatus == OutgoingFrameQueueOpen) + m_outgoingFrameQueueStatus = OutgoingFrameQueueClosing; processOutgoingFrameQueue(); } break; @@ -696,33 +700,33 @@ bool WebSocketChannel::processFrame() void WebSocketChannel::enqueueTextFrame(const CString& string) { ASSERT(m_outgoingFrameQueueStatus == OutgoingFrameQueueOpen); - OwnPtr<QueuedFrame> frame = adoptPtr(new QueuedFrame); + auto frame = std::make_unique<QueuedFrame>(); frame->opCode = WebSocketFrame::OpCodeText; frame->frameType = QueuedFrameTypeString; frame->stringData = string; - m_outgoingFrameQueue.append(frame.release()); + m_outgoingFrameQueue.append(WTFMove(frame)); } void WebSocketChannel::enqueueRawFrame(WebSocketFrame::OpCode opCode, const char* data, size_t dataLength) { ASSERT(m_outgoingFrameQueueStatus == OutgoingFrameQueueOpen); - OwnPtr<QueuedFrame> frame = adoptPtr(new QueuedFrame); + auto frame = std::make_unique<QueuedFrame>(); frame->opCode = opCode; frame->frameType = QueuedFrameTypeVector; frame->vectorData.resize(dataLength); if (dataLength) memcpy(frame->vectorData.data(), data, dataLength); - m_outgoingFrameQueue.append(frame.release()); + m_outgoingFrameQueue.append(WTFMove(frame)); } -void WebSocketChannel::enqueueBlobFrame(WebSocketFrame::OpCode opCode, const Blob& blob) +void WebSocketChannel::enqueueBlobFrame(WebSocketFrame::OpCode opCode, Blob& blob) { ASSERT(m_outgoingFrameQueueStatus == OutgoingFrameQueueOpen); - OwnPtr<QueuedFrame> frame = adoptPtr(new QueuedFrame); + auto frame = std::make_unique<QueuedFrame>(); frame->opCode = opCode; frame->frameType = QueuedFrameTypeBlob; - frame->blobData = Blob::create(blob.url(), blob.type(), blob.size()); - m_outgoingFrameQueue.append(frame.release()); + frame->blobData = &blob; + m_outgoingFrameQueue.append(WTFMove(frame)); } void WebSocketChannel::processOutgoingFrameQueue() @@ -730,10 +734,10 @@ void WebSocketChannel::processOutgoingFrameQueue() if (m_outgoingFrameQueueStatus == OutgoingFrameQueueClosed) return; - Ref<WebSocketChannel> protect(*this); // Any call to fail() will get the channel closed and dereferenced. + Ref<WebSocketChannel> protectedThis(*this); // Any call to fail() will get the channel closed and dereferenced. while (!m_outgoingFrameQueue.isEmpty()) { - OwnPtr<QueuedFrame> frame = m_outgoingFrameQueue.takeFirst(); + auto frame = m_outgoingFrameQueue.takeFirst(); switch (frame->frameType) { case QueuedFrameTypeString: { if (!sendFrame(frame->opCode, frame->stringData.data(), frame->stringData.length())) @@ -747,34 +751,31 @@ void WebSocketChannel::processOutgoingFrameQueue() break; case QueuedFrameTypeBlob: { -#if ENABLE(BLOB) switch (m_blobLoaderStatus) { case BlobLoaderNotStarted: ref(); // Will be derefed after didFinishLoading() or didFail(). ASSERT(!m_blobLoader); - m_blobLoader = adoptPtr(new FileReaderLoader(FileReaderLoader::ReadAsArrayBuffer, this)); + ASSERT(frame->blobData); + m_blobLoader = std::make_unique<FileReaderLoader>(FileReaderLoader::ReadAsArrayBuffer, this); m_blobLoaderStatus = BlobLoaderStarted; - m_blobLoader->start(m_document, frame->blobData.get()); - m_outgoingFrameQueue.prepend(frame.release()); + m_blobLoader->start(m_document, *frame->blobData); + m_outgoingFrameQueue.prepend(WTFMove(frame)); return; case BlobLoaderStarted: case BlobLoaderFailed: - m_outgoingFrameQueue.prepend(frame.release()); + m_outgoingFrameQueue.prepend(WTFMove(frame)); return; case BlobLoaderFinished: { RefPtr<ArrayBuffer> result = m_blobLoader->arrayBufferResult(); - m_blobLoader.clear(); + m_blobLoader = nullptr; m_blobLoaderStatus = BlobLoaderNotStarted; if (!sendFrame(frame->opCode, static_cast<const char*>(result->data()), result->byteLength())) fail("Failed to send WebSocket frame."); break; } } -#else - fail("FileReader is not available. Could not send a Blob as WebSocket binary message."); -#endif break; } @@ -795,12 +796,10 @@ void WebSocketChannel::abortOutgoingFrameQueue() { m_outgoingFrameQueue.clear(); m_outgoingFrameQueueStatus = OutgoingFrameQueueClosed; -#if ENABLE(BLOB) if (m_blobLoaderStatus == BlobLoaderStarted) { m_blobLoader->cancel(); didFail(FileError::ABORT_ERR); } -#endif } bool WebSocketChannel::sendFrame(WebSocketFrame::OpCode opCode, const char* data, size_t dataLength) @@ -811,7 +810,7 @@ bool WebSocketChannel::sendFrame(WebSocketFrame::OpCode opCode, const char* data WebSocketFrame frame(opCode, true, false, true, data, dataLength); InspectorInstrumentation::didSendWebSocketFrame(m_document, m_identifier, frame); - OwnPtr<DeflateResultHolder> deflateResult = m_deflateFramer.deflate(frame); + auto deflateResult = m_deflateFramer.deflate(frame); if (!deflateResult->succeeded()) { fail(deflateResult->failureReason()); return false; diff --git a/Source/WebCore/Modules/websockets/WebSocketChannel.h b/Source/WebCore/Modules/websockets/WebSocketChannel.h index 1102f6aaf..b3a52ccc7 100644 --- a/Source/WebCore/Modules/websockets/WebSocketChannel.h +++ b/Source/WebCore/Modules/websockets/WebSocketChannel.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketChannel_h -#define WebSocketChannel_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -39,7 +38,6 @@ #include "Timer.h" #include "WebSocketDeflateFramer.h" #include "WebSocketFrame.h" -#include "WebSocketHandshake.h" #include <wtf/Deque.h> #include <wtf/Forward.h> #include <wtf/RefCounted.h> @@ -51,46 +49,42 @@ namespace WebCore { class Blob; class Document; class FileReaderLoader; +class SocketProvider; class SocketStreamHandle; class SocketStreamError; class WebSocketChannelClient; +class WebSocketHandshake; -class WebSocketChannel : public RefCounted<WebSocketChannel>, public SocketStreamHandleClient, public ThreadableWebSocketChannel -#if ENABLE(BLOB) - , public FileReaderLoaderClient -#endif +class WebSocketChannel : public RefCounted<WebSocketChannel>, public SocketStreamHandleClient, public ThreadableWebSocketChannel, public FileReaderLoaderClient { WTF_MAKE_FAST_ALLOCATED; public: - static PassRefPtr<WebSocketChannel> create(Document* document, WebSocketChannelClient* client) { return adoptRef(new WebSocketChannel(document, client)); } + static Ref<WebSocketChannel> create(Document& document, WebSocketChannelClient& client, SocketProvider& provider) { return adoptRef(*new WebSocketChannel(document, client, provider)); } virtual ~WebSocketChannel(); bool send(const char* data, int length); // ThreadableWebSocketChannel functions. - virtual void connect(const URL&, const String& protocol) override; - virtual String subprotocol() override; - virtual String extensions() override; - virtual ThreadableWebSocketChannel::SendResult send(const String& message) override; - virtual ThreadableWebSocketChannel::SendResult send(const JSC::ArrayBuffer&, unsigned byteOffset, unsigned byteLength) override; - virtual ThreadableWebSocketChannel::SendResult send(const Blob&) override; - virtual unsigned long bufferedAmount() const override; - virtual void close(int code, const String& reason) override; // Start closing handshake. - virtual void fail(const String& reason) override; - virtual void disconnect() override; - - virtual void suspend() override; - virtual void resume() override; + void connect(const URL&, const String& protocol) override; + String subprotocol() override; + String extensions() override; + ThreadableWebSocketChannel::SendResult send(const String& message) override; + ThreadableWebSocketChannel::SendResult send(const JSC::ArrayBuffer&, unsigned byteOffset, unsigned byteLength) override; + ThreadableWebSocketChannel::SendResult send(Blob&) override; + unsigned bufferedAmount() const override; + void close(int code, const String& reason) override; // Start closing handshake. + void fail(const String& reason) override; + void disconnect() override; + + void suspend() override; + void resume() override; // SocketStreamHandleClient functions. - virtual void willOpenSocketStream(SocketStreamHandle*) override; - virtual void didOpenSocketStream(SocketStreamHandle*) override; - virtual void didCloseSocketStream(SocketStreamHandle*) override; - virtual void didReceiveSocketStreamData(SocketStreamHandle*, const char*, int) override; - virtual void didUpdateBufferedAmount(SocketStreamHandle*, size_t bufferedAmount) override; - virtual void didFailSocketStream(SocketStreamHandle*, const SocketStreamError&) override; - virtual void didReceiveAuthenticationChallenge(SocketStreamHandle*, const AuthenticationChallenge&) override; - virtual void didCancelAuthenticationChallenge(SocketStreamHandle*, const AuthenticationChallenge&) override; + void didOpenSocketStream(SocketStreamHandle&) final; + void didCloseSocketStream(SocketStreamHandle&) final; + void didReceiveSocketStreamData(SocketStreamHandle&, const char*, std::optional<size_t>) final; + void didUpdateBufferedAmount(SocketStreamHandle&, size_t bufferedAmount) final; + void didFailSocketStream(SocketStreamHandle&, const SocketStreamError&) final; enum CloseEventCode { CloseEventCodeNotSpecified = -1, @@ -111,30 +105,28 @@ public: CloseEventCodeMaximumUserDefined = 4999 }; -#if ENABLE(BLOB) // FileReaderLoaderClient functions. - virtual void didStartLoading(); - virtual void didReceiveData(); - virtual void didFinishLoading(); - virtual void didFail(int errorCode); -#endif + void didStartLoading() override; + void didReceiveData() override; + void didFinishLoading() override; + void didFail(int errorCode) override; using RefCounted<WebSocketChannel>::ref; using RefCounted<WebSocketChannel>::deref; protected: - virtual void refThreadableWebSocketChannel() { ref(); } - virtual void derefThreadableWebSocketChannel() { deref(); } + void refThreadableWebSocketChannel() override { ref(); } + void derefThreadableWebSocketChannel() override { deref(); } private: - WebSocketChannel(Document*, WebSocketChannelClient*); + WEBCORE_EXPORT WebSocketChannel(Document&, WebSocketChannelClient&, SocketProvider&); bool appendToBuffer(const char* data, size_t len); void skipBuffer(size_t len); bool processBuffer(); - void resumeTimerFired(Timer<WebSocketChannel>*); + void resumeTimerFired(); void startClosingHandshake(int code, const String& reason); - void closingTimerFired(Timer<WebSocketChannel>*); + void closingTimerFired(); bool processFrame(); @@ -161,7 +153,7 @@ private: }; void enqueueTextFrame(const CString&); void enqueueRawFrame(WebSocketFrame::OpCode, const char* data, size_t dataLength); - void enqueueBlobFrame(WebSocketFrame::OpCode, const Blob&); + void enqueueBlobFrame(WebSocketFrame::OpCode, Blob&); void processOutgoingFrameQueue(); void abortOutgoingFrameQueue(); @@ -182,53 +174,48 @@ private: // instead of call sendFrame() directly. bool sendFrame(WebSocketFrame::OpCode, const char* data, size_t dataLength); -#if ENABLE(BLOB) enum BlobLoaderStatus { BlobLoaderNotStarted, BlobLoaderStarted, BlobLoaderFinished, BlobLoaderFailed }; -#endif Document* m_document; WebSocketChannelClient* m_client; - OwnPtr<WebSocketHandshake> m_handshake; + std::unique_ptr<WebSocketHandshake> m_handshake; RefPtr<SocketStreamHandle> m_handle; Vector<char> m_buffer; - Timer<WebSocketChannel> m_resumeTimer; - bool m_suspended; - bool m_closing; - bool m_receivedClosingHandshake; - Timer<WebSocketChannel> m_closingTimer; - bool m_closed; - bool m_shouldDiscardReceivedData; - unsigned long m_unhandledBufferedAmount; + Timer m_resumeTimer; + bool m_suspended { false }; + bool m_closing { false }; + bool m_receivedClosingHandshake { false }; + Timer m_closingTimer; + bool m_closed { false }; + bool m_shouldDiscardReceivedData { false }; + unsigned m_unhandledBufferedAmount { 0 }; - unsigned long m_identifier; // m_identifier == 0 means that we could not obtain a valid identifier. + unsigned m_identifier { 0 }; // m_identifier == 0 means that we could not obtain a valid identifier. // Private members only for hybi-10 protocol. - bool m_hasContinuousFrame; + bool m_hasContinuousFrame { false }; WebSocketFrame::OpCode m_continuousFrameOpCode; - Vector<char> m_continuousFrameData; - unsigned short m_closeEventCode; + Vector<uint8_t> m_continuousFrameData; + unsigned short m_closeEventCode { CloseEventCodeAbnormalClosure }; String m_closeEventReason; - Deque<OwnPtr<QueuedFrame>> m_outgoingFrameQueue; - OutgoingFrameQueueStatus m_outgoingFrameQueueStatus; + Deque<std::unique_ptr<QueuedFrame>> m_outgoingFrameQueue; + OutgoingFrameQueueStatus m_outgoingFrameQueueStatus { OutgoingFrameQueueOpen }; -#if ENABLE(BLOB) // FIXME: Load two or more Blobs simultaneously for better performance. - OwnPtr<FileReaderLoader> m_blobLoader; - BlobLoaderStatus m_blobLoaderStatus; -#endif + std::unique_ptr<FileReaderLoader> m_blobLoader; + BlobLoaderStatus m_blobLoaderStatus { BlobLoaderNotStarted }; WebSocketDeflateFramer m_deflateFramer; + Ref<SocketProvider> m_socketProvider; }; } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketChannel_h diff --git a/Source/WebCore/Modules/websockets/WebSocketChannelClient.h b/Source/WebCore/Modules/websockets/WebSocketChannelClient.h index 46641c06c..b7b7fcd2d 100644 --- a/Source/WebCore/Modules/websockets/WebSocketChannelClient.h +++ b/Source/WebCore/Modules/websockets/WebSocketChannelClient.h @@ -28,38 +28,35 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketChannelClient_h -#define WebSocketChannelClient_h +#pragma once #if ENABLE(WEB_SOCKETS) #include <wtf/Forward.h> -#include <wtf/PassOwnPtr.h> #include <wtf/Vector.h> namespace WebCore { - class WebSocketChannelClient { - public: - virtual ~WebSocketChannelClient() { } - virtual void didConnect() { } - virtual void didReceiveMessage(const String&) { } - virtual void didReceiveBinaryData(PassOwnPtr<Vector<char>>) { } - virtual void didReceiveMessageError() { } - virtual void didUpdateBufferedAmount(unsigned long /* bufferedAmount */) { } - virtual void didStartClosingHandshake() { } - enum ClosingHandshakeCompletionStatus { - ClosingHandshakeIncomplete, - ClosingHandshakeComplete - }; - virtual void didClose(unsigned long /* unhandledBufferedAmount */, ClosingHandshakeCompletionStatus, unsigned short /* code */, const String& /* reason */) { } - - protected: - WebSocketChannelClient() { } +class WebSocketChannelClient { +public: + virtual ~WebSocketChannelClient() { } + virtual void didConnect() = 0; + virtual void didReceiveMessage(const String&) = 0; + virtual void didReceiveBinaryData(Vector<uint8_t>&&) = 0; + virtual void didReceiveMessageError() = 0; + virtual void didUpdateBufferedAmount(unsigned bufferedAmount) = 0; + virtual void didStartClosingHandshake() = 0; + enum ClosingHandshakeCompletionStatus { + ClosingHandshakeIncomplete, + ClosingHandshakeComplete }; + virtual void didClose(unsigned unhandledBufferedAmount, ClosingHandshakeCompletionStatus, unsigned short code, const String& reason) = 0; + virtual void didUpgradeURL() = 0; + +protected: + WebSocketChannelClient() { } +}; } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketChannelClient_h diff --git a/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.cpp b/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.cpp index f3aa3f081..4677c053d 100644 --- a/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.cpp +++ b/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.cpp @@ -43,19 +43,14 @@ namespace WebCore { class WebSocketExtensionDeflateFrame : public WebSocketExtensionProcessor { WTF_MAKE_FAST_ALLOCATED; public: - static PassOwnPtr<WebSocketExtensionDeflateFrame> create(WebSocketDeflateFramer* framer) - { - return adoptPtr(new WebSocketExtensionDeflateFrame(framer)); - } + explicit WebSocketExtensionDeflateFrame(WebSocketDeflateFramer*); virtual ~WebSocketExtensionDeflateFrame() { } - virtual String handshakeString() override; - virtual bool processResponse(const HashMap<String, String>&) override; - virtual String failureReason() override { return m_failureReason; } + String handshakeString() override; + bool processResponse(const HashMap<String, String>&) override; + String failureReason() override { return m_failureReason; } private: - WebSocketExtensionDeflateFrame(WebSocketDeflateFramer*); - WebSocketDeflateFramer* m_framer; bool m_responseProcessed; String m_failureReason; @@ -84,7 +79,7 @@ bool WebSocketExtensionDeflateFrame::processResponse(const HashMap<String, Strin } m_responseProcessed = true; - int expectedNumParameters = 0; + unsigned expectedNumParameters = 0; int windowBits = 15; HashMap<String, String>::const_iterator parameter = serverParameters.find("max_window_bits"); if (parameter != serverParameters.end()) { @@ -161,9 +156,9 @@ WebSocketDeflateFramer::WebSocketDeflateFramer() { } -PassOwnPtr<WebSocketExtensionProcessor> WebSocketDeflateFramer::createExtensionProcessor() +std::unique_ptr<WebSocketExtensionProcessor> WebSocketDeflateFramer::createExtensionProcessor() { - return WebSocketExtensionDeflateFrame::create(this); + return std::make_unique<WebSocketExtensionDeflateFrame>(this); } bool WebSocketDeflateFramer::canDeflate() const @@ -178,33 +173,33 @@ bool WebSocketDeflateFramer::canDeflate() const #if USE(ZLIB) void WebSocketDeflateFramer::enableDeflate(int windowBits, WebSocketDeflater::ContextTakeOverMode mode) { - m_deflater = WebSocketDeflater::create(windowBits, mode); - m_inflater = WebSocketInflater::create(); + m_deflater = std::make_unique<WebSocketDeflater>(windowBits, mode); + m_inflater = std::make_unique<WebSocketInflater>(); if (!m_deflater->initialize() || !m_inflater->initialize()) { - m_deflater.clear(); - m_inflater.clear(); + m_deflater = nullptr; + m_inflater = nullptr; return; } m_enabled = true; } #endif -PassOwnPtr<DeflateResultHolder> WebSocketDeflateFramer::deflate(WebSocketFrame& frame) +std::unique_ptr<DeflateResultHolder> WebSocketDeflateFramer::deflate(WebSocketFrame& frame) { #if USE(ZLIB) - OwnPtr<DeflateResultHolder> result = DeflateResultHolder::create(this); + auto result = std::make_unique<DeflateResultHolder>(this); if (!enabled() || !WebSocketFrame::isNonControlOpCode(frame.opCode) || !frame.payloadLength) - return result.release(); + return result; if (!m_deflater->addBytes(frame.payload, frame.payloadLength) || !m_deflater->finish()) { result->fail("Failed to compress frame"); - return result.release(); + return result; } frame.compress = true; frame.payload = m_deflater->data(); frame.payloadLength = m_deflater->size(); - return result.release(); + return result; #else - return DeflateResultHolder::create(this); + return std::make_unique<DeflateResultHolder>(this); #endif } @@ -216,30 +211,30 @@ void WebSocketDeflateFramer::resetDeflateContext() #endif } -PassOwnPtr<InflateResultHolder> WebSocketDeflateFramer::inflate(WebSocketFrame& frame) +std::unique_ptr<InflateResultHolder> WebSocketDeflateFramer::inflate(WebSocketFrame& frame) { - OwnPtr<InflateResultHolder> result = InflateResultHolder::create(this); + auto result = std::make_unique<InflateResultHolder>(this); if (!enabled() && frame.compress) { result->fail("Compressed bit must be 0 if no negotiated deflate-frame extension"); - return result.release(); + return result; } #if USE(ZLIB) if (!frame.compress) - return result.release(); + return result; if (!WebSocketFrame::isNonControlOpCode(frame.opCode)) { result->fail("Received unexpected compressed frame"); - return result.release(); + return result; } if (!m_inflater->addBytes(frame.payload, frame.payloadLength) || !m_inflater->finish()) { result->fail("Failed to decompress frame"); - return result.release(); + return result; } frame.compress = false; frame.payload = m_inflater->data(); frame.payloadLength = m_inflater->size(); - return result.release(); + return result; #else - return result.release(); + return result; #endif } diff --git a/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.h b/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.h index f3232af70..7cbc06da8 100644 --- a/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.h +++ b/Source/WebCore/Modules/websockets/WebSocketDeflateFramer.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketDeflateFramer_h -#define WebSocketDeflateFramer_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -38,8 +37,6 @@ #endif #include "WebSocketExtensionProcessor.h" #include "WebSocketFrame.h" -#include <wtf/OwnPtr.h> -#include <wtf/PassOwnPtr.h> namespace WebCore { @@ -48,11 +45,7 @@ class WebSocketDeflateFramer; class DeflateResultHolder { WTF_MAKE_FAST_ALLOCATED; public: - static PassOwnPtr<DeflateResultHolder> create(WebSocketDeflateFramer* framer) - { - return adoptPtr(new DeflateResultHolder(framer)); - } - + explicit DeflateResultHolder(WebSocketDeflateFramer*); ~DeflateResultHolder(); bool succeeded() const { return m_succeeded; } @@ -61,8 +54,6 @@ public: void fail(const String& failureReason); private: - explicit DeflateResultHolder(WebSocketDeflateFramer*); - WebSocketDeflateFramer* m_framer; bool m_succeeded; String m_failureReason; @@ -71,11 +62,7 @@ private: class InflateResultHolder { WTF_MAKE_FAST_ALLOCATED; public: - static PassOwnPtr<InflateResultHolder> create(WebSocketDeflateFramer* framer) - { - return adoptPtr(new InflateResultHolder(framer)); - } - + explicit InflateResultHolder(WebSocketDeflateFramer*); ~InflateResultHolder(); bool succeeded() const { return m_succeeded; } @@ -84,8 +71,6 @@ public: void fail(const String& failureReason); private: - explicit InflateResultHolder(WebSocketDeflateFramer*); - WebSocketDeflateFramer* m_framer; bool m_succeeded; String m_failureReason; @@ -95,14 +80,14 @@ class WebSocketDeflateFramer { public: WebSocketDeflateFramer(); - PassOwnPtr<WebSocketExtensionProcessor> createExtensionProcessor(); + std::unique_ptr<WebSocketExtensionProcessor> createExtensionProcessor(); bool canDeflate() const; bool enabled() const { return m_enabled; } - PassOwnPtr<DeflateResultHolder> deflate(WebSocketFrame&); + std::unique_ptr<DeflateResultHolder> deflate(WebSocketFrame&); void resetDeflateContext(); - PassOwnPtr<InflateResultHolder> inflate(WebSocketFrame&); + std::unique_ptr<InflateResultHolder> inflate(WebSocketFrame&); void resetInflateContext(); void didFail(); @@ -114,13 +99,11 @@ public: private: bool m_enabled; #if USE(ZLIB) - OwnPtr<WebSocketDeflater> m_deflater; - OwnPtr<WebSocketInflater> m_inflater; + std::unique_ptr<WebSocketDeflater> m_deflater; + std::unique_ptr<WebSocketInflater> m_inflater; #endif }; -} +} // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketDeflateFramer_h diff --git a/Source/WebCore/Modules/websockets/WebSocketDeflater.cpp b/Source/WebCore/Modules/websockets/WebSocketDeflater.cpp index aef4b1e9d..497b293a6 100644 --- a/Source/WebCore/Modules/websockets/WebSocketDeflater.cpp +++ b/Source/WebCore/Modules/websockets/WebSocketDeflater.cpp @@ -47,18 +47,13 @@ namespace WebCore { static const int defaultMemLevel = 1; static const size_t bufferIncrementUnit = 4096; -PassOwnPtr<WebSocketDeflater> WebSocketDeflater::create(int windowBits, ContextTakeOverMode contextTakeOverMode) -{ - return adoptPtr(new WebSocketDeflater(windowBits, contextTakeOverMode)); -} - WebSocketDeflater::WebSocketDeflater(int windowBits, ContextTakeOverMode contextTakeOverMode) : m_windowBits(windowBits) , m_contextTakeOverMode(contextTakeOverMode) { ASSERT(m_windowBits >= 8); ASSERT(m_windowBits <= 15); - m_stream = adoptPtr(new z_stream); + m_stream = std::make_unique<z_stream>(); memset(m_stream.get(), 0, sizeof(z_stream)); } @@ -127,15 +122,10 @@ void WebSocketDeflater::reset() deflateReset(m_stream.get()); } -PassOwnPtr<WebSocketInflater> WebSocketInflater::create(int windowBits) -{ - return adoptPtr(new WebSocketInflater(windowBits)); -} - WebSocketInflater::WebSocketInflater(int windowBits) : m_windowBits(windowBits) { - m_stream = adoptPtr(new z_stream); + m_stream = std::make_unique<z_stream>(); memset(m_stream.get(), 0, sizeof(z_stream)); } diff --git a/Source/WebCore/Modules/websockets/WebSocketDeflater.h b/Source/WebCore/Modules/websockets/WebSocketDeflater.h index f73eb20d3..8296c47ed 100644 --- a/Source/WebCore/Modules/websockets/WebSocketDeflater.h +++ b/Source/WebCore/Modules/websockets/WebSocketDeflater.h @@ -28,14 +28,11 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketDeflater_h -#define WebSocketDeflater_h +#pragma once #if ENABLE(WEB_SOCKETS) #include <wtf/Noncopyable.h> -#include <wtf/OwnPtr.h> -#include <wtf/PassOwnPtr.h> #include <wtf/Vector.h> struct z_stream_s; @@ -50,8 +47,8 @@ public: DoNotTakeOverContext, TakeOverContext }; - static PassOwnPtr<WebSocketDeflater> create(int windowBits, ContextTakeOverMode = TakeOverContext); + explicit WebSocketDeflater(int windowBits, ContextTakeOverMode = TakeOverContext); ~WebSocketDeflater(); bool initialize(); @@ -62,19 +59,16 @@ public: void reset(); private: - WebSocketDeflater(int windowBits, ContextTakeOverMode); - int m_windowBits; ContextTakeOverMode m_contextTakeOverMode; Vector<char> m_buffer; - OwnPtr<z_stream> m_stream; + std::unique_ptr<z_stream> m_stream; }; class WebSocketInflater { WTF_MAKE_FAST_ALLOCATED; public: - static PassOwnPtr<WebSocketInflater> create(int windowBits = 15); - + explicit WebSocketInflater(int windowBits = 15); ~WebSocketInflater(); bool initialize(); @@ -85,15 +79,11 @@ public: void reset(); private: - explicit WebSocketInflater(int windowBits); - int m_windowBits; Vector<char> m_buffer; - OwnPtr<z_stream> m_stream; + std::unique_ptr<z_stream> m_stream; }; -} +} // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketDeflater_h diff --git a/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.cpp b/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.cpp index 06b719465..d834c733e 100644 --- a/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.cpp +++ b/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.cpp @@ -48,16 +48,16 @@ void WebSocketExtensionDispatcher::reset() m_processors.clear(); } -void WebSocketExtensionDispatcher::addProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor) +void WebSocketExtensionDispatcher::addProcessor(std::unique_ptr<WebSocketExtensionProcessor> processor) { - for (size_t i = 0; i < m_processors.size(); ++i) { - if (m_processors[i]->extensionToken() == processor->extensionToken()) + for (auto& extensionProcessor : m_processors) { + if (extensionProcessor->extensionToken() == processor->extensionToken()) return; } ASSERT(processor->handshakeString().length()); ASSERT(!processor->handshakeString().contains('\n')); ASSERT(!processor->handshakeString().contains(static_cast<UChar>('\0'))); - m_processors.append(processor); + m_processors.append(WTFMove(processor)); } const String WebSocketExtensionDispatcher::createHeaderValue() const @@ -69,7 +69,7 @@ const String WebSocketExtensionDispatcher::createHeaderValue() const StringBuilder builder; builder.append(m_processors[0]->handshakeString()); for (size_t i = 1; i < numProcessors; ++i) { - builder.append(", "); + builder.appendLiteral(", "); builder.append(m_processors[i]->handshakeString()); } return builder.toString(); @@ -78,15 +78,15 @@ const String WebSocketExtensionDispatcher::createHeaderValue() const void WebSocketExtensionDispatcher::appendAcceptedExtension(const String& extensionToken, HashMap<String, String>& extensionParameters) { if (!m_acceptedExtensionsBuilder.isEmpty()) - m_acceptedExtensionsBuilder.append(", "); + m_acceptedExtensionsBuilder.appendLiteral(", "); m_acceptedExtensionsBuilder.append(extensionToken); // FIXME: Should use ListHashSet to keep the order of the parameters. - for (HashMap<String, String>::const_iterator iterator = extensionParameters.begin(); iterator != extensionParameters.end(); ++iterator) { - m_acceptedExtensionsBuilder.append("; "); - m_acceptedExtensionsBuilder.append(iterator->key); - if (!iterator->value.isNull()) { - m_acceptedExtensionsBuilder.append("="); - m_acceptedExtensionsBuilder.append(iterator->value); + for (auto& parameter : extensionParameters) { + m_acceptedExtensionsBuilder.appendLiteral("; "); + m_acceptedExtensionsBuilder.append(parameter.key); + if (!parameter.value.isNull()) { + m_acceptedExtensionsBuilder.append('='); + m_acceptedExtensionsBuilder.append(parameter.value); } } } @@ -118,9 +118,8 @@ bool WebSocketExtensionDispatcher::processHeaderValue(const String& headerValue) return false; } - size_t index; - for (index = 0; index < m_processors.size(); ++index) { - WebSocketExtensionProcessor* processor = m_processors[index].get(); + size_t index = 0; + for (auto& processor : m_processors) { if (extensionToken == processor->extensionToken()) { if (processor->processResponse(extensionParameters)) { appendAcceptedExtension(extensionToken, extensionParameters); @@ -129,6 +128,7 @@ bool WebSocketExtensionDispatcher::processHeaderValue(const String& headerValue) fail(processor->failureReason()); return false; } + ++index; } // There is no extension which can process the response. if (index == m_processors.size()) { diff --git a/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.h b/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.h index 2430017d5..3b2568cee 100644 --- a/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.h +++ b/Source/WebCore/Modules/websockets/WebSocketExtensionDispatcher.h @@ -28,14 +28,11 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketExtensionDispatcher_h -#define WebSocketExtensionDispatcher_h +#pragma once #if ENABLE(WEB_SOCKETS) #include "WebSocketExtensionProcessor.h" -#include <wtf/OwnPtr.h> -#include <wtf/PassOwnPtr.h> #include <wtf/Vector.h> #include <wtf/text/StringBuilder.h> #include <wtf/text/WTFString.h> @@ -47,7 +44,7 @@ public: WebSocketExtensionDispatcher() { } void reset(); - void addProcessor(PassOwnPtr<WebSocketExtensionProcessor>); + void addProcessor(std::unique_ptr<WebSocketExtensionProcessor>); const String createHeaderValue() const; bool processHeaderValue(const String&); @@ -58,13 +55,11 @@ private: void appendAcceptedExtension(const String& extensionToken, HashMap<String, String>& extensionParameters); void fail(const String& reason); - Vector<OwnPtr<WebSocketExtensionProcessor>> m_processors; + Vector<std::unique_ptr<WebSocketExtensionProcessor>> m_processors; StringBuilder m_acceptedExtensionsBuilder; String m_failureReason; }; -} +} // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketExtensionDispatcher_h diff --git a/Source/WebCore/Modules/websockets/WebSocketExtensionParser.h b/Source/WebCore/Modules/websockets/WebSocketExtensionParser.h index c7b0e3646..3247768a8 100644 --- a/Source/WebCore/Modules/websockets/WebSocketExtensionParser.h +++ b/Source/WebCore/Modules/websockets/WebSocketExtensionParser.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketExtensionParser_h -#define WebSocketExtensionParser_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -70,5 +69,3 @@ private: } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketExtensionParser_h diff --git a/Source/WebCore/Modules/websockets/WebSocketExtensionProcessor.h b/Source/WebCore/Modules/websockets/WebSocketExtensionProcessor.h index 167460c76..023a98c12 100644 --- a/Source/WebCore/Modules/websockets/WebSocketExtensionProcessor.h +++ b/Source/WebCore/Modules/websockets/WebSocketExtensionProcessor.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketExtensionProcessor_h -#define WebSocketExtensionProcessor_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -67,8 +66,6 @@ private: String m_extensionToken; }; -} +} // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketExtensionProcessor_h diff --git a/Source/WebCore/Modules/websockets/WebSocketFrame.h b/Source/WebCore/Modules/websockets/WebSocketFrame.h index c1893624f..687ccd0e8 100644 --- a/Source/WebCore/Modules/websockets/WebSocketFrame.h +++ b/Source/WebCore/Modules/websockets/WebSocketFrame.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketFrame_h -#define WebSocketFrame_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -61,7 +60,7 @@ struct WebSocketFrame { static bool needsExtendedLengthField(size_t payloadLength); static ParseFrameResult parseFrame(char* data, size_t dataLength, WebSocketFrame&, const char*& frameEnd, String& errorString); // May modify part of data to unmask the frame. - WebSocketFrame(OpCode = OpCodeInvalid, bool final = false, bool compress = false, bool masked = false, const char* payload = 0, size_t payloadLength = 0); + WebSocketFrame(OpCode = OpCodeInvalid, bool final = false, bool compress = false, bool masked = false, const char* payload = nullptr, size_t payloadLength = 0); void makeFrameData(Vector<char>& frameData); OpCode opCode; @@ -77,5 +76,3 @@ struct WebSocketFrame { } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketFrame_h diff --git a/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp b/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp index e6782cf8f..dd3b9a4e4 100644 --- a/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp +++ b/Source/WebCore/Modules/websockets/WebSocketHandshake.cpp @@ -40,13 +40,14 @@ #include "CookieJar.h" #include "Document.h" #include "HTTPHeaderMap.h" +#include "HTTPHeaderNames.h" #include "HTTPParsers.h" #include "URL.h" #include "Logging.h" #include "ResourceRequest.h" -#include "ScriptCallStack.h" #include "ScriptExecutionContext.h" #include "SecurityOrigin.h" +#include <wtf/ASCIICType.h> #include <wtf/CryptographicallyRandomNumber.h> #include <wtf/MD5.h> #include <wtf/SHA1.h> @@ -56,6 +57,7 @@ #include <wtf/text/Base64.h> #include <wtf/text/CString.h> #include <wtf/text/StringBuilder.h> +#include <wtf/text/StringView.h> #include <wtf/text/WTFString.h> #include <wtf/unicode/CharacterNames.h> @@ -81,10 +83,10 @@ static String hostName(const URL& url, bool secure) { ASSERT(url.protocolIs("wss") == secure); StringBuilder builder; - builder.append(url.host().lower()); - if (url.port() && ((!secure && url.port() != 80) || (secure && url.port() != 443))) { + builder.append(url.host().convertToASCIILowercase()); + if (url.port() && ((!secure && url.port().value() != 80) || (secure && url.port().value() != 443))) { builder.append(':'); - builder.appendNumber(url.port()); + builder.appendNumber(url.port().value()); } return builder.toString(); } @@ -118,12 +120,13 @@ String WebSocketHandshake::getExpectedWebSocketAccept(const String& secWebSocket return base64Encode(hash.data(), SHA1::hashSize); } -WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, ScriptExecutionContext* context) +WebSocketHandshake::WebSocketHandshake(const URL& url, const String& protocol, Document* document, bool allowCookies) : m_url(url) , m_clientProtocol(protocol) , m_secure(m_url.protocolIs("wss")) - , m_context(context) + , m_document(document) , m_mode(Incomplete) + , m_allowCookies(allowCookies) { m_secWebSocketKey = generateSecWebSocketKey(); m_expectedAccept = getExpectedWebSocketAccept(m_secWebSocketKey); @@ -140,12 +143,13 @@ const URL& WebSocketHandshake::url() const void WebSocketHandshake::setURL(const URL& url) { - m_url = url.copy(); + m_url = url.isolatedCopy(); } +// FIXME: Return type should just be String, not const String. const String WebSocketHandshake::host() const { - return m_url.host().lower(); + return m_url.host().convertToASCIILowercase(); } const String& WebSocketHandshake::clientProtocol() const @@ -165,14 +169,14 @@ bool WebSocketHandshake::secure() const String WebSocketHandshake::clientOrigin() const { - return m_context->securityOrigin()->toString(); + return m_document->securityOrigin().toString(); } String WebSocketHandshake::clientLocation() const { StringBuilder builder; builder.append(m_secure ? "wss" : "ws"); - builder.append("://"); + builder.appendLiteral("://"); builder.append(hostName(m_url, m_secure)); builder.append(resourceName(m_url)); return builder.toString(); @@ -183,9 +187,9 @@ CString WebSocketHandshake::clientHandshakeMessage() const // Keep the following consistent with clientHandshakeRequest(). StringBuilder builder; - builder.append("GET "); + builder.appendLiteral("GET "); builder.append(resourceName(m_url)); - builder.append(" HTTP/1.1\r\n"); + builder.appendLiteral(" HTTP/1.1\r\n"); Vector<String> fields; fields.append("Upgrade: websocket"); @@ -196,12 +200,10 @@ CString WebSocketHandshake::clientHandshakeMessage() const fields.append("Sec-WebSocket-Protocol: " + m_clientProtocol); URL url = httpURLForAuthenticationAndCookies(); - if (m_context->isDocument()) { - Document* document = toDocument(m_context); - String cookie = cookieRequestHeaderFieldValue(document, url); + if (m_allowCookies && m_document) { + String cookie = cookieRequestHeaderFieldValue(*m_document, url); if (!cookie.isEmpty()) fields.append("Cookie: " + cookie); - // Set "Cookie2: <cookie>" if cookies 2 exists for url? } // Add no-cache headers to avoid compatibility issue. @@ -218,18 +220,18 @@ CString WebSocketHandshake::clientHandshakeMessage() const fields.append("Sec-WebSocket-Extensions: " + extensionValue); // Add a User-Agent header. - fields.append("User-Agent: " + m_context->userAgent(m_context->url())); + fields.append("User-Agent: " + m_document->userAgent(m_document->url())); // Fields in the handshake are sent by the client in a random order; the // order is not meaningful. Thus, it's ok to send the order we constructed // the fields. - for (size_t i = 0; i < fields.size(); i++) { - builder.append(fields[i]); - builder.append("\r\n"); + for (auto& field : fields) { + builder.append(field); + builder.appendLiteral("\r\n"); } - builder.append("\r\n"); + builder.appendLiteral("\r\n"); return builder.toString().utf8(); } @@ -237,37 +239,33 @@ CString WebSocketHandshake::clientHandshakeMessage() const ResourceRequest WebSocketHandshake::clientHandshakeRequest() const { // Keep the following consistent with clientHandshakeMessage(). - // FIXME: do we need to store m_secWebSocketKey1, m_secWebSocketKey2 and - // m_key3 in the request? ResourceRequest request(m_url); request.setHTTPMethod("GET"); - request.addHTTPHeaderField("Connection", "Upgrade"); - request.addHTTPHeaderField("Host", hostName(m_url, m_secure)); - request.addHTTPHeaderField("Origin", clientOrigin()); + request.setHTTPHeaderField(HTTPHeaderName::Connection, "Upgrade"); + request.setHTTPHeaderField(HTTPHeaderName::Host, hostName(m_url, m_secure)); + request.setHTTPHeaderField(HTTPHeaderName::Origin, clientOrigin()); if (!m_clientProtocol.isEmpty()) - request.addHTTPHeaderField("Sec-WebSocket-Protocol", m_clientProtocol); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketProtocol, m_clientProtocol); URL url = httpURLForAuthenticationAndCookies(); - if (m_context->isDocument()) { - Document* document = toDocument(m_context); - String cookie = cookieRequestHeaderFieldValue(document, url); + if (m_allowCookies && m_document) { + String cookie = cookieRequestHeaderFieldValue(*m_document, url); if (!cookie.isEmpty()) - request.addHTTPHeaderField("Cookie", cookie); - // Set "Cookie2: <cookie>" if cookies 2 exists for url? + request.setHTTPHeaderField(HTTPHeaderName::Cookie, cookie); } - request.addHTTPHeaderField("Pragma", "no-cache"); - request.addHTTPHeaderField("Cache-Control", "no-cache"); + request.setHTTPHeaderField(HTTPHeaderName::Pragma, "no-cache"); + request.setHTTPHeaderField(HTTPHeaderName::CacheControl, "no-cache"); - request.addHTTPHeaderField("Sec-WebSocket-Key", m_secWebSocketKey); - request.addHTTPHeaderField("Sec-WebSocket-Version", "13"); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketKey, m_secWebSocketKey); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketVersion, "13"); const String extensionValue = m_extensionDispatcher.createHeaderValue(); if (extensionValue.length()) - request.addHTTPHeaderField("Sec-WebSocket-Extensions", extensionValue); + request.setHTTPHeaderField(HTTPHeaderName::SecWebSocketExtensions, extensionValue); // Add a User-Agent header. - request.addHTTPHeaderField("User-Agent", m_context->userAgent(m_context->url())); + request.setHTTPHeaderField(HTTPHeaderName::UserAgent, m_document->userAgent(m_document->url())); return request; } @@ -278,9 +276,9 @@ void WebSocketHandshake::reset() m_extensionDispatcher.reset(); } -void WebSocketHandshake::clearScriptExecutionContext() +void WebSocketHandshake::clearDocument() { - m_context = 0; + m_document = nullptr; } int WebSocketHandshake::readServerHandshake(const char* header, size_t len) @@ -303,7 +301,7 @@ int WebSocketHandshake::readServerHandshake(const char* header, size_t len) if (statusCode != 101) { m_mode = Failed; - m_failureReason = "Unexpected response code: " + String::number(statusCode); + m_failureReason = makeString("Unexpected response code: ", String::number(statusCode)); return len; } m_mode = Normal; @@ -340,32 +338,27 @@ String WebSocketHandshake::failureReason() const String WebSocketHandshake::serverWebSocketProtocol() const { - return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-protocol"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketProtocol); } String WebSocketHandshake::serverSetCookie() const { - return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie"); -} - -String WebSocketHandshake::serverSetCookie2() const -{ - return m_serverHandshakeResponse.httpHeaderFields().get("set-cookie2"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SetCookie); } String WebSocketHandshake::serverUpgrade() const { - return m_serverHandshakeResponse.httpHeaderFields().get("upgrade"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Upgrade); } String WebSocketHandshake::serverConnection() const { - return m_serverHandshakeResponse.httpHeaderFields().get("connection"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::Connection); } String WebSocketHandshake::serverWebSocketAccept() const { - return m_serverHandshakeResponse.httpHeaderFields().get("sec-websocket-accept"); + return m_serverHandshakeResponse.httpHeaderFields().get(HTTPHeaderName::SecWebSocketAccept); } String WebSocketHandshake::acceptedExtensions() const @@ -378,19 +371,56 @@ const ResourceResponse& WebSocketHandshake::serverHandshakeResponse() const return m_serverHandshakeResponse; } -void WebSocketHandshake::addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor> processor) +void WebSocketHandshake::addExtensionProcessor(std::unique_ptr<WebSocketExtensionProcessor> processor) { - m_extensionDispatcher.addProcessor(processor); + m_extensionDispatcher.addProcessor(WTFMove(processor)); } URL WebSocketHandshake::httpURLForAuthenticationAndCookies() const { - URL url = m_url.copy(); + URL url = m_url.isolatedCopy(); bool couldSetProtocol = url.setProtocol(m_secure ? "https" : "http"); ASSERT_UNUSED(couldSetProtocol, couldSetProtocol); return url; } +// https://tools.ietf.org/html/rfc6455#section-4.1 +// "The HTTP version MUST be at least 1.1." +static inline bool headerHasValidHTTPVersion(StringView httpStatusLine) +{ + const char* httpVersionStaticPreambleLiteral = "HTTP/"; + StringView httpVersionStaticPreamble(reinterpret_cast<const LChar*>(httpVersionStaticPreambleLiteral), strlen(httpVersionStaticPreambleLiteral)); + if (!httpStatusLine.startsWith(httpVersionStaticPreamble)) + return false; + + // Check that there is a version number which should be at least three characters after "HTTP/" + unsigned preambleLength = httpVersionStaticPreamble.length(); + if (httpStatusLine.length() < preambleLength + 3) + return false; + + auto dotPosition = httpStatusLine.find('.', preambleLength); + if (dotPosition == notFound) + return false; + + StringView majorVersionView = httpStatusLine.substring(preambleLength, dotPosition - preambleLength); + bool isValid; + int majorVersion = majorVersionView.toIntStrict(isValid); + if (!isValid) + return false; + + unsigned minorVersionLength; + unsigned charactersLeftAfterDotPosition = httpStatusLine.length() - dotPosition; + for (minorVersionLength = 1; minorVersionLength < charactersLeftAfterDotPosition; minorVersionLength++) { + if (!isASCIIDigit(httpStatusLine[dotPosition + minorVersionLength])) + break; + } + int minorVersion = (httpStatusLine.substring(dotPosition + 1, minorVersionLength)).toIntStrict(isValid); + if (!isValid) + return false; + + return (majorVersion >= 1 && minorVersion >= 1) || majorVersion >= 2; +} + // Returns the header length (including "\r\n"), or -1 if we have not received enough data yet. // If the line is malformed or the status code is not a 3-digit number, // statusCode and statusText will be set to -1 and a null string, respectively. @@ -403,8 +433,8 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, statusCode = -1; statusText = String(); - const char* space1 = 0; - const char* space2 = 0; + const char* space1 = nullptr; + const char* space2 = nullptr; const char* p; size_t consumedLength; @@ -418,7 +448,10 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, // The caller isn't prepared to deal with null bytes in status // line. WebSockets specification doesn't prohibit this, but HTTP // does, so we'll just treat this as an error. - m_failureReason = "Status line contains embedded null"; + m_failureReason = ASCIILiteral("Status line contains embedded null"); + return p + 1 - header; + } else if (!isASCII(*p)) { + m_failureReason = ASCIILiteral("Status line contains non-ASCII character"); return p + 1 - header; } else if (*p == '\n') break; @@ -429,32 +462,38 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, const char* end = p + 1; int lineLength = end - header; if (lineLength > maximumLength) { - m_failureReason = "Status line is too long"; + m_failureReason = ASCIILiteral("Status line is too long"); return maximumLength; } // The line must end with "\r\n". if (lineLength < 2 || *(end - 2) != '\r') { - m_failureReason = "Status line does not end with CRLF"; + m_failureReason = ASCIILiteral("Status line does not end with CRLF"); return lineLength; } if (!space1 || !space2) { - m_failureReason = "No response code found: " + trimInputSample(header, lineLength - 2); + m_failureReason = makeString("No response code found: ", trimInputSample(header, lineLength - 2)); + return lineLength; + } + + StringView httpStatusLine(reinterpret_cast<const LChar*>(header), space1 - header); + if (!headerHasValidHTTPVersion(httpStatusLine)) { + m_failureReason = makeString("Invalid HTTP version string: ", httpStatusLine); return lineLength; } - String statusCodeString(space1 + 1, space2 - space1 - 1); + StringView statusCodeString(reinterpret_cast<const LChar*>(space1 + 1), space2 - space1 - 1); if (statusCodeString.length() != 3) // Status code must consist of three digits. return lineLength; for (int i = 0; i < 3; ++i) - if (statusCodeString[i] < '0' || statusCodeString[i] > '9') { - m_failureReason = "Invalid status code: " + statusCodeString; + if (!isASCIIDigit(statusCodeString[i])) { + m_failureReason = makeString("Invalid status code: ", statusCodeString); return lineLength; } bool ok = false; - statusCode = statusCodeString.toInt(&ok); + statusCode = statusCodeString.toIntStrict(ok); ASSERT(ok); statusText = String(space2 + 1, end - space2 - 3); // Exclude "\r\n". @@ -463,7 +502,7 @@ int WebSocketHandshake::readStatusLine(const char* header, size_t headerLength, const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* end) { - AtomicString name; + StringView name; String value; bool sawSecWebSocketExtensionsHeaderField = false; bool sawSecWebSocketAcceptHeaderField = false; @@ -472,39 +511,57 @@ const char* WebSocketHandshake::readHTTPHeaders(const char* start, const char* e for (; p < end; p++) { size_t consumedLength = parseHTTPHeader(p, end - p, m_failureReason, name, value); if (!consumedLength) - return 0; + return nullptr; p += consumedLength; // Stop once we consumed an empty line. if (name.isEmpty()) break; - if (equalIgnoringCase("sec-websocket-extensions", name)) { + HTTPHeaderName headerName; + if (!findHTTPHeaderName(name, headerName)) { + // Evidence in the wild shows that services make use of custom headers in the handshake + m_serverHandshakeResponse.addHTTPHeaderField(name.toString(), value); + continue; + } + + // https://tools.ietf.org/html/rfc7230#section-3.2.4 + // "Newly defined header fields SHOULD limit their field values to US-ASCII octets." + if ((headerName == HTTPHeaderName::SecWebSocketExtensions + || headerName == HTTPHeaderName::SecWebSocketAccept + || headerName == HTTPHeaderName::SecWebSocketProtocol) + && !value.containsOnlyASCII()) { + m_failureReason = makeString(name, " header value should only contain ASCII characters"); + return nullptr; + } + + if (headerName == HTTPHeaderName::SecWebSocketExtensions) { if (sawSecWebSocketExtensionsHeaderField) { - m_failureReason = "The Sec-WebSocket-Extensions header MUST NOT appear more than once in an HTTP response"; - return 0; + m_failureReason = ASCIILiteral("The Sec-WebSocket-Extensions header must not appear more than once in an HTTP response"); + return nullptr; } if (!m_extensionDispatcher.processHeaderValue(value)) { m_failureReason = m_extensionDispatcher.failureReason(); - return 0; + return nullptr; } sawSecWebSocketExtensionsHeaderField = true; - } else if (equalIgnoringCase("Sec-WebSocket-Accept", name)) { - if (sawSecWebSocketAcceptHeaderField) { - m_failureReason = "The Sec-WebSocket-Accept header MUST NOT appear more than once in an HTTP response"; - return 0; + } else { + if (headerName == HTTPHeaderName::SecWebSocketAccept) { + if (sawSecWebSocketAcceptHeaderField) { + m_failureReason = ASCIILiteral("The Sec-WebSocket-Accept header must not appear more than once in an HTTP response"); + return nullptr; + } + sawSecWebSocketAcceptHeaderField = true; + } else if (headerName == HTTPHeaderName::SecWebSocketProtocol) { + if (sawSecWebSocketProtocolHeaderField) { + m_failureReason = ASCIILiteral("The Sec-WebSocket-Protocol header must not appear more than once in an HTTP response"); + return nullptr; + } + sawSecWebSocketProtocolHeaderField = true; } - m_serverHandshakeResponse.addHTTPHeaderField(name, value); - sawSecWebSocketAcceptHeaderField = true; - } else if (equalIgnoringCase("Sec-WebSocket-Protocol", name)) { - if (sawSecWebSocketProtocolHeaderField) { - m_failureReason = "The Sec-WebSocket-Protocol header MUST NOT appear more than once in an HTTP response"; - return 0; - } - m_serverHandshakeResponse.addHTTPHeaderField(name, value); - sawSecWebSocketProtocolHeaderField = true; - } else - m_serverHandshakeResponse.addHTTPHeaderField(name, value); + + m_serverHandshakeResponse.addHTTPHeaderField(headerName, value); + } } return p; } @@ -517,40 +574,40 @@ bool WebSocketHandshake::checkResponseHeaders() const String& serverWebSocketAccept = this->serverWebSocketAccept(); if (serverUpgrade.isNull()) { - m_failureReason = "Error during WebSocket handshake: 'Upgrade' header is missing"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Upgrade' header is missing"); return false; } if (serverConnection.isNull()) { - m_failureReason = "Error during WebSocket handshake: 'Connection' header is missing"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Connection' header is missing"); return false; } if (serverWebSocketAccept.isNull()) { - m_failureReason = "Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Sec-WebSocket-Accept' header is missing"); return false; } - if (!equalIgnoringCase(serverUpgrade, "websocket")) { - m_failureReason = "Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'"; + if (!equalLettersIgnoringASCIICase(serverUpgrade, "websocket")) { + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Upgrade' header value is not 'WebSocket'"); return false; } - if (!equalIgnoringCase(serverConnection, "upgrade")) { - m_failureReason = "Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'"; + if (!equalLettersIgnoringASCIICase(serverConnection, "upgrade")) { + m_failureReason = ASCIILiteral("Error during WebSocket handshake: 'Connection' header value is not 'Upgrade'"); return false; } if (serverWebSocketAccept != m_expectedAccept) { - m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Accept mismatch"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Accept mismatch"); return false; } if (!serverWebSocketProtocol.isNull()) { if (m_clientProtocol.isEmpty()) { - m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"); return false; } Vector<String> result; - m_clientProtocol.split(String(WebSocket::subProtocolSeperator()), result); + m_clientProtocol.split(WebSocket::subprotocolSeparator(), result); if (!result.contains(serverWebSocketProtocol)) { - m_failureReason = "Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"; + m_failureReason = ASCIILiteral("Error during WebSocket handshake: Sec-WebSocket-Protocol mismatch"); return false; } } diff --git a/Source/WebCore/Modules/websockets/WebSocketHandshake.h b/Source/WebCore/Modules/websockets/WebSocketHandshake.h index ccab04501..9f701c530 100644 --- a/Source/WebCore/Modules/websockets/WebSocketHandshake.h +++ b/Source/WebCore/Modules/websockets/WebSocketHandshake.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WebSocketHandshake_h -#define WebSocketHandshake_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -37,13 +36,12 @@ #include "ResourceResponse.h" #include "WebSocketExtensionDispatcher.h" #include "WebSocketExtensionProcessor.h" -#include <wtf/PassOwnPtr.h> #include <wtf/text/WTFString.h> namespace WebCore { +class Document; class ResourceRequest; -class ScriptExecutionContext; class WebSocketHandshake { WTF_MAKE_NONCOPYABLE(WebSocketHandshake); WTF_MAKE_FAST_ALLOCATED; @@ -51,7 +49,7 @@ public: enum Mode { Incomplete, Normal, Failed, Connected }; - WebSocketHandshake(const URL&, const String& protocol, ScriptExecutionContext*); + WebSocketHandshake(const URL&, const String& protocol, Document*, bool allowCookies); ~WebSocketHandshake(); const URL& url() const; @@ -70,7 +68,7 @@ public: ResourceRequest clientHandshakeRequest() const; void reset(); - void clearScriptExecutionContext(); + void clearDocument(); int readServerHandshake(const char* header, size_t len); Mode mode() const; @@ -78,7 +76,6 @@ public: String serverWebSocketProtocol() const; String serverSetCookie() const; - String serverSetCookie2() const; String serverUpgrade() const; String serverConnection() const; String serverWebSocketAccept() const; @@ -86,7 +83,7 @@ public: const ResourceResponse& serverHandshakeResponse() const; - void addExtensionProcessor(PassOwnPtr<WebSocketExtensionProcessor>); + void addExtensionProcessor(std::unique_ptr<WebSocketExtensionProcessor>); static String getExpectedWebSocketAccept(const String& secWebSocketKey); @@ -103,9 +100,10 @@ private: URL m_url; String m_clientProtocol; bool m_secure; - ScriptExecutionContext* m_context; + Document* m_document; Mode m_mode; + bool m_allowCookies; ResourceResponse m_serverHandshakeResponse; @@ -120,5 +118,3 @@ private: } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WebSocketHandshake_h diff --git a/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.cpp b/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.cpp index 65b8d2bc1..f0d86b615 100644 --- a/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.cpp +++ b/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.cpp @@ -35,9 +35,9 @@ #include "WorkerThreadableWebSocketChannel.h" #include "Blob.h" -#include "CrossThreadTask.h" #include "Document.h" #include "ScriptExecutionContext.h" +#include "SocketProvider.h" #include "ThreadableWebSocketChannelClientWrapper.h" #include "WebSocketChannel.h" #include "WebSocketChannelClient.h" @@ -47,15 +47,15 @@ #include "WorkerThread.h" #include <runtime/ArrayBuffer.h> #include <wtf/MainThread.h> -#include <wtf/PassRefPtr.h> #include <wtf/text/WTFString.h> namespace WebCore { -WorkerThreadableWebSocketChannel::WorkerThreadableWebSocketChannel(WorkerGlobalScope* context, WebSocketChannelClient* client, const String& taskMode) +WorkerThreadableWebSocketChannel::WorkerThreadableWebSocketChannel(WorkerGlobalScope& context, WebSocketChannelClient& client, const String& taskMode, SocketProvider& provider) : m_workerGlobalScope(context) , m_workerClientWrapper(ThreadableWebSocketChannelClientWrapper::create(context, client)) - , m_bridge(Bridge::create(m_workerClientWrapper, m_workerGlobalScope, taskMode)) + , m_bridge(Bridge::create(m_workerClientWrapper.copyRef(), m_workerGlobalScope.copyRef(), taskMode, provider)) + , m_socketProvider(provider) { m_bridge->initialize(); } @@ -74,13 +74,11 @@ void WorkerThreadableWebSocketChannel::connect(const URL& url, const String& pro String WorkerThreadableWebSocketChannel::subprotocol() { - ASSERT(m_workerClientWrapper); return m_workerClientWrapper->subprotocol(); } String WorkerThreadableWebSocketChannel::extensions() { - ASSERT(m_workerClientWrapper); return m_workerClientWrapper->extensions(); } @@ -98,14 +96,14 @@ ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::send(co return m_bridge->send(binaryData, byteOffset, byteLength); } -ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::send(const Blob& binaryData) +ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::send(Blob& binaryData) { if (!m_bridge) return ThreadableWebSocketChannel::SendFail; return m_bridge->send(binaryData); } -unsigned long WorkerThreadableWebSocketChannel::bufferedAmount() const +unsigned WorkerThreadableWebSocketChannel::bufferedAmount() const { if (!m_bridge) return 0; @@ -127,7 +125,7 @@ void WorkerThreadableWebSocketChannel::fail(const String& reason) void WorkerThreadableWebSocketChannel::disconnect() { m_bridge->disconnect(); - m_bridge.clear(); + m_bridge = nullptr; } void WorkerThreadableWebSocketChannel::suspend() @@ -144,10 +142,10 @@ void WorkerThreadableWebSocketChannel::resume() m_bridge->resume(); } -WorkerThreadableWebSocketChannel::Peer::Peer(PassRefPtr<ThreadableWebSocketChannelClientWrapper> clientWrapper, WorkerLoaderProxy& loaderProxy, ScriptExecutionContext* context, const String& taskMode) - : m_workerClientWrapper(clientWrapper) +WorkerThreadableWebSocketChannel::Peer::Peer(Ref<ThreadableWebSocketChannelClientWrapper>&& clientWrapper, WorkerLoaderProxy& loaderProxy, ScriptExecutionContext& context, const String& taskMode, SocketProvider& provider) + : m_workerClientWrapper(WTFMove(clientWrapper)) , m_loaderProxy(loaderProxy) - , m_mainWebSocketChannel(WebSocketChannel::create(toDocument(context), this)) + , m_mainWebSocketChannel(WebSocketChannel::create(downcast<Document>(context), *this, provider)) , m_taskMode(taskMode) { ASSERT(isMainThread()); @@ -168,52 +166,53 @@ void WorkerThreadableWebSocketChannel::Peer::connect(const URL& url, const Strin m_mainWebSocketChannel->connect(url, protocol); } -static void workerGlobalScopeDidSend(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, ThreadableWebSocketChannel::SendResult sendRequestResult) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->setSendRequestResult(sendRequestResult); -} - void WorkerThreadableWebSocketChannel::Peer::send(const String& message) { ASSERT(isMainThread()); - if (!m_mainWebSocketChannel || !m_workerClientWrapper) + if (!m_mainWebSocketChannel) return; + ThreadableWebSocketChannel::SendResult sendRequestResult = m_mainWebSocketChannel->send(message); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidSend, m_workerClientWrapper, sendRequestResult), m_taskMode); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), sendRequestResult](ScriptExecutionContext&) mutable { + workerClientWrapper->setSendRequestResult(sendRequestResult); + }, m_taskMode); } void WorkerThreadableWebSocketChannel::Peer::send(const ArrayBuffer& binaryData) { ASSERT(isMainThread()); - if (!m_mainWebSocketChannel || !m_workerClientWrapper) + if (!m_mainWebSocketChannel) return; + ThreadableWebSocketChannel::SendResult sendRequestResult = m_mainWebSocketChannel->send(binaryData, 0, binaryData.byteLength()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidSend, m_workerClientWrapper, sendRequestResult), m_taskMode); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), sendRequestResult](ScriptExecutionContext&) mutable { + workerClientWrapper->setSendRequestResult(sendRequestResult); + }, m_taskMode); } -void WorkerThreadableWebSocketChannel::Peer::send(const Blob& binaryData) +void WorkerThreadableWebSocketChannel::Peer::send(Blob& binaryData) { ASSERT(isMainThread()); - if (!m_mainWebSocketChannel || !m_workerClientWrapper) + if (!m_mainWebSocketChannel) return; - ThreadableWebSocketChannel::SendResult sendRequestResult = m_mainWebSocketChannel->send(binaryData); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidSend, m_workerClientWrapper, sendRequestResult), m_taskMode); -} -static void workerGlobalScopeDidGetBufferedAmount(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, unsigned long bufferedAmount) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->setBufferedAmount(bufferedAmount); + ThreadableWebSocketChannel::SendResult sendRequestResult = m_mainWebSocketChannel->send(binaryData); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), sendRequestResult](ScriptExecutionContext&) mutable { + workerClientWrapper->setSendRequestResult(sendRequestResult); + }, m_taskMode); } void WorkerThreadableWebSocketChannel::Peer::bufferedAmount() { ASSERT(isMainThread()); - if (!m_mainWebSocketChannel || !m_workerClientWrapper) + if (!m_mainWebSocketChannel) return; - unsigned long bufferedAmount = m_mainWebSocketChannel->bufferedAmount(); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidGetBufferedAmount, m_workerClientWrapper, bufferedAmount), m_taskMode); + + unsigned bufferedAmount = m_mainWebSocketChannel->bufferedAmount(); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), bufferedAmount](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->setBufferedAmount(bufferedAmount); + }, m_taskMode); } void WorkerThreadableWebSocketChannel::Peer::close(int code, const String& reason) @@ -238,7 +237,7 @@ void WorkerThreadableWebSocketChannel::Peer::disconnect() if (!m_mainWebSocketChannel) return; m_mainWebSocketChannel->disconnect(); - m_mainWebSocketChannel = 0; + m_mainWebSocketChannel = nullptr; } void WorkerThreadableWebSocketChannel::Peer::suspend() @@ -257,101 +256,98 @@ void WorkerThreadableWebSocketChannel::Peer::resume() m_mainWebSocketChannel->resume(); } -static void workerGlobalScopeDidConnect(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, const String& subprotocol, const String& extensions) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->setSubprotocol(subprotocol); - workerClientWrapper->setExtensions(extensions); - workerClientWrapper->didConnect(); -} - void WorkerThreadableWebSocketChannel::Peer::didConnect() { ASSERT(isMainThread()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidConnect, m_workerClientWrapper, m_mainWebSocketChannel->subprotocol(), m_mainWebSocketChannel->extensions()), m_taskMode); -} -static void workerGlobalScopeDidReceiveMessage(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, const String& message) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->didReceiveMessage(message); + String subprotocol = m_mainWebSocketChannel->subprotocol(); + String extensions = m_mainWebSocketChannel->extensions(); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), subprotocol = subprotocol.isolatedCopy(), extensions = extensions.isolatedCopy()](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->setSubprotocol(subprotocol); + workerClientWrapper->setExtensions(extensions); + workerClientWrapper->didConnect(); + }, m_taskMode); } void WorkerThreadableWebSocketChannel::Peer::didReceiveMessage(const String& message) { ASSERT(isMainThread()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidReceiveMessage, m_workerClientWrapper, message), m_taskMode); -} -static void workerGlobalScopeDidReceiveBinaryData(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, PassOwnPtr<Vector<char>> binaryData) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->didReceiveBinaryData(binaryData); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), message = message.isolatedCopy()](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didReceiveMessage(message); + }, m_taskMode); } -void WorkerThreadableWebSocketChannel::Peer::didReceiveBinaryData(PassOwnPtr<Vector<char>> binaryData) +void WorkerThreadableWebSocketChannel::Peer::didReceiveBinaryData(Vector<uint8_t>&& binaryData) { ASSERT(isMainThread()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidReceiveBinaryData, m_workerClientWrapper, binaryData), m_taskMode); -} -static void workerGlobalScopeDidUpdateBufferedAmount(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, unsigned long bufferedAmount) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->didUpdateBufferedAmount(bufferedAmount); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), binaryData = WTFMove(binaryData)](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didReceiveBinaryData(WTFMove(binaryData)); + }, m_taskMode); } -void WorkerThreadableWebSocketChannel::Peer::didUpdateBufferedAmount(unsigned long bufferedAmount) +void WorkerThreadableWebSocketChannel::Peer::didUpdateBufferedAmount(unsigned bufferedAmount) { ASSERT(isMainThread()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidUpdateBufferedAmount, m_workerClientWrapper, bufferedAmount), m_taskMode); -} -static void workerGlobalScopeDidStartClosingHandshake(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->didStartClosingHandshake(); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), bufferedAmount](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didUpdateBufferedAmount(bufferedAmount); + }, m_taskMode); } void WorkerThreadableWebSocketChannel::Peer::didStartClosingHandshake() { ASSERT(isMainThread()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidStartClosingHandshake, m_workerClientWrapper), m_taskMode); -} -static void workerGlobalScopeDidClose(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, unsigned long unhandledBufferedAmount, WebSocketChannelClient::ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) -{ - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->didClose(unhandledBufferedAmount, closingHandshakeCompletion, code, reason); + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef()](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didStartClosingHandshake(); + }, m_taskMode); } -void WorkerThreadableWebSocketChannel::Peer::didClose(unsigned long unhandledBufferedAmount, ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) +void WorkerThreadableWebSocketChannel::Peer::didClose(unsigned unhandledBufferedAmount, ClosingHandshakeCompletionStatus closingHandshakeCompletion, unsigned short code, const String& reason) { ASSERT(isMainThread()); - m_mainWebSocketChannel = 0; - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidClose, m_workerClientWrapper, unhandledBufferedAmount, closingHandshakeCompletion, code, reason), m_taskMode); + m_mainWebSocketChannel = nullptr; + + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef(), unhandledBufferedAmount, closingHandshakeCompletion, code, reason = reason.isolatedCopy()](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didClose(unhandledBufferedAmount, closingHandshakeCompletion, code, reason); + }, m_taskMode); } -static void workerGlobalScopeDidReceiveMessageError(ScriptExecutionContext* context, PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper) +void WorkerThreadableWebSocketChannel::Peer::didReceiveMessageError() { - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - workerClientWrapper->didReceiveMessageError(); + ASSERT(isMainThread()); + + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef()](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didReceiveMessageError(); + }, m_taskMode); } -void WorkerThreadableWebSocketChannel::Peer::didReceiveMessageError() +void WorkerThreadableWebSocketChannel::Peer::didUpgradeURL() { ASSERT(isMainThread()); - m_loaderProxy.postTaskForModeToWorkerGlobalScope(createCallbackTask(&workerGlobalScopeDidReceiveMessageError, m_workerClientWrapper), m_taskMode); + + m_loaderProxy.postTaskForModeToWorkerGlobalScope([workerClientWrapper = m_workerClientWrapper.copyRef()](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + workerClientWrapper->didUpgradeURL(); + }, m_taskMode); } -WorkerThreadableWebSocketChannel::Bridge::Bridge(PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, PassRefPtr<WorkerGlobalScope> workerGlobalScope, const String& taskMode) - : m_workerClientWrapper(workerClientWrapper) - , m_workerGlobalScope(workerGlobalScope) - , m_loaderProxy(m_workerGlobalScope->thread()->workerLoaderProxy()) +WorkerThreadableWebSocketChannel::Bridge::Bridge(Ref<ThreadableWebSocketChannelClientWrapper>&& workerClientWrapper, Ref<WorkerGlobalScope>&& workerGlobalScope, const String& taskMode, Ref<SocketProvider>&& socketProvider) + : m_workerClientWrapper(WTFMove(workerClientWrapper)) + , m_workerGlobalScope(WTFMove(workerGlobalScope)) + , m_loaderProxy(m_workerGlobalScope->thread().workerLoaderProxy()) , m_taskMode(taskMode) - , m_peer(0) + , m_socketProvider(WTFMove(socketProvider)) { - ASSERT(m_workerClientWrapper.get()); } WorkerThreadableWebSocketChannel::Bridge::~Bridge() @@ -359,273 +355,210 @@ WorkerThreadableWebSocketChannel::Bridge::~Bridge() disconnect(); } -class WorkerThreadableWebSocketChannel::WorkerGlobalScopeDidInitializeTask : public ScriptExecutionContext::Task { -public: - static PassOwnPtr<ScriptExecutionContext::Task> create(WorkerThreadableWebSocketChannel::Peer* peer, - WorkerLoaderProxy* loaderProxy, - PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper) - { - return adoptPtr(new WorkerGlobalScopeDidInitializeTask(peer, loaderProxy, workerClientWrapper)); - } - - virtual ~WorkerGlobalScopeDidInitializeTask() { } - virtual void performTask(ScriptExecutionContext* context) override - { - ASSERT_UNUSED(context, context->isWorkerGlobalScope()); - if (m_workerClientWrapper->failedWebSocketChannelCreation()) { - // If Bridge::initialize() quitted earlier, we need to kick mainThreadDestroy() to delete the peer. - OwnPtr<WorkerThreadableWebSocketChannel::Peer> peer = adoptPtr(m_peer); - m_peer = 0; - m_loaderProxy->postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadDestroy, peer.release())); - } else - m_workerClientWrapper->didCreateWebSocketChannel(m_peer); - } - virtual bool isCleanupTask() const override { return true; } - -private: - WorkerGlobalScopeDidInitializeTask(WorkerThreadableWebSocketChannel::Peer* peer, - WorkerLoaderProxy* loaderProxy, - PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper) - : m_peer(peer) - , m_loaderProxy(loaderProxy) - , m_workerClientWrapper(workerClientWrapper) - { - } - - WorkerThreadableWebSocketChannel::Peer* m_peer; - WorkerLoaderProxy* m_loaderProxy; - RefPtr<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; -}; - -void WorkerThreadableWebSocketChannel::Bridge::mainThreadInitialize(ScriptExecutionContext* context, WorkerLoaderProxy* loaderProxy, PassRefPtr<ThreadableWebSocketChannelClientWrapper> prpClientWrapper, const String& taskMode) +void WorkerThreadableWebSocketChannel::Bridge::mainThreadInitialize(ScriptExecutionContext& context, WorkerLoaderProxy& loaderProxy, Ref<ThreadableWebSocketChannelClientWrapper>&& clientWrapper, const String& taskMode, Ref<SocketProvider>&& provider) { ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - - RefPtr<ThreadableWebSocketChannelClientWrapper> clientWrapper = prpClientWrapper; - - Peer* peer = Peer::create(clientWrapper, *loaderProxy, context, taskMode); - bool sent = loaderProxy->postTaskForModeToWorkerGlobalScope( - WorkerThreadableWebSocketChannel::WorkerGlobalScopeDidInitializeTask::create(peer, loaderProxy, clientWrapper), taskMode); - if (!sent) { + ASSERT(context.isDocument()); + + bool sent = loaderProxy.postTaskForModeToWorkerGlobalScope({ + ScriptExecutionContext::Task::CleanupTask, + [clientWrapper = clientWrapper.copyRef(), &loaderProxy, peer = std::make_unique<Peer>(clientWrapper.copyRef(), loaderProxy, context, taskMode, WTFMove(provider))](ScriptExecutionContext& context) mutable { + ASSERT_UNUSED(context, context.isWorkerGlobalScope()); + if (clientWrapper->failedWebSocketChannelCreation()) { + // If Bridge::initialize() quitted earlier, we need to kick mainThreadDestroy() to delete the peer. + loaderProxy.postTaskToLoader([peer = WTFMove(peer)](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + }); + } else + clientWrapper->didCreateWebSocketChannel(peer.release()); + } + }, taskMode); + + if (!sent) clientWrapper->clearPeer(); - delete peer; - } } void WorkerThreadableWebSocketChannel::Bridge::initialize() { ASSERT(!m_peer); setMethodNotCompleted(); - Ref<Bridge> protect(*this); - m_loaderProxy.postTaskToLoader( - createCallbackTask(&Bridge::mainThreadInitialize, - AllowCrossThreadAccess(&m_loaderProxy), m_workerClientWrapper, m_taskMode)); + Ref<Bridge> protectedThis(*this); + + m_loaderProxy.postTaskToLoader([&loaderProxy = m_loaderProxy, workerClientWrapper = m_workerClientWrapper.copyRef(), taskMode = m_taskMode.isolatedCopy(), provider = m_socketProvider.copyRef()](ScriptExecutionContext& context) mutable { + mainThreadInitialize(context, loaderProxy, WTFMove(workerClientWrapper), taskMode, WTFMove(provider)); + }); waitForMethodCompletion(); + // m_peer may be null when the nested runloop exited before a peer has created. m_peer = m_workerClientWrapper->peer(); if (!m_peer) m_workerClientWrapper->setFailedWebSocketChannelCreation(); } -void WorkerThreadableWebSocketChannel::mainThreadConnect(ScriptExecutionContext* context, Peer* peer, const URL& url, const String& protocol) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); - - peer->connect(url, protocol); -} - void WorkerThreadableWebSocketChannel::Bridge::connect(const URL& url, const String& protocol) { - ASSERT(m_workerClientWrapper); if (!m_peer) return; - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadConnect, AllowCrossThreadAccess(m_peer), url, protocol)); -} -void WorkerThreadableWebSocketChannel::mainThreadSend(ScriptExecutionContext* context, Peer* peer, const String& message) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); + m_loaderProxy.postTaskToLoader([peer = m_peer, url = url.isolatedCopy(), protocol = protocol.isolatedCopy()](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); - peer->send(message); -} - -void WorkerThreadableWebSocketChannel::mainThreadSendArrayBuffer(ScriptExecutionContext* context, Peer* peer, PassOwnPtr<Vector<char>> data) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); - - RefPtr<ArrayBuffer> arrayBuffer = ArrayBuffer::create(data->data(), data->size()); - peer->send(*arrayBuffer); -} - -void WorkerThreadableWebSocketChannel::mainThreadSendBlob(ScriptExecutionContext* context, Peer* peer, const URL& url, const String& type, long long size) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); - - RefPtr<Blob> blob = Blob::create(url, type, size); - peer->send(*blob); + peer->connect(url, protocol); + }); } ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::Bridge::send(const String& message) { - if (!m_workerClientWrapper || !m_peer) + if (!m_peer) return ThreadableWebSocketChannel::SendFail; setMethodNotCompleted(); - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadSend, AllowCrossThreadAccess(m_peer), message)); - Ref<Bridge> protect(*this); + + m_loaderProxy.postTaskToLoader([peer = m_peer, message = message.isolatedCopy()](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); + + peer->send(message); + }); + + Ref<Bridge> protectedThis(*this); waitForMethodCompletion(); - ThreadableWebSocketChannelClientWrapper* clientWrapper = m_workerClientWrapper.get(); - if (!clientWrapper) - return ThreadableWebSocketChannel::SendFail; - return clientWrapper->sendRequestResult(); + return m_workerClientWrapper->sendRequestResult(); } ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::Bridge::send(const ArrayBuffer& binaryData, unsigned byteOffset, unsigned byteLength) { - if (!m_workerClientWrapper || !m_peer) + if (!m_peer) return ThreadableWebSocketChannel::SendFail; + // ArrayBuffer isn't thread-safe, hence the content of ArrayBuffer is copied into Vector<char>. - OwnPtr<Vector<char>> data = adoptPtr(new Vector<char>(byteLength)); + Vector<char> data(byteLength); if (binaryData.byteLength()) - memcpy(data->data(), static_cast<const char*>(binaryData.data()) + byteOffset, byteLength); + memcpy(data.data(), static_cast<const char*>(binaryData.data()) + byteOffset, byteLength); setMethodNotCompleted(); - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadSendArrayBuffer, AllowCrossThreadAccess(m_peer), data.release())); - Ref<Bridge> protect(*this); + + m_loaderProxy.postTaskToLoader([peer = m_peer, data = WTFMove(data)](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); + + auto arrayBuffer = ArrayBuffer::create(data.data(), data.size()); + peer->send(arrayBuffer); + }); + + Ref<Bridge> protectedThis(*this); waitForMethodCompletion(); - ThreadableWebSocketChannelClientWrapper* clientWrapper = m_workerClientWrapper.get(); - if (!clientWrapper) - return ThreadableWebSocketChannel::SendFail; - return clientWrapper->sendRequestResult(); + return m_workerClientWrapper->sendRequestResult(); } -ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::Bridge::send(const Blob& binaryData) +ThreadableWebSocketChannel::SendResult WorkerThreadableWebSocketChannel::Bridge::send(Blob& binaryData) { - if (!m_workerClientWrapper || !m_peer) + if (!m_peer) return ThreadableWebSocketChannel::SendFail; setMethodNotCompleted(); - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadSendBlob, AllowCrossThreadAccess(m_peer), binaryData.url(), binaryData.type(), binaryData.size())); - Ref<Bridge> protect(*this); - waitForMethodCompletion(); - ThreadableWebSocketChannelClientWrapper* clientWrapper = m_workerClientWrapper.get(); - if (!clientWrapper) - return ThreadableWebSocketChannel::SendFail; - return clientWrapper->sendRequestResult(); -} -void WorkerThreadableWebSocketChannel::mainThreadBufferedAmount(ScriptExecutionContext* context, Peer* peer) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); + m_loaderProxy.postTaskToLoader([peer = m_peer, url = binaryData.url().isolatedCopy(), type = binaryData.type().isolatedCopy(), size = binaryData.size()](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); - peer->bufferedAmount(); + peer->send(Blob::deserialize(url, type, size, { })); + }); + + Ref<Bridge> protectedThis(*this); + waitForMethodCompletion(); + return m_workerClientWrapper->sendRequestResult(); } -unsigned long WorkerThreadableWebSocketChannel::Bridge::bufferedAmount() +unsigned WorkerThreadableWebSocketChannel::Bridge::bufferedAmount() { - if (!m_workerClientWrapper || !m_peer) + if (!m_peer) return 0; setMethodNotCompleted(); - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadBufferedAmount, AllowCrossThreadAccess(m_peer))); - Ref<Bridge> protect(*this); - waitForMethodCompletion(); - ThreadableWebSocketChannelClientWrapper* clientWrapper = m_workerClientWrapper.get(); - if (clientWrapper) - return clientWrapper->bufferedAmount(); - return 0; -} -void WorkerThreadableWebSocketChannel::mainThreadClose(ScriptExecutionContext* context, Peer* peer, int code, const String& reason) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); + m_loaderProxy.postTaskToLoader([peer = m_peer](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); - peer->close(code, reason); + peer->bufferedAmount(); + }); + + Ref<Bridge> protectedThis(*this); + waitForMethodCompletion(); + return m_workerClientWrapper->bufferedAmount(); } void WorkerThreadableWebSocketChannel::Bridge::close(int code, const String& reason) { if (!m_peer) return; - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadClose, AllowCrossThreadAccess(m_peer), code, reason)); -} -void WorkerThreadableWebSocketChannel::mainThreadFail(ScriptExecutionContext* context, Peer* peer, const String& reason) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); + m_loaderProxy.postTaskToLoader([peer = m_peer, code, reason = reason.isolatedCopy()](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); - peer->fail(reason); + peer->close(code, reason); + }); } void WorkerThreadableWebSocketChannel::Bridge::fail(const String& reason) { if (!m_peer) return; - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadFail, AllowCrossThreadAccess(m_peer), reason)); -} -void WorkerThreadableWebSocketChannel::mainThreadDestroy(ScriptExecutionContext* context, PassOwnPtr<Peer> peer) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT_UNUSED(peer, peer); + m_loaderProxy.postTaskToLoader([peer = m_peer, reason = reason.isolatedCopy()](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); - // Peer object will be deleted even if the task does not run in the main thread's cleanup period, because - // the destructor for the task object (created by createCallbackTask()) will automatically delete the peer. + peer->fail(reason); + }); } void WorkerThreadableWebSocketChannel::Bridge::disconnect() { clearClientWrapper(); if (m_peer) { - OwnPtr<Peer> peer = adoptPtr(m_peer); - m_peer = 0; - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadDestroy, peer.release())); + m_loaderProxy.postTaskToLoader([peer = std::unique_ptr<Peer>(m_peer)](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + }); + m_peer = nullptr; } - m_workerGlobalScope = 0; -} - -void WorkerThreadableWebSocketChannel::mainThreadSuspend(ScriptExecutionContext* context, Peer* peer) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); - - peer->suspend(); + m_workerGlobalScope = nullptr; } void WorkerThreadableWebSocketChannel::Bridge::suspend() { if (!m_peer) return; - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadSuspend, AllowCrossThreadAccess(m_peer))); -} -void WorkerThreadableWebSocketChannel::mainThreadResume(ScriptExecutionContext* context, Peer* peer) -{ - ASSERT(isMainThread()); - ASSERT_UNUSED(context, context->isDocument()); - ASSERT(peer); + m_loaderProxy.postTaskToLoader([peer = m_peer](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); - peer->resume(); + peer->suspend(); + }); } void WorkerThreadableWebSocketChannel::Bridge::resume() { if (!m_peer) return; - m_loaderProxy.postTaskToLoader(createCallbackTask(&WorkerThreadableWebSocketChannel::mainThreadResume, AllowCrossThreadAccess(m_peer))); + + m_loaderProxy.postTaskToLoader([peer = m_peer](ScriptExecutionContext& context) { + ASSERT(isMainThread()); + ASSERT_UNUSED(context, context.isDocument()); + ASSERT(peer); + + peer->resume(); + }); } void WorkerThreadableWebSocketChannel::Bridge::clearClientWrapper() @@ -635,7 +568,6 @@ void WorkerThreadableWebSocketChannel::Bridge::clearClientWrapper() void WorkerThreadableWebSocketChannel::Bridge::setMethodNotCompleted() { - ASSERT(m_workerClientWrapper); m_workerClientWrapper->clearSyncMethodDone(); } @@ -645,12 +577,12 @@ void WorkerThreadableWebSocketChannel::Bridge::waitForMethodCompletion() { if (!m_workerGlobalScope) return; - WorkerRunLoop& runLoop = m_workerGlobalScope->thread()->runLoop(); + WorkerRunLoop& runLoop = m_workerGlobalScope->thread().runLoop(); MessageQueueWaitResult result = MessageQueueMessageReceived; - ThreadableWebSocketChannelClientWrapper* clientWrapper = m_workerClientWrapper.get(); + ThreadableWebSocketChannelClientWrapper* clientWrapper = m_workerClientWrapper.ptr(); while (m_workerGlobalScope && clientWrapper && !clientWrapper->syncMethodDone() && result != MessageQueueTerminated) { result = runLoop.runInMode(m_workerGlobalScope.get(), m_taskMode); // May cause this bridge to get disconnected, which makes m_workerGlobalScope become null. - clientWrapper = m_workerClientWrapper.get(); + clientWrapper = m_workerClientWrapper.ptr(); } } diff --git a/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.h b/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.h index 90b0ea07e..be90d12f1 100644 --- a/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.h +++ b/Source/WebCore/Modules/websockets/WorkerThreadableWebSocketChannel.h @@ -28,8 +28,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -#ifndef WorkerThreadableWebSocketChannel_h -#define WorkerThreadableWebSocketChannel_h +#pragma once #if ENABLE(WEB_SOCKETS) @@ -37,7 +36,6 @@ #include "WebSocketChannelClient.h" #include "WorkerGlobalScope.h" -#include <wtf/PassRefPtr.h> #include <wtf/RefCounted.h> #include <wtf/RefPtr.h> #include <wtf/Threading.h> @@ -55,41 +53,38 @@ class WorkerRunLoop; class WorkerThreadableWebSocketChannel : public RefCounted<WorkerThreadableWebSocketChannel>, public ThreadableWebSocketChannel { WTF_MAKE_FAST_ALLOCATED; public: - static PassRefPtr<ThreadableWebSocketChannel> create(WorkerGlobalScope* workerGlobalScope, WebSocketChannelClient* client, const String& taskMode) + static Ref<ThreadableWebSocketChannel> create(WorkerGlobalScope& workerGlobalScope, WebSocketChannelClient& client, const String& taskMode, SocketProvider& provider) { - return adoptRef(new WorkerThreadableWebSocketChannel(workerGlobalScope, client, taskMode)); + return adoptRef(*new WorkerThreadableWebSocketChannel(workerGlobalScope, client, taskMode, provider)); } virtual ~WorkerThreadableWebSocketChannel(); // ThreadableWebSocketChannel functions. - virtual void connect(const URL&, const String& protocol) override; - virtual String subprotocol() override; - virtual String extensions() override; - virtual ThreadableWebSocketChannel::SendResult send(const String& message) override; - virtual ThreadableWebSocketChannel::SendResult send(const JSC::ArrayBuffer&, unsigned byteOffset, unsigned byteLength) override; - virtual ThreadableWebSocketChannel::SendResult send(const Blob&) override; - virtual unsigned long bufferedAmount() const override; - virtual void close(int code, const String& reason) override; - virtual void fail(const String& reason) override; - virtual void disconnect() override; // Will suppress didClose(). - virtual void suspend() override; - virtual void resume() override; + void connect(const URL&, const String& protocol) override; + String subprotocol() override; + String extensions() override; + ThreadableWebSocketChannel::SendResult send(const String& message) override; + ThreadableWebSocketChannel::SendResult send(const JSC::ArrayBuffer&, unsigned byteOffset, unsigned byteLength) override; + ThreadableWebSocketChannel::SendResult send(Blob&) override; + unsigned bufferedAmount() const override; + void close(int code, const String& reason) override; + void fail(const String& reason) override; + void disconnect() override; // Will suppress didClose(). + void suspend() override; + void resume() override; // Generated by the bridge. The Peer and its bridge should have identical // lifetimes. class Peer : public WebSocketChannelClient { WTF_MAKE_NONCOPYABLE(Peer); WTF_MAKE_FAST_ALLOCATED; public: - static Peer* create(PassRefPtr<ThreadableWebSocketChannelClientWrapper> clientWrapper, WorkerLoaderProxy& loaderProxy, ScriptExecutionContext* context, const String& taskMode) - { - return new Peer(clientWrapper, loaderProxy, context, taskMode); - } + Peer(Ref<ThreadableWebSocketChannelClientWrapper>&&, WorkerLoaderProxy&, ScriptExecutionContext&, const String& taskMode, SocketProvider&); ~Peer(); void connect(const URL&, const String& protocol); void send(const String& message); void send(const JSC::ArrayBuffer&); - void send(const Blob&); + void send(Blob&); void bufferedAmount(); void close(int code, const String& reason); void fail(const String& reason); @@ -98,18 +93,17 @@ public: void resume(); // WebSocketChannelClient functions. - virtual void didConnect() override; - virtual void didReceiveMessage(const String& message) override; - virtual void didReceiveBinaryData(PassOwnPtr<Vector<char>>) override; - virtual void didUpdateBufferedAmount(unsigned long bufferedAmount) override; - virtual void didStartClosingHandshake() override; - virtual void didClose(unsigned long unhandledBufferedAmount, ClosingHandshakeCompletionStatus, unsigned short code, const String& reason) override; - virtual void didReceiveMessageError() override; + void didConnect() final; + void didReceiveMessage(const String& message) final; + void didReceiveBinaryData(Vector<uint8_t>&&) final; + void didUpdateBufferedAmount(unsigned bufferedAmount) final; + void didStartClosingHandshake() final; + void didClose(unsigned unhandledBufferedAmount, ClosingHandshakeCompletionStatus, unsigned short code, const String& reason) final; + void didReceiveMessageError() final; + void didUpgradeURL() final; private: - Peer(PassRefPtr<ThreadableWebSocketChannelClientWrapper>, WorkerLoaderProxy&, ScriptExecutionContext*, const String& taskMode); - - RefPtr<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; + Ref<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; WorkerLoaderProxy& m_loaderProxy; RefPtr<ThreadableWebSocketChannel> m_mainWebSocketChannel; String m_taskMode; @@ -119,24 +113,24 @@ public: using RefCounted<WorkerThreadableWebSocketChannel>::deref; protected: - virtual void refThreadableWebSocketChannel() { ref(); } - virtual void derefThreadableWebSocketChannel() { deref(); } + void refThreadableWebSocketChannel() override { ref(); } + void derefThreadableWebSocketChannel() override { deref(); } private: // Bridge for Peer. Running on the worker thread. class Bridge : public RefCounted<Bridge> { public: - static PassRefPtr<Bridge> create(PassRefPtr<ThreadableWebSocketChannelClientWrapper> workerClientWrapper, PassRefPtr<WorkerGlobalScope> workerGlobalScope, const String& taskMode) + static Ref<Bridge> create(Ref<ThreadableWebSocketChannelClientWrapper>&& workerClientWrapper, Ref<WorkerGlobalScope>&& workerGlobalScope, const String& taskMode, Ref<SocketProvider>&& provider) { - return adoptRef(new Bridge(workerClientWrapper, workerGlobalScope, taskMode)); + return adoptRef(*new Bridge(WTFMove(workerClientWrapper), WTFMove(workerGlobalScope), taskMode, WTFMove(provider))); } ~Bridge(); void initialize(); void connect(const URL&, const String& protocol); ThreadableWebSocketChannel::SendResult send(const String& message); ThreadableWebSocketChannel::SendResult send(const JSC::ArrayBuffer&, unsigned byteOffset, unsigned byteLength); - ThreadableWebSocketChannel::SendResult send(const Blob&); - unsigned long bufferedAmount(); + ThreadableWebSocketChannel::SendResult send(Blob&); + unsigned bufferedAmount(); void close(int code, const String& reason); void fail(const String& reason); void disconnect(); @@ -147,12 +141,12 @@ private: using RefCounted<Bridge>::deref; private: - Bridge(PassRefPtr<ThreadableWebSocketChannelClientWrapper>, PassRefPtr<WorkerGlobalScope>, const String& taskMode); + Bridge(Ref<ThreadableWebSocketChannelClientWrapper>&&, Ref<WorkerGlobalScope>&&, const String& taskMode, Ref<SocketProvider>&&); - static void setWebSocketChannel(ScriptExecutionContext*, Bridge* thisPtr, Peer*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>); + static void setWebSocketChannel(ScriptExecutionContext*, Bridge* thisPtr, Peer*, Ref<ThreadableWebSocketChannelClientWrapper>&&); // Executed on the main thread to create a Peer for this bridge. - static void mainThreadInitialize(ScriptExecutionContext*, WorkerLoaderProxy*, PassRefPtr<ThreadableWebSocketChannelClientWrapper>, const String& taskMode); + static void mainThreadInitialize(ScriptExecutionContext&, WorkerLoaderProxy&, Ref<ThreadableWebSocketChannelClientWrapper>&&, const String& taskMode, Ref<SocketProvider>&&); // Executed on the worker context's thread. void clearClientWrapper(); @@ -160,35 +154,24 @@ private: void setMethodNotCompleted(); void waitForMethodCompletion(); - RefPtr<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; + Ref<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; RefPtr<WorkerGlobalScope> m_workerGlobalScope; WorkerLoaderProxy& m_loaderProxy; String m_taskMode; - Peer* m_peer; + Peer* m_peer { nullptr }; + Ref<SocketProvider> m_socketProvider; }; - WorkerThreadableWebSocketChannel(WorkerGlobalScope*, WebSocketChannelClient*, const String& taskMode); - - static void mainThreadConnect(ScriptExecutionContext*, Peer*, const URL&, const String& protocol); - static void mainThreadSend(ScriptExecutionContext*, Peer*, const String& message); - static void mainThreadSendArrayBuffer(ScriptExecutionContext*, Peer*, PassOwnPtr<Vector<char>>); - static void mainThreadSendBlob(ScriptExecutionContext*, Peer*, const URL&, const String& type, long long size); - static void mainThreadBufferedAmount(ScriptExecutionContext*, Peer*); - static void mainThreadClose(ScriptExecutionContext*, Peer*, int code, const String& reason); - static void mainThreadFail(ScriptExecutionContext*, Peer*, const String& reason); - static void mainThreadDestroy(ScriptExecutionContext*, PassOwnPtr<Peer>); - static void mainThreadSuspend(ScriptExecutionContext*, Peer*); - static void mainThreadResume(ScriptExecutionContext*, Peer*); + WEBCORE_EXPORT WorkerThreadableWebSocketChannel(WorkerGlobalScope&, WebSocketChannelClient&, const String& taskMode, SocketProvider&); class WorkerGlobalScopeDidInitializeTask; - RefPtr<WorkerGlobalScope> m_workerGlobalScope; - RefPtr<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; + Ref<WorkerGlobalScope> m_workerGlobalScope; + Ref<ThreadableWebSocketChannelClientWrapper> m_workerClientWrapper; RefPtr<Bridge> m_bridge; + Ref<SocketProvider> m_socketProvider; }; } // namespace WebCore #endif // ENABLE(WEB_SOCKETS) - -#endif // WorkerThreadableWebSocketChannel_h |