diff options
Diffstat (limited to 'tests/auto/websockets/qwebsocket/tst_qwebsocket.cpp')
-rw-r--r-- | tests/auto/websockets/qwebsocket/tst_qwebsocket.cpp | 698 |
1 files changed, 602 insertions, 96 deletions
diff --git a/tests/auto/websockets/qwebsocket/tst_qwebsocket.cpp b/tests/auto/websockets/qwebsocket/tst_qwebsocket.cpp index 329419c..6881afb 100644 --- a/tests/auto/websockets/qwebsocket/tst_qwebsocket.cpp +++ b/tests/auto/websockets/qwebsocket/tst_qwebsocket.cpp @@ -1,40 +1,32 @@ -/**************************************************************************** -** -** Copyright (C) 2016 Kurt Pattyn <pattyn.kurt@gmail.com>. -** Contact: https://www.qt.io/licensing/ -** -** This file is part of the test suite of the Qt Toolkit. -** -** $QT_BEGIN_LICENSE:GPL-EXCEPT$ -** Commercial License Usage -** Licensees holding valid commercial Qt licenses may use this file in -** accordance with the commercial license agreement provided with the -** Software or, alternatively, in accordance with the terms contained in -** a written agreement between you and The Qt Company. For licensing terms -** and conditions see https://www.qt.io/terms-conditions. For further -** information use the contact form at https://www.qt.io/contact-us. -** -** GNU General Public License Usage -** Alternatively, this file may be used under the terms of the GNU -** General Public License version 3 as published by the Free Software -** Foundation with exceptions as appearing in the file LICENSE.GPL3-EXCEPT -** included in the packaging of this file. Please review the following -** information to ensure the GNU General Public License requirements will -** be met: https://www.gnu.org/licenses/gpl-3.0.html. -** -** $QT_END_LICENSE$ -** -****************************************************************************/ +// Copyright (C) 2016 Kurt Pattyn <pattyn.kurt@gmail.com>. +// SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0 +#include <QRegularExpression> #include <QString> #include <QtTest> #include <QtWebSockets/QWebSocket> +#include <QtWebSockets/QWebSocketHandshakeOptions> +#include <QtWebSockets/QWebSocketCorsAuthenticator> #include <QtWebSockets/QWebSocketServer> #include <QtWebSockets/qwebsocketprotocol.h> +#include <QtNetwork/qtcpserver.h> +#include <QtNetwork/qauthenticator.h> +#include <QtNetwork/qtcpsocket.h> + +#if QT_CONFIG(ssl) +#include <QtNetwork/qsslserver.h> +#include <QtNetwork/qsslcertificate.h> +#include <QtNetwork/qsslkey.h> +#endif + +#include <utility> + QT_USE_NAMESPACE Q_DECLARE_METATYPE(QWebSocketProtocol::Version) +using namespace Qt::StringLiterals; + class EchoServer : public QObject { Q_OBJECT @@ -50,6 +42,7 @@ public: Q_SIGNALS: void newConnection(QUrl requestUrl); void newConnection(QNetworkRequest request); + void originAuthenticationRequired(QWebSocketCorsAuthenticator* pAuthenticator); private Q_SLOTS: void onNewConnection(); @@ -72,9 +65,13 @@ EchoServer::EchoServer(QObject *parent, quint64 maxAllowedIncomingMessageSize, q m_maxAllowedIncomingFrameSize(maxAllowedIncomingFrameSize), m_clients() { + m_pWebSocketServer->setSupportedSubprotocols({ QStringLiteral("protocol1"), + QStringLiteral("protocol2") }); if (m_pWebSocketServer->listen(QHostAddress(QStringLiteral("127.0.0.1")))) { connect(m_pWebSocketServer, SIGNAL(newConnection()), this, SLOT(onNewConnection())); + connect(m_pWebSocketServer, &QWebSocketServer::originAuthenticationRequired, + this, &EchoServer::originAuthenticationRequired); } } @@ -146,16 +143,23 @@ private Q_SLOTS: void tst_sendTextMessage(); void tst_sendBinaryMessage(); void tst_errorString(); + void tst_openRequest_data(); void tst_openRequest(); + void tst_protocolAccessor(); + void protocolsHeaderGeneration_data(); + void protocolsHeaderGeneration(); void tst_moveToThread(); void tst_moveToThreadNoWarning(); #ifndef QT_NO_NETWORKPROXY void tst_setProxy(); #endif + void authenticationRequired_data(); + void authenticationRequired(); void overlongCloseReason(); void incomingMessageTooLong(); void incomingFrameTooLong(); void testingFrameAndMessageSizeApi(); + void customHeader(); }; tst_QWebSocket::tst_QWebSocket() @@ -324,17 +328,17 @@ void tst_QWebSocket::tst_invalidOpen() QCOMPARE(socket.sendTextMessage(QStringLiteral("A text message")), 0); QCOMPARE(socket.sendBinaryMessage(QByteArrayLiteral("A text message")), 0); - if (errorSpy.count() == 0) + if (errorSpy.size() == 0) QVERIFY(errorSpy.wait()); - QCOMPARE(errorSpy.count(), 1); + QCOMPARE(errorSpy.size(), 1); QList<QVariant> arguments = errorSpy.takeFirst(); QAbstractSocket::SocketError socketError = qvariant_cast<QAbstractSocket::SocketError>(arguments.at(0)); QCOMPARE(socketError, QAbstractSocket::ConnectionRefusedError); - QCOMPARE(aboutToCloseSpy.count(), 0); - QCOMPARE(connectedSpy.count(), 0); - QCOMPARE(disconnectedSpy.count(), disconnectedCount); - QCOMPARE(stateChangedSpy.count(), stateChangedCount); + QCOMPARE(aboutToCloseSpy.size(), 0); + QCOMPARE(connectedSpy.size(), 0); + QCOMPARE(disconnectedSpy.size(), disconnectedCount); + QCOMPARE(stateChangedSpy.size(), stateChangedCount); if (stateChangedCount == 2) { arguments = stateChangedSpy.takeFirst(); QAbstractSocket::SocketState socketState = @@ -343,13 +347,13 @@ void tst_QWebSocket::tst_invalidOpen() socketState = qvariant_cast<QAbstractSocket::SocketState>(arguments.at(0)); QCOMPARE(socketState, QAbstractSocket::UnconnectedState); } - QCOMPARE(readChannelFinishedSpy.count(), 0); - QCOMPARE(textFrameReceivedSpy.count(), 0); - QCOMPARE(binaryFrameReceivedSpy.count(), 0); - QCOMPARE(textMessageReceivedSpy.count(), 0); - QCOMPARE(binaryMessageReceivedSpy.count(), 0); - QCOMPARE(pongSpy.count(), 0); - QCOMPARE(bytesWrittenSpy.count(), 0); + QCOMPARE(readChannelFinishedSpy.size(), 0); + QCOMPARE(textFrameReceivedSpy.size(), 0); + QCOMPARE(binaryFrameReceivedSpy.size(), 0); + QCOMPARE(textMessageReceivedSpy.size(), 0); + QCOMPARE(binaryMessageReceivedSpy.size(), 0); + QCOMPARE(pongSpy.size(), 0); + QCOMPARE(bytesWrittenSpy.size(), 0); } void tst_QWebSocket::tst_invalidOrigin() @@ -390,28 +394,28 @@ void tst_QWebSocket::tst_invalidOrigin() QVERIFY(errorSpy.wait()); - QCOMPARE(errorSpy.count(), 1); + QCOMPARE(errorSpy.size(), 1); QList<QVariant> arguments = errorSpy.takeFirst(); QAbstractSocket::SocketError socketError = qvariant_cast<QAbstractSocket::SocketError>(arguments.at(0)); QCOMPARE(socketError, QAbstractSocket::ConnectionRefusedError); - QCOMPARE(aboutToCloseSpy.count(), 0); - QCOMPARE(connectedSpy.count(), 0); - QCOMPARE(disconnectedSpy.count(), 1); - QCOMPARE(stateChangedSpy.count(), 2); //connectingstate, unconnectedstate + QCOMPARE(aboutToCloseSpy.size(), 0); + QCOMPARE(connectedSpy.size(), 0); + QCOMPARE(disconnectedSpy.size(), 1); + QCOMPARE(stateChangedSpy.size(), 2); //connectingstate, unconnectedstate arguments = stateChangedSpy.takeFirst(); QAbstractSocket::SocketState socketState = qvariant_cast<QAbstractSocket::SocketState>(arguments.at(0)); arguments = stateChangedSpy.takeFirst(); socketState = qvariant_cast<QAbstractSocket::SocketState>(arguments.at(0)); QCOMPARE(socketState, QAbstractSocket::UnconnectedState); - QCOMPARE(readChannelFinishedSpy.count(), 0); - QCOMPARE(textFrameReceivedSpy.count(), 0); - QCOMPARE(binaryFrameReceivedSpy.count(), 0); - QCOMPARE(textMessageReceivedSpy.count(), 0); - QCOMPARE(binaryMessageReceivedSpy.count(), 0); - QCOMPARE(pongSpy.count(), 0); - QCOMPARE(bytesWrittenSpy.count(), 0); + QCOMPARE(readChannelFinishedSpy.size(), 0); + QCOMPARE(textFrameReceivedSpy.size(), 0); + QCOMPARE(binaryFrameReceivedSpy.size(), 0); + QCOMPARE(textMessageReceivedSpy.size(), 0); + QCOMPARE(binaryMessageReceivedSpy.size(), 0); + QCOMPARE(pongSpy.size(), 0); + QCOMPARE(bytesWrittenSpy.size(), 0); } void tst_QWebSocket::tst_sendTextMessage() @@ -440,8 +444,8 @@ void tst_QWebSocket::tst_sendTextMessage() socket.open(url); - QTRY_COMPARE(socketConnectedSpy.count(), 1); - QCOMPARE(socketError.count(), 0); + QTRY_COMPARE(socketConnectedSpy.size(), 1); + QCOMPARE(socketError.size(), 0); QCOMPARE(socket.state(), QAbstractSocket::ConnectedState); QList<QVariant> arguments = serverConnectedSpy.takeFirst(); QUrl urlConnected = arguments.at(0).toUrl(); @@ -454,14 +458,14 @@ void tst_QWebSocket::tst_sendTextMessage() QVERIFY(textMessageReceived.wait(500)); QCOMPARE(socket.bytesToWrite(), 0); - QCOMPARE(textMessageReceived.count(), 1); - QCOMPARE(binaryMessageReceived.count(), 0); - QCOMPARE(binaryFrameReceived.count(), 0); + QCOMPARE(textMessageReceived.size(), 1); + QCOMPARE(binaryMessageReceived.size(), 0); + QCOMPARE(binaryFrameReceived.size(), 0); arguments = textMessageReceived.takeFirst(); QString messageReceived = arguments.at(0).toString(); QCOMPARE(messageReceived, QStringLiteral("Hello world!")); - QCOMPARE(textFrameReceived.count(), 1); + QCOMPARE(textFrameReceived.size(), 1); arguments = textFrameReceived.takeFirst(); QString frameReceived = arguments.at(0).toString(); bool isLastFrame = arguments.at(1).toBool(); @@ -476,8 +480,8 @@ void tst_QWebSocket::tst_sendTextMessage() // QTBUG-74464 QWebsocket doesn't receive text (binary) message with size > 32 kb socket.open(url); - QTRY_COMPARE(socketConnectedSpy.count(), 1); - QCOMPARE(socketError.count(), 0); + QTRY_COMPARE(socketConnectedSpy.size(), 1); + QCOMPARE(socketError.size(), 0); QCOMPARE(socket.state(), QAbstractSocket::ConnectedState); arguments = serverConnectedSpy.takeFirst(); urlConnected = arguments.at(0).toUrl(); @@ -487,16 +491,16 @@ void tst_QWebSocket::tst_sendTextMessage() // transmit a long text message with 1 MB QString longString(0x100000, 'a'); socket.sendTextMessage(longString); - QVERIFY(socket.bytesToWrite() > longString.length()); + QVERIFY(socket.bytesToWrite() > longString.size()); QVERIFY(textMessageReceived.wait()); QCOMPARE(socket.bytesToWrite(), 0); - QCOMPARE(textMessageReceived.count(), 1); - QCOMPARE(binaryMessageReceived.count(), 0); - QCOMPARE(binaryFrameReceived.count(), 0); + QCOMPARE(textMessageReceived.size(), 1); + QCOMPARE(binaryMessageReceived.size(), 0); + QCOMPARE(binaryFrameReceived.size(), 0); arguments = textMessageReceived.takeFirst(); messageReceived = arguments.at(0).toString(); - QCOMPARE(messageReceived.length(), longString.length()); + QCOMPARE(messageReceived.size(), longString.size()); QCOMPARE(messageReceived, longString); arguments = textFrameReceived.takeLast(); @@ -512,20 +516,20 @@ void tst_QWebSocket::tst_sendTextMessage() socket.open(QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + QStringLiteral(":") + QString::number(echoServer.port()))); - QTRY_COMPARE(socketConnectedSpy.count(), 1); + QTRY_COMPARE(socketConnectedSpy.size(), 1); QCOMPARE(socket.state(), QAbstractSocket::ConnectedState); socket.sendTextMessage(QStringLiteral("Hello world!")); QVERIFY(textMessageReceived.wait(500)); - QCOMPARE(textMessageReceived.count(), 1); - QCOMPARE(binaryMessageReceived.count(), 0); - QCOMPARE(binaryFrameReceived.count(), 0); + QCOMPARE(textMessageReceived.size(), 1); + QCOMPARE(binaryMessageReceived.size(), 0); + QCOMPARE(binaryFrameReceived.size(), 0); arguments = textMessageReceived.takeFirst(); messageReceived = arguments.at(0).toString(); QCOMPARE(messageReceived, QStringLiteral("Hello world!")); - QCOMPARE(textFrameReceived.count(), 1); + QCOMPARE(textFrameReceived.size(), 1); arguments = textFrameReceived.takeFirst(); frameReceived = arguments.at(0).toString(); isLastFrame = arguments.at(1).toBool(); @@ -556,7 +560,7 @@ void tst_QWebSocket::tst_sendBinaryMessage() socket.open(QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + QStringLiteral(":") + QString::number(echoServer.port()))); - QTRY_COMPARE(socketConnectedSpy.count(), 1); + QTRY_COMPARE(socketConnectedSpy.size(), 1); QCOMPARE(socket.state(), QAbstractSocket::ConnectedState); QCOMPARE(socket.bytesToWrite(), 0); @@ -566,14 +570,14 @@ void tst_QWebSocket::tst_sendBinaryMessage() QVERIFY(binaryMessageReceived.wait(500)); QCOMPARE(socket.bytesToWrite(), 0); - QCOMPARE(textMessageReceived.count(), 0); - QCOMPARE(textFrameReceived.count(), 0); - QCOMPARE(binaryMessageReceived.count(), 1); + QCOMPARE(textMessageReceived.size(), 0); + QCOMPARE(textFrameReceived.size(), 0); + QCOMPARE(binaryMessageReceived.size(), 1); QList<QVariant> arguments = binaryMessageReceived.takeFirst(); QByteArray messageReceived = arguments.at(0).toByteArray(); QCOMPARE(messageReceived, QByteArrayLiteral("Hello world!")); - QCOMPARE(binaryFrameReceived.count(), 1); + QCOMPARE(binaryFrameReceived.size(), 1); arguments = binaryFrameReceived.takeFirst(); QByteArray frameReceived = arguments.at(0).toByteArray(); bool isLastFrame = arguments.at(1).toBool(); @@ -590,20 +594,20 @@ void tst_QWebSocket::tst_sendBinaryMessage() socket.open(QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + QStringLiteral(":") + QString::number(echoServer.port()))); - QTRY_COMPARE(socketConnectedSpy.count(), 1); + QTRY_COMPARE(socketConnectedSpy.size(), 1); QCOMPARE(socket.state(), QAbstractSocket::ConnectedState); socket.sendBinaryMessage(QByteArrayLiteral("Hello world!")); QVERIFY(binaryMessageReceived.wait(500)); - QCOMPARE(textMessageReceived.count(), 0); - QCOMPARE(textFrameReceived.count(), 0); - QCOMPARE(binaryMessageReceived.count(), 1); + QCOMPARE(textMessageReceived.size(), 0); + QCOMPARE(textFrameReceived.size(), 0); + QCOMPARE(binaryMessageReceived.size(), 1); arguments = binaryMessageReceived.takeFirst(); messageReceived = arguments.at(0).toByteArray(); QCOMPARE(messageReceived, QByteArrayLiteral("Hello world!")); - QCOMPARE(binaryFrameReceived.count(), 1); + QCOMPARE(binaryFrameReceived.size(), 1); arguments = binaryFrameReceived.takeFirst(); frameReceived = arguments.at(0).toByteArray(); isLastFrame = arguments.at(1).toBool(); @@ -623,16 +627,63 @@ void tst_QWebSocket::tst_errorString() socket.open(QUrl(QStringLiteral("ws://someserver.on.mars:9999"))); - QTRY_COMPARE_WITH_TIMEOUT(errorSpy.count(), 1, 10000); + QTRY_COMPARE_WITH_TIMEOUT(errorSpy.size(), 1, 10000); QList<QVariant> arguments = errorSpy.takeFirst(); QAbstractSocket::SocketError socketError = qvariant_cast<QAbstractSocket::SocketError>(arguments.at(0)); QCOMPARE(socketError, QAbstractSocket::HostNotFoundError); QCOMPARE(socket.errorString(), QStringLiteral("Host not found")); + + // Check that handshake status code is parsed. The error is triggered by + // refusing the origin authentication + EchoServer echoServer; + errorSpy.clear(); + QSignalSpy socketConnectedSpy(&socket, SIGNAL(connected())); + QSignalSpy serverConnectedSpy(&echoServer, SIGNAL(newConnection(QUrl))); + connect(&echoServer, &EchoServer::originAuthenticationRequired, + &socket, [](QWebSocketCorsAuthenticator* pAuthenticator){ + pAuthenticator->setAllowed(false); + }); + + socket.open(QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + + QStringLiteral(":") + QString::number(echoServer.port()))); + QTRY_VERIFY(errorSpy.size() > 0); + QCOMPARE(serverConnectedSpy.size(), 0); + QCOMPARE(socketConnectedSpy.size(), 0); + QCOMPARE(socket.errorString(), + QStringLiteral("QWebSocketPrivate::processHandshake: Unhandled http status code: 403" + " (Access Forbidden).")); +} + +void tst_QWebSocket::tst_openRequest_data() +{ + QTest::addColumn<QStringList>("subprotocols"); + QTest::addColumn<QString>("subprotocolHeader"); + QTest::addColumn<QRegularExpression>("warningExpression"); + + QTest::addRow("no subprotocols") << QStringList{} << QString{} << QRegularExpression{}; + QTest::addRow("single subprotocol") << QStringList{"foobar"} << QStringLiteral("foobar") + << QRegularExpression{}; + QTest::addRow("multiple subprotocols") << QStringList{"foo", "bar"} + << QStringLiteral("foo, bar") + << QRegularExpression{}; + QTest::addRow("subprotocol with whitespace") + << QStringList{"chat", "foo\r\nbar with space"} + << QStringLiteral("chat") + << QRegularExpression{".*invalid.*bar with space"}; + + QTest::addRow("subprotocol with invalid chars") + << QStringList{"chat", "foo{}"} + << QStringLiteral("chat") + << QRegularExpression{".*invalid.*foo"}; } void tst_QWebSocket::tst_openRequest() { + QFETCH(QStringList, subprotocols); + QFETCH(QString, subprotocolHeader); + QFETCH(QRegularExpression, warningExpression); + EchoServer echoServer; QWebSocket socket; @@ -647,18 +698,119 @@ void tst_QWebSocket::tst_openRequest() url.setQuery(query); QNetworkRequest req(url); req.setRawHeader("X-Custom-Header", "A custom header"); - socket.open(req); + QWebSocketHandshakeOptions options; + options.setSubprotocols(subprotocols); - QTRY_COMPARE(socketConnectedSpy.count(), 1); - QTRY_COMPARE(serverRequestSpy.count(), 1); + if (!warningExpression.pattern().isEmpty()) + QTest::ignoreMessage(QtWarningMsg, warningExpression); + + socket.open(req, options); + + QTRY_COMPARE(socketConnectedSpy.size(), 1); + QTRY_COMPARE(serverRequestSpy.size(), 1); QCOMPARE(socket.state(), QAbstractSocket::ConnectedState); QList<QVariant> arguments = serverRequestSpy.takeFirst(); QNetworkRequest requestConnected = arguments.at(0).value<QNetworkRequest>(); QCOMPARE(requestConnected.url(), req.url()); QCOMPARE(requestConnected.rawHeader("X-Custom-Header"), req.rawHeader("X-Custom-Header")); + + if (subprotocols.isEmpty()) + QVERIFY(!requestConnected.hasRawHeader("Sec-WebSocket-Protocol")); + else + QCOMPARE(requestConnected.rawHeader("Sec-WebSocket-Protocol"), subprotocolHeader); + + + socket.close(); +} + +void tst_QWebSocket::tst_protocolAccessor() +{ + EchoServer echoServer; + + QWebSocket socket; + + QUrl url = QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + + QLatin1Char(':') + QString::number(echoServer.port())); + + QWebSocketHandshakeOptions options; + options.setSubprotocols({ "foo", "protocol2" }); + + socket.open(url, options); + + QTRY_COMPARE(socket.state(), QAbstractSocket::ConnectedState); + + QCOMPARE(socket.subprotocol(), "protocol2"); + socket.close(); } +void tst_QWebSocket::protocolsHeaderGeneration_data() +{ + QTest::addColumn<QStringList>("subprotocols"); + QTest::addColumn<int>("numInvalidEntries"); + + using QSL = QStringList; + QTest::addRow("all-invalid") << QSL{ "hello?", "------,,,,------" } << 2; + QTest::addRow("one-valid") << QSL{ "hello?", "ImValid" } << 1; + QTest::addRow("all-valid") << QSL{ "hello", "ImValid" } << 0; +} + +// We test that the Sec-WebSocket-Protocol header is generated normally in presence +// of one or more invalid entries. That is, it should not be included at all +// if there are no valid entries, and there should be no separators with only +// one valid entry. +void tst_QWebSocket::protocolsHeaderGeneration() +{ + QFETCH(const QStringList, subprotocols); + QFETCH(const int, numInvalidEntries); + const bool containsValidEntry = numInvalidEntries != subprotocols.size(); + + QTcpServer tcpServer; + QVERIFY(tcpServer.listen()); + + QWebSocket socket; + + QUrl url = QUrl("ws://127.0.0.1:%1"_L1.arg(QString::number(tcpServer.serverPort()))); + + QWebSocketHandshakeOptions options; + options.setSubprotocols(subprotocols); + + QCOMPARE(options.subprotocols().size(), subprotocols.size()); + for (int i = 0; i < numInvalidEntries; ++i) { + QTest::ignoreMessage(QtMsgType::QtWarningMsg, + QRegularExpression("Ignoring invalid WebSocket subprotocol name \".*\"")); + } + socket.open(url, options); + + QTRY_VERIFY(tcpServer.hasPendingConnections()); + QTcpSocket *serverSocket = tcpServer.nextPendingConnection(); + QVERIFY(serverSocket); + + bool hasSeenHeader = false; + while (serverSocket->state() == QAbstractSocket::ConnectedState) { + if (!serverSocket->canReadLine()) { + QTRY_VERIFY2(serverSocket->canReadLine(), + "Reached end-of-data without seeing end-of-header!"); + } + const QByteArray fullLine = serverSocket->readLine(); + QByteArrayView line = fullLine; + if (line == "\r\n") // End-of-Header + break; + QByteArrayView headerPrefix = "Sec-WebSocket-Protocol:"; + if (line.size() < headerPrefix.size()) + continue; + if (line.first(headerPrefix.size()).compare(headerPrefix, Qt::CaseInsensitive) != 0) + continue; + hasSeenHeader = true; + QByteArrayView protocols = line.sliced(headerPrefix.size()).trimmed(); + QVERIFY(!protocols.empty()); + QCOMPARE(protocols.count(','), subprotocols.size() - numInvalidEntries - 1); + // Keep going in case we encounter the header again + } + QCOMPARE(hasSeenHeader, containsValidEntry); + serverSocket->disconnectFromHost(); +} + class WebSocket : public QWebSocket { Q_OBJECT @@ -730,7 +882,7 @@ void tst_QWebSocket::tst_moveToThread() url.setQuery(query); socket->asyncOpen(url); - if (socketConnectedSpy.count() == 0) + if (socketConnectedSpy.size() == 0) QVERIFY(socketConnectedSpy.wait(500)); socket->asyncSendTextMessage(textMessage); @@ -779,6 +931,309 @@ void tst_QWebSocket::tst_setProxy() socket.setProxy(proxy); QCOMPARE(socket.proxy(), proxy); } +#endif // QT_NO_NETWORKPROXY + +class AuthServer : public QTcpServer +{ + Q_OBJECT +public: + AuthServer() + { + connect(this, &QTcpServer::pendingConnectionAvailable, this, &AuthServer::handleConnection); + } + + void incomingConnection(qintptr sockfd) override + { + if (withEncryption) { +#if QT_CONFIG(ssl) + auto *sslSocket = new QSslSocket(this); + connect(sslSocket, &QSslSocket::encrypted, this, + [this, sslSocket]() { + addPendingConnection(sslSocket); + }); + sslSocket->setSslConfiguration(configuration); + sslSocket->setSocketDescriptor(sockfd); + sslSocket->startServerEncryption(); +#else + QFAIL("withEncryption should not be 'true' if we don't have TLS"); +#endif + } else { + QTcpSocket *socket = new QTcpSocket(this); + socket->setSocketDescriptor(sockfd); + addPendingConnection(socket); + } + } + + void handleConnection() + { + QTcpSocket *serverSocket = nextPendingConnection(); + connect(serverSocket, &QTcpSocket::readyRead, this, &AuthServer::handleReadyRead); + } + + void handleReadyRead() + { + auto *serverSocket = qobject_cast<QTcpSocket *>(sender()); + incomingData.append(serverSocket->readAll()); + if (finished) { + qWarning() << "Unexpected trailing data..." << incomingData; + return; + } + if (!incomingData.contains("\r\n\r\n")) { + qDebug("Not all of the data arrived at once, waiting for more..."); + return; + } + // Move incomingData into local variable and reset it since we received it all: + const QByteArray fullHeader = std::exchange(incomingData, {}); + + QLatin1StringView authView = getHeaderValue("Authorization"_L1, fullHeader); + if (authView.isEmpty()) + return writeAuthRequired(serverSocket); + qsizetype sep = authView.indexOf(' '); + if (sep == -1) + return writeAuthRequired(serverSocket); + QLatin1StringView authenticateMethod = authView.first(sep); + QLatin1StringView authenticateAttempt = authView.sliced(sep + 1); + if (authenticateMethod != "Basic" || authenticateAttempt != expectedBasicPayload()) + return writeAuthRequired(serverSocket); + + QLatin1StringView keyView = getHeaderValue("Sec-WebSocket-Key"_L1, fullHeader); + QVERIFY(!keyView.isEmpty()); + + const QByteArray accept = + QByteArrayView(keyView) % QByteArrayLiteral("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + auto generatedKey = QCryptographicHash::hash(accept, QCryptographicHash::Sha1).toBase64(); + serverSocket->write("HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " % generatedKey % "\r\n" + "\r\n"); + finished = true; + } + + void writeAuthRequired(QTcpSocket *socket) const + { + QByteArray payload = "HTTP/1.1 401 UNAUTHORIZED\r\n" + "WWW-Authenticate: Basic realm=shadow\r\n"; + if (withConnectionClose) + payload.append("Connection: Close\r\n"); + else if (withContentLength) + payload.append("Content-Length: " % QByteArray::number(body.size()) % "\r\n"); + payload.append("\r\n"); + + if (withBody) + payload.append(body); + + socket->write(payload); + if (withConnectionClose) + socket->disconnectFromHost(); + } + + static QLatin1StringView getHeaderValue(const QLatin1StringView keyHeader, + const QByteArrayView fullHeader) + { + const auto fullHeaderView = QLatin1StringView(fullHeader); + const qsizetype headerStart = fullHeaderView.indexOf(keyHeader, 0, Qt::CaseInsensitive); + if (headerStart == -1) + return {}; + qsizetype valueStart = headerStart + keyHeader.size(); + Q_ASSERT(fullHeaderView.size() > valueStart); + Q_ASSERT(fullHeaderView[valueStart] == ':'); + ++valueStart; + const qsizetype valueEnd = fullHeaderView.indexOf(QLatin1StringView("\r\n"), valueStart); + if (valueEnd == -1) + return {}; + return fullHeaderView.sliced(valueStart, valueEnd - valueStart).trimmed(); + } + + static QByteArray expectedBasicPayload() + { + return QByteArray(user % ':' % password).toBase64(); + } + + static constexpr QByteArrayView user = "user"; + static constexpr QByteArrayView password = "password"; + static constexpr QUtf8StringView body = "Authorization required"; + + bool withBody = false; + bool withContentLength = true; + bool withConnectionClose = false; + bool withEncryption = false; +#if QT_CONFIG(ssl) + QSslConfiguration configuration; +#endif + +private: + QByteArray incomingData; + bool finished = false; +}; + +struct ServerScenario { + QByteArrayView label; + bool withContentLength = false; + bool withBody = false; + bool withConnectionClose = false; + bool withEncryption = false; +}; +struct Credentials { QString username, password; }; +struct ClientScenario { + QByteArrayView label; + Credentials urlCredentials; + QVector<Credentials> callbackCredentials; + bool expectSuccess = true; +}; + +void tst_QWebSocket::authenticationRequired_data() +{ + const QString correctUser = QString::fromUtf8(AuthServer::user.toByteArray()); + const QString correctPassword = QString::fromUtf8(AuthServer::password.toByteArray()); + + QTest::addColumn<ServerScenario>("serverScenario"); + QTest::addColumn<ClientScenario>("clientScenario"); + + // Need to test multiple server scenarios: + // 1. Normal server (connection: keep-alive, Content-Length) + // 2. Older server (connection: close, Content-Length) + // 3. Even older server (connection: close, no Content-Length) + // 4. Strange server (connection: close, no Content-Length, no body) + // 5. Quiet server (connection: keep-alive, no Content-Length, no body) + ServerScenario serverScenarios[] = { + { "normal-server", true, true, false, false }, + { "connection-close", true, true, true, false }, + { "connection-close-no-content-length", false, true, true, false }, + { "connection-close-no-content-length-no-body", false, false, true, false }, + { "keep-alive-no-content-length-no-body", false, false, false, false }, + }; + + // And some client scenarios + // 1. User/pass supplied in url + // 2. User/pass supplied in callback + // 3. _Wrong_ user/pass supplied in URL, correct in callback + // 4. _Wrong_ user/pass supplied in URL, _wrong_ supplied in callback + // 5. No user/pass supplied in URL, nothing supplied in callback + // 5. No user/pass supplied in URL, wrong, then correct, supplied in callback + ClientScenario clientScenarios[]{ + { "url-ok", {correctUser, correctPassword}, {} }, + { "callback-ok", {}, { {correctUser, correctPassword } } }, + { "url-wrong-callback-ok", {u"admin"_s, u"admin"_s}, { {correctUser, correctPassword} } }, + { "url-wrong-callback-wrong", {u"admin"_s, u"admin"_s}, { {u"test"_s, u"test"_s} }, false }, + { "no-creds", {{}, {}}, {}, false }, + { "url-wrong-callback-2-ok", {u"admin"_s, u"admin"_s}, { {u"test"_s, u"test"_s}, {correctUser , correctPassword} } }, + }; + + for (auto &server : serverScenarios) { + for (auto &client : clientScenarios) { + QTest::addRow("Server:%s,Client:%s", server.label.data(), client.label.data()) + << server << client; + } + } +#if QT_CONFIG(ssl) + if (!QSslSocket::supportsSsl()) { + qDebug("Skipping the SslServer part of this test because proper TLS is not supported."); + return; + } + // And double that, but now with TLS + for (auto &server : serverScenarios) { + server.withEncryption = true; + for (auto &client : clientScenarios) { + QTest::addRow("SslServer:%s,Client:%s", server.label.data(), client.label.data()) + << server << client; + } + } +#endif +} + +void tst_QWebSocket::authenticationRequired() +{ + QFETCH(const ServerScenario, serverScenario); + QFETCH(const ClientScenario, clientScenario); + + int credentialIndex = 0; + auto handleAuthenticationRequired = [&clientScenario, + &credentialIndex](QAuthenticator *authenticator) { + if (credentialIndex == clientScenario.callbackCredentials.size()) { + if (clientScenario.expectSuccess) + QFAIL("Ran out of credentials to try, but failed to authorize!"); + if (clientScenario.callbackCredentials.isEmpty()) + return; + // If we don't expect to succeed, retry the last returned credentials. + // QAuthenticator should notice there is no change in user/pass and + // ignore it, leading to authentication failure. + --credentialIndex; + } + // Verify that realm parsing works: + QCOMPARE_EQ(authenticator->realm(), u"shadow"_s); + + Credentials credentials = clientScenario.callbackCredentials[credentialIndex++]; + authenticator->setUser(credentials.username); + authenticator->setPassword(credentials.password); + }; + + AuthServer server; + server.withBody = serverScenario.withBody; + server.withContentLength = serverScenario.withContentLength; + server.withConnectionClose = serverScenario.withConnectionClose; + server.withEncryption = serverScenario.withEncryption; +#if QT_CONFIG(ssl) + if (serverScenario.withEncryption) { + QSslConfiguration config = QSslConfiguration::defaultConfiguration(); + QList<QSslCertificate> certificates = QSslCertificate::fromPath(u":/localhost.cert"_s); + QVERIFY(!certificates.isEmpty()); + config.setLocalCertificateChain(certificates); + QFile keyFile(u":/localhost.key"_s); + QVERIFY(keyFile.open(QIODevice::ReadOnly)); + config.setPrivateKey(QSslKey(keyFile.readAll(), QSsl::Rsa)); + server.configuration = config; + } +#endif + + QVERIFY(server.listen()); + QUrl url = QUrl(u"ws://127.0.0.1"_s); + if (serverScenario.withEncryption) + url.setScheme(u"wss"_s); + url.setPort(server.serverPort()); + url.setUserName(clientScenario.urlCredentials.username); + url.setPassword(clientScenario.urlCredentials.password); + + QWebSocket socket; + QSignalSpy connectedSpy(&socket, &QWebSocket::connected); + QSignalSpy errorSpy(&socket, &QWebSocket::errorOccurred); + QSignalSpy stateChangedSpy(&socket, &QWebSocket::stateChanged); + connect(&socket, &QWebSocket::authenticationRequired, &socket, handleAuthenticationRequired); +#if QT_CONFIG(ssl) + if (serverScenario.withEncryption) { + auto config = socket.sslConfiguration(); + config.setPeerVerifyMode(QSslSocket::VerifyNone); + socket.setSslConfiguration(config); + QObject::connect(&socket, &QWebSocket::sslErrors, &socket, + qOverload<>(&QWebSocket::ignoreSslErrors)); + } +#endif + socket.open(url); + + if (clientScenario.expectSuccess) { + // Wait for connected! + QTRY_COMPARE_EQ(connectedSpy.size(), 1); + QCOMPARE_EQ(errorSpy.size(), 0); + // connecting->connected + const int ExpectedStateChanges = 2; + QTRY_COMPARE_EQ(stateChangedSpy.size(), ExpectedStateChanges); + auto firstState = stateChangedSpy.at(0).front().value<QAbstractSocket::SocketState>(); + QCOMPARE_EQ(firstState, QAbstractSocket::ConnectingState); + auto secondState = stateChangedSpy.at(1).front().value<QAbstractSocket::SocketState>(); + QCOMPARE_EQ(secondState, QAbstractSocket::ConnectedState); + } else { + // Wait for error! + QTRY_COMPARE_EQ(errorSpy.size(), 1); + QCOMPARE_EQ(connectedSpy.size(), 0); + // connecting->unconnected + const int ExpectedStateChanges = 2; + QTRY_COMPARE_EQ(stateChangedSpy.size(), ExpectedStateChanges); + auto firstState = stateChangedSpy.at(0).front().value<QAbstractSocket::SocketState>(); + QCOMPARE_EQ(firstState, QAbstractSocket::ConnectingState); + auto secondState = stateChangedSpy.at(1).front().value<QAbstractSocket::SocketState>(); + QCOMPARE_EQ(secondState, QAbstractSocket::UnconnectedState); + } +} void tst_QWebSocket::overlongCloseReason() { @@ -796,16 +1251,16 @@ void tst_QWebSocket::overlongCloseReason() QUrl url = QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + QStringLiteral(":") + QString::number(echoServer.port())); socket.open(url); - QTRY_COMPARE(socketConnectedSpy.count(), 1); - QTRY_COMPARE(serverConnectedSpy.count(), 1); + QTRY_COMPARE(socketConnectedSpy.size(), 1); + QTRY_COMPARE(serverConnectedSpy.size(), 1); const QString reason(200, QChar::fromLatin1('a')); socket.close(QWebSocketProtocol::CloseCodeGoingAway, reason); QCOMPARE(socket.closeCode(), QWebSocketProtocol::CloseCodeGoingAway); // Max length of a control frame is 125, but 2 bytes are used for the close code: - QCOMPARE(socket.closeReason().length(), 123); + QCOMPARE(socket.closeReason().size(), 123); QCOMPARE(socket.closeReason(), reason.left(123)); - QTRY_COMPARE(socketDisconnectedSpy.count(), 1); + QTRY_COMPARE(socketDisconnectedSpy.size(), 1); } void tst_QWebSocket::incomingMessageTooLong() @@ -825,13 +1280,13 @@ void tst_QWebSocket::incomingMessageTooLong() QUrl url = QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + QStringLiteral(":") + QString::number(echoServer.port())); socket.open(url); - QTRY_COMPARE(socketConnectedSpy.count(), 1); - QTRY_COMPARE(serverConnectedSpy.count(), 1); + QTRY_COMPARE(socketConnectedSpy.size(), 1); + QTRY_COMPARE(serverConnectedSpy.size(), 1); QString payload(maxAllowedIncomingMessageSize+1, 'a'); QCOMPARE(socket.sendTextMessage(payload), payload.size()); - QTRY_COMPARE(socketDisconnectedSpy.count(), 1); + QTRY_COMPARE(socketDisconnectedSpy.size(), 1); QCOMPARE(socket.closeCode(), QWebSocketProtocol::CloseCodeTooMuchData); } @@ -853,13 +1308,13 @@ void tst_QWebSocket::incomingFrameTooLong() QUrl url = QUrl(QStringLiteral("ws://") + echoServer.hostAddress().toString() + QStringLiteral(":") + QString::number(echoServer.port())); socket.open(url); - QTRY_COMPARE(socketConnectedSpy.count(), 1); - QTRY_COMPARE(serverConnectedSpy.count(), 1); + QTRY_COMPARE(socketConnectedSpy.size(), 1); + QTRY_COMPARE(serverConnectedSpy.size(), 1); QString payload(maxAllowedIncomingFrameSize+1, 'a'); QCOMPARE(socket.sendTextMessage(payload), payload.size()); - QTRY_COMPARE(socketDisconnectedSpy.count(), 1); + QTRY_COMPARE(socketDisconnectedSpy.size(), 1); QCOMPARE(socket.closeCode(), QWebSocketProtocol::CloseCodeTooMuchData); } @@ -881,7 +1336,58 @@ void tst_QWebSocket::testingFrameAndMessageSizeApi() QTRY_COMPARE(maxAllowedIncomingMessageSize, socket.maxAllowedIncomingMessageSize()); } -#endif // QT_NO_NETWORKPROXY +void tst_QWebSocket::customHeader() +{ + QTcpServer server; + QSignalSpy serverSpy(&server, &QTcpServer::newConnection); + + server.listen(); + QUrl url = QUrl(QStringLiteral("ws://127.0.0.1")); + url.setPort(server.serverPort()); + + QNetworkRequest request(url); + request.setRawHeader("CustomHeader", "Example"); + QWebSocket socket; + socket.open(request); + + // Custom websocket server below (needed because a QWebSocketServer on + // localhost doesn't show the issue): + QVERIFY(serverSpy.wait()); + QTcpSocket *serverSocket = server.nextPendingConnection(); + QSignalSpy serverSocketSpy(serverSocket, &QTcpSocket::readyRead); + QByteArray data; + while (!data.contains("\r\n\r\n")) { + QVERIFY(serverSocketSpy.wait()); + data.append(serverSocket->readAll()); + } + QVERIFY(data.contains("CustomHeader: Example")); + const auto view = QLatin1String(data); + const auto keyHeader = QLatin1String("Sec-WebSocket-Key:"); + const qsizetype keyStart = view.indexOf(keyHeader, 0, Qt::CaseInsensitive) + keyHeader.size(); + QVERIFY(keyStart != -1); + const qsizetype keyEnd = view.indexOf(QLatin1String("\r\n"), keyStart); + QVERIFY(keyEnd != -1); + const QLatin1String keyView = view.sliced(keyStart, keyEnd - keyStart).trimmed(); + const QByteArray accept = + QByteArrayView(keyView) % QByteArrayLiteral("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + serverSocket->write( + "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + % QCryptographicHash::hash(accept, QCryptographicHash::Sha1).toBase64() + ); // trailing \r\n\r\n intentionally left off to make the client wait for it + serverSocket->flush(); + // This would freeze prior to the fix for QTBUG-102111, because the client would loop forever. + // We use qWait to give the OS some time to move the bytes over to the client and push the event + // to our eventloop. + QTest::qWait(100); + serverSocket->write("\r\n\r\n"); + + // And check the client properly connects: + QSignalSpy connectedSpy(&socket, &QWebSocket::connected); + QVERIFY(connectedSpy.wait()); +} QTEST_MAIN(tst_QWebSocket) |