summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorPierre Ossman <pierre@ossman.eu>2016-09-15 19:51:26 +0200
committerPierre Ossman <pierre@ossman.eu>2017-02-01 08:22:27 +0100
commit8a697622495fd319582cd1c604e7eb2cc0ac0ef6 (patch)
tree9270b1bb631c6559d2c0e9049a0d9b505b4c507c /tests
parent4099949984eb80ef33c2d0dd216991124975a5d2 (diff)
downloadwebsockify-8a697622495fd319582cd1c604e7eb2cc0ac0ef6.tar.gz
Separate out raw WebSocket protocol handling
Diffstat (limited to 'tests')
-rwxr-xr-xtests/echo.py4
-rwxr-xr-xtests/echo_client.py70
-rwxr-xr-xtests/load.py14
-rw-r--r--tests/test_websocket.py386
-rw-r--r--tests/test_websocketproxy.py4
-rw-r--r--tests/test_websocketserver.py347
6 files changed, 448 insertions, 377 deletions
diff --git a/tests/echo.py b/tests/echo.py
index e6a6851..3d81e04 100755
--- a/tests/echo.py
+++ b/tests/echo.py
@@ -12,7 +12,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
import os, sys, select, optparse, logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
-from websockify.websocket import WebSocketServer, WebSocketRequestHandler
+from websockify.websocketserver import WebSocketServer, WebSocketRequestHandler
class WebSocketEcho(WebSocketRequestHandler):
"""
@@ -48,7 +48,7 @@ class WebSocketEcho(WebSocketRequestHandler):
cqueue.extend(frames)
if closed:
- self.send_close()
+ break
if __name__ == '__main__':
parser = optparse.OptionParser(usage="%prog [options] listen_port")
diff --git a/tests/echo_client.py b/tests/echo_client.py
new file mode 100755
index 0000000..6d745ec
--- /dev/null
+++ b/tests/echo_client.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python
+
+import os
+import sys
+import optparse
+import select
+
+sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
+from websockify.websocket import WebSocket, \
+ WebSocketWantReadError, WebSocketWantWriteError
+
+parser = optparse.OptionParser(usage="%prog URL")
+(opts, args) = parser.parse_args()
+
+try:
+ if len(args) != 1: raise
+ URL = args[0]
+except:
+ parser.error("Invalid arguments")
+
+sock = WebSocket()
+print("Connecting to %s..." % URL)
+sock.connect(URL)
+print("Connected.")
+
+def send(msg):
+ while True:
+ try:
+ sock.sendmsg(msg)
+ break
+ except WebSocketWantReadError:
+ msg = ''
+ ins, outs, excepts = select.select([sock], [], [])
+ if excepts: raise Exception("Socket exception")
+ except WebSocketWantWriteError:
+ msg = ''
+ ins, outs, excepts = select.select([], [sock], [])
+ if excepts: raise Exception("Socket exception")
+
+def read():
+ while True:
+ try:
+ return sock.recvmsg()
+ except WebSocketWantReadError:
+ ins, outs, excepts = select.select([sock], [], [])
+ if excepts: raise Exception("Socket exception")
+ except WebSocketWantWriteError:
+ ins, outs, excepts = select.select([], [sock], [])
+ if excepts: raise Exception("Socket exception")
+
+counter = 1
+while True:
+ msg = "Message #%d" % counter
+ counter += 1
+ send(msg)
+ print("Sent message: %r" % msg)
+
+ while True:
+ ins, outs, excepts = select.select([sock], [], [], 1.0)
+ if excepts: raise Exception("Socket exception")
+
+ if ins == []:
+ break
+
+ while True:
+ msg = read()
+ print("Received message: %r" % msg)
+
+ if not sock.pending():
+ break
diff --git a/tests/load.py b/tests/load.py
index c76feb1..caf6b58 100755
--- a/tests/load.py
+++ b/tests/load.py
@@ -8,7 +8,7 @@ given a sequence number. Any errors are reported and counted.
import sys, os, select, random, time, optparse, logging
sys.path.insert(0,os.path.join(os.path.dirname(__file__), ".."))
-from websockify.websocket import WebSocketServer, WebSocketRequestHandler
+from websockify.websocketserver import WebSocketServer, WebSocketRequestHandler
class WebSocketLoadServer(WebSocketServer):
@@ -35,12 +35,10 @@ class WebSocketLoad(WebSocketRequestHandler):
self.send_cnt = 0
self.recv_cnt = 0
- try:
- self.responder(self.request)
- except:
- print "accumulated errors:", self.errors
- self.errors = 0
- raise
+ self.responder(self.request)
+
+ print "accumulated errors:", self.errors
+ self.errors = 0
def responder(self, client):
c_pend = 0
@@ -62,7 +60,7 @@ class WebSocketLoad(WebSocketRequestHandler):
print err
if closed:
- self.send_close()
+ break
now = time.time() * 1000
if client in outs:
diff --git a/tests/test_websocket.py b/tests/test_websocket.py
index 545fa1c..77d0eca 100644
--- a/tests/test_websocket.py
+++ b/tests/test_websocket.py
@@ -15,418 +15,74 @@
# under the License.
""" Unit tests for websocket """
-import errno
-import os
-import logging
-import select
-import shutil
-import socket
-import ssl
-from mox3 import stubout
-import sys
-import tempfile
import unittest
-import socket
-import signal
from websockify import websocket
-try:
- from SimpleHTTPServer import SimpleHTTPRequestHandler
-except ImportError:
- from http.server import SimpleHTTPRequestHandler
-
-try:
- from StringIO import StringIO
- BytesIO = StringIO
-except ImportError:
- from io import StringIO
- from io import BytesIO
-
-
-
-
-def raise_oserror(*args, **kwargs):
- raise OSError('fake error')
-
-
-class FakeSocket(object):
- def __init__(self, data=''):
- if isinstance(data, bytes):
- self._data = data
- else:
- self._data = data.encode('latin_1')
-
- def recv(self, amt, flags=None):
- res = self._data[0:amt]
- if not (flags & socket.MSG_PEEK):
- self._data = self._data[amt:]
-
- return res
-
- def makefile(self, mode='r', buffsize=None):
- if 'b' in mode:
- return BytesIO(self._data)
- else:
- return StringIO(self._data.decode('latin_1'))
-
-
-class WebSocketRequestHandlerTestCase(unittest.TestCase):
- def setUp(self):
- super(WebSocketRequestHandlerTestCase, self).setUp()
- self.stubs = stubout.StubOutForTesting()
- self.tmpdir = tempfile.mkdtemp('-websockify-tests')
- # Mock this out cause it screws tests up
- self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
- self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
- lambda *args, **kwargs: None)
-
- def tearDown(self):
- """Called automatically after each test."""
- self.stubs.UnsetAll()
- os.rmdir(self.tmpdir)
- super(WebSocketRequestHandlerTestCase, self).tearDown()
-
- def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
- **kwargs):
- web = kwargs.pop('web', self.tmpdir)
- return websocket.WebSocketServer(
- handler_class, listen_host='localhost',
- listen_port=80, key=self.tmpdir, web=web,
- record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
- **kwargs)
-
- def test_normal_get_with_only_upgrade_returns_error(self):
- server = self._get_server(web=None)
- handler = websocket.WebSocketRequestHandler(
- FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
-
- def fake_send_response(self, code, message=None):
- self.last_code = code
-
- self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
- fake_send_response)
-
- handler.do_GET()
- self.assertEqual(handler.last_code, 405)
-
- def test_list_dir_with_file_only_returns_error(self):
- server = self._get_server(file_only=True)
- handler = websocket.WebSocketRequestHandler(
- FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
-
- def fake_send_response(self, code, message=None):
- self.last_code = code
-
- self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
- fake_send_response)
-
- handler.path = '/'
- handler.do_GET()
- self.assertEqual(handler.last_code, 404)
-
-
-class WebSocketServerTestCase(unittest.TestCase):
- def setUp(self):
- super(WebSocketServerTestCase, self).setUp()
- self.stubs = stubout.StubOutForTesting()
- self.tmpdir = tempfile.mkdtemp('-websockify-tests')
- # Mock this out cause it screws tests up
- self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
-
- def tearDown(self):
- """Called automatically after each test."""
- self.stubs.UnsetAll()
- os.rmdir(self.tmpdir)
- super(WebSocketServerTestCase, self).tearDown()
-
- def _get_server(self, handler_class=websocket.WebSocketRequestHandler,
- **kwargs):
- return websocket.WebSocketServer(
- handler_class, listen_host='localhost',
- listen_port=80, key=self.tmpdir, web=self.tmpdir,
- record=self.tmpdir, **kwargs)
-
- def test_daemonize_raises_error_while_closing_fds(self):
- server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
- self.stubs.Set(os, 'fork', lambda *args: 0)
- self.stubs.Set(signal, 'signal', lambda *args: None)
- self.stubs.Set(os, 'setsid', lambda *args: None)
- self.stubs.Set(os, 'close', raise_oserror)
- self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
-
- def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
- def raise_oserror_ebadf(fd):
- raise OSError(errno.EBADF, 'fake error')
-
- server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
- self.stubs.Set(os, 'fork', lambda *args: 0)
- self.stubs.Set(os, 'setsid', lambda *args: None)
- self.stubs.Set(signal, 'signal', lambda *args: None)
- self.stubs.Set(os, 'close', raise_oserror_ebadf)
- self.stubs.Set(os, 'open', raise_oserror)
- self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
-
- def test_handshake_fails_on_not_ready(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([], [], [])
-
- self.stubs.Set(select, 'select', fake_select)
- self.assertRaises(
- websocket.WebSocketServer.EClose, server.do_handshake,
- FakeSocket(), '127.0.0.1')
-
- def test_empty_handshake_fails(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
-
- sock = FakeSocket('')
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
-
- self.stubs.Set(select, 'select', fake_select)
- self.assertRaises(
- websocket.WebSocketServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
-
- def test_handshake_policy_request(self):
- # TODO(directxman12): implement
- pass
-
- def test_handshake_ssl_only_without_ssl_raises_error(self):
- server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
-
- sock = FakeSocket('some initial data')
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
-
- self.stubs.Set(select, 'select', fake_select)
- self.assertRaises(
- websocket.WebSocketServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
-
- def test_do_handshake_no_ssl(self):
- class FakeHandler(object):
- CALLED = False
- def __init__(self, *args, **kwargs):
- type(self).CALLED = True
-
- FakeHandler.CALLED = False
-
- server = self._get_server(
- handler_class=FakeHandler, daemon=True,
- ssl_only=0, idle_timeout=1)
-
- sock = FakeSocket('some initial data')
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
-
- self.stubs.Set(select, 'select', fake_select)
- self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
- self.assertTrue(FakeHandler.CALLED, True)
-
- def test_do_handshake_ssl(self):
- # TODO(directxman12): implement this
- pass
-
- def test_do_handshake_ssl_without_ssl_raises_error(self):
- # TODO(directxman12): implement this
- pass
-
- def test_do_handshake_ssl_without_cert_raises_error(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
- cert='afdsfasdafdsafdsafdsafdas')
-
- sock = FakeSocket("\x16some ssl data")
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
-
- self.stubs.Set(select, 'select', fake_select)
- self.assertRaises(
- websocket.WebSocketServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
-
- def test_do_handshake_ssl_error_eof_raises_close_error(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
-
- sock = FakeSocket("\x16some ssl data")
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
-
- def fake_wrap_socket(*args, **kwargs):
- raise ssl.SSLError(ssl.SSL_ERROR_EOF)
-
- self.stubs.Set(select, 'select', fake_select)
- self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket)
- self.assertRaises(
- websocket.WebSocketServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
-
- def test_fallback_sigchld_handler(self):
- # TODO(directxman12): implement this
- pass
-
- def test_start_server_error(self):
- server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
- sock = server.socket('localhost')
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- raise Exception("fake error")
-
- self.stubs.Set(websocket.WebSocketServer, 'socket',
- lambda *args, **kwargs: sock)
- self.stubs.Set(websocket.WebSocketServer, 'daemonize',
- lambda *args, **kwargs: None)
- self.stubs.Set(select, 'select', fake_select)
- server.start_server()
-
- def test_start_server_keyboardinterrupt(self):
- server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
- sock = server.socket('localhost')
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- raise KeyboardInterrupt
-
- self.stubs.Set(websocket.WebSocketServer, 'socket',
- lambda *args, **kwargs: sock)
- self.stubs.Set(websocket.WebSocketServer, 'daemonize',
- lambda *args, **kwargs: None)
- self.stubs.Set(select, 'select', fake_select)
- server.start_server()
-
- def test_start_server_systemexit(self):
- server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
- sock = server.socket('localhost')
-
- def fake_select(rlist, wlist, xlist, timeout=None):
- sys.exit()
-
- self.stubs.Set(websocket.WebSocketServer, 'socket',
- lambda *args, **kwargs: sock)
- self.stubs.Set(websocket.WebSocketServer, 'daemonize',
- lambda *args, **kwargs: None)
- self.stubs.Set(select, 'select', fake_select)
- server.start_server()
-
- def test_socket_set_keepalive_options(self):
- keepcnt = 12
- keepidle = 34
- keepintvl = 56
-
- server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
- sock = server.socket('localhost',
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
-
- if hasattr(socket, 'TCP_KEEPCNT'):
- self.assertEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPCNT), keepcnt)
- self.assertEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPIDLE), keepidle)
- self.assertEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPINTVL), keepintvl)
-
- sock = server.socket('localhost',
- tcp_keepalive=False,
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
-
- if hasattr(socket, 'TCP_KEEPCNT'):
- self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPCNT), keepcnt)
- self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPIDLE), keepidle)
- self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPINTVL), keepintvl)
-
-
class HyBiEncodeDecodeTestCase(unittest.TestCase):
def test_decode_hybi_text(self):
buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
- res = websocket.WebSocketRequestHandler.decode_hybi(buf)
+ ws = websocket.WebSocket()
+ res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x1)
self.assertEqual(res['masked'], True)
- self.assertEqual(res['length'], 5)
+ self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], b'Hello')
- self.assertEqual(res['left'], 0)
def test_decode_hybi_binary(self):
buf = b'\x82\x04\x01\x02\x03\x04'
- res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
+ ws = websocket.WebSocket()
+ res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
- self.assertEqual(res['length'], 4)
+ self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], b'\x01\x02\x03\x04')
- self.assertEqual(res['left'], 0)
def test_decode_hybi_extended_16bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7e\x01\x04' + data
- res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
+ ws = websocket.WebSocket()
+ res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
- self.assertEqual(res['length'], 260)
+ self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], data)
- self.assertEqual(res['left'], 0)
def test_decode_hybi_extended_64bit_binary(self):
data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data
- res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
+ ws = websocket.WebSocket()
+ res = ws._decode_hybi(buf)
self.assertEqual(res['fin'], 1)
self.assertEqual(res['opcode'], 0x2)
- self.assertEqual(res['length'], 260)
+ self.assertEqual(res['length'], len(buf))
self.assertEqual(res['payload'], data)
- self.assertEqual(res['left'], 0)
def test_decode_hybi_multi(self):
buf1 = b'\x01\x03\x48\x65\x6c'
buf2 = b'\x80\x02\x6c\x6f'
- res1 = websocket.WebSocketRequestHandler.decode_hybi(buf1, strict=False)
+ ws = websocket.WebSocket()
+
+ res1 = ws._decode_hybi(buf1)
self.assertEqual(res1['fin'], 0)
self.assertEqual(res1['opcode'], 0x1)
- self.assertEqual(res1['length'], 3)
+ self.assertEqual(res1['length'], len(buf1))
self.assertEqual(res1['payload'], b'Hel')
- self.assertEqual(res1['left'], 0)
- res2 = websocket.WebSocketRequestHandler.decode_hybi(buf2, strict=False)
+ res2 = ws._decode_hybi(buf2)
self.assertEqual(res2['fin'], 1)
self.assertEqual(res2['opcode'], 0x0)
- self.assertEqual(res2['length'], 2)
+ self.assertEqual(res2['length'], len(buf2))
self.assertEqual(res2['payload'], b'lo')
- self.assertEqual(res2['left'], 0)
def test_encode_hybi_basic(self):
- res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1)
- expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0)
+ ws = websocket.WebSocket()
+ res = ws._encode_hybi(0x1, b'Hello')
+ expected = b'\x81\x05\x48\x65\x6c\x6c\x6f'
self.assertEqual(res, expected)
-
- def test_strict_mode_refuses_unmasked_client_frames(self):
- buf = b'\x81\x05\x48\x65\x6c\x6c\x6f'
- self.assertRaises(websocket.WebSocketRequestHandler.CClose,
- websocket.WebSocketRequestHandler.decode_hybi,
- buf)
-
- def test_no_strict_mode_accepts_unmasked_client_frames(self):
- buf = b'\x81\x05\x48\x65\x6c\x6c\x6f'
- res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False)
-
- self.assertEqual(res['fin'], 1)
- self.assertEqual(res['opcode'], 0x1)
- self.assertEqual(res['masked'], False)
- self.assertEqual(res['length'], 5)
- self.assertEqual(res['payload'], b'Hello')
diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py
index b48796e..ac08dfa 100644
--- a/tests/test_websocketproxy.py
+++ b/tests/test_websocketproxy.py
@@ -22,7 +22,7 @@ import socket
from mox3 import stubout
-from websockify import websocket
+from websockify import websocketserver
from websockify import websocketproxy
from websockify import token_plugins
from websockify import auth_plugins
@@ -75,7 +75,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase):
FakeSocket(''), "127.0.0.1", FakeServer())
self.handler.path = "https://localhost:6080/websockify?token=blah"
self.handler.headers = None
- self.stubs.Set(websocket.WebSocketServer, 'socket',
+ self.stubs.Set(websocketserver.WebSocketServer, 'socket',
staticmethod(lambda *args, **kwargs: None))
def tearDown(self):
diff --git a/tests/test_websocketserver.py b/tests/test_websocketserver.py
new file mode 100644
index 0000000..aaeeee6
--- /dev/null
+++ b/tests/test_websocketserver.py
@@ -0,0 +1,347 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright(c)2013 NTT corp. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+""" Unit tests for websocketserver """
+import errno
+import os
+import logging
+import select
+import shutil
+import socket
+import ssl
+from mox3 import stubout
+import sys
+import tempfile
+import unittest
+import socket
+import signal
+from websockify import websocketserver
+
+try:
+ from SimpleHTTPServer import SimpleHTTPRequestHandler
+except ImportError:
+ from http.server import SimpleHTTPRequestHandler
+
+try:
+ from StringIO import StringIO
+ BytesIO = StringIO
+except ImportError:
+ from io import StringIO
+ from io import BytesIO
+
+
+
+
+def raise_oserror(*args, **kwargs):
+ raise OSError('fake error')
+
+
+class FakeSocket(object):
+ def __init__(self, data=''):
+ if isinstance(data, bytes):
+ self._data = data
+ else:
+ self._data = data.encode('latin_1')
+
+ def recv(self, amt, flags=None):
+ res = self._data[0:amt]
+ if not (flags & socket.MSG_PEEK):
+ self._data = self._data[amt:]
+
+ return res
+
+ def makefile(self, mode='r', buffsize=None):
+ if 'b' in mode:
+ return BytesIO(self._data)
+ else:
+ return StringIO(self._data.decode('latin_1'))
+
+
+class WebSocketRequestHandlerTestCase(unittest.TestCase):
+ def setUp(self):
+ super(WebSocketRequestHandlerTestCase, self).setUp()
+ self.stubs = stubout.StubOutForTesting()
+ self.tmpdir = tempfile.mkdtemp('-websockify-tests')
+ # Mock this out cause it screws tests up
+ self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ lambda *args, **kwargs: None)
+
+ def tearDown(self):
+ """Called automatically after each test."""
+ self.stubs.UnsetAll()
+ os.rmdir(self.tmpdir)
+ super(WebSocketRequestHandlerTestCase, self).tearDown()
+
+ def _get_server(self, handler_class=websocketserver.WebSocketRequestHandler,
+ **kwargs):
+ web = kwargs.pop('web', self.tmpdir)
+ return websocketserver.WebSocketServer(
+ handler_class, listen_host='localhost',
+ listen_port=80, key=self.tmpdir, web=web,
+ record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
+ **kwargs)
+
+ def test_normal_get_with_only_upgrade_returns_error(self):
+ server = self._get_server(web=None)
+ handler = websocketserver.WebSocketRequestHandler(
+ FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
+
+ def fake_send_response(self, code, message=None):
+ self.last_code = code
+
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ fake_send_response)
+
+ handler.do_GET()
+ self.assertEqual(handler.last_code, 405)
+
+ def test_list_dir_with_file_only_returns_error(self):
+ server = self._get_server(file_only=True)
+ handler = websocketserver.WebSocketRequestHandler(
+ FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server)
+
+ def fake_send_response(self, code, message=None):
+ self.last_code = code
+
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ fake_send_response)
+
+ handler.path = '/'
+ handler.do_GET()
+ self.assertEqual(handler.last_code, 404)
+
+
+class WebSocketServerTestCase(unittest.TestCase):
+ def setUp(self):
+ super(WebSocketServerTestCase, self).setUp()
+ self.stubs = stubout.StubOutForTesting()
+ self.tmpdir = tempfile.mkdtemp('-websockify-tests')
+ # Mock this out cause it screws tests up
+ self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
+
+ def tearDown(self):
+ """Called automatically after each test."""
+ self.stubs.UnsetAll()
+ os.rmdir(self.tmpdir)
+ super(WebSocketServerTestCase, self).tearDown()
+
+ def _get_server(self, handler_class=websocketserver.WebSocketRequestHandler,
+ **kwargs):
+ return websocketserver.WebSocketServer(
+ handler_class, listen_host='localhost',
+ listen_port=80, key=self.tmpdir, web=self.tmpdir,
+ record=self.tmpdir, **kwargs)
+
+ def test_daemonize_raises_error_while_closing_fds(self):
+ server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
+ self.stubs.Set(os, 'fork', lambda *args: 0)
+ self.stubs.Set(signal, 'signal', lambda *args: None)
+ self.stubs.Set(os, 'setsid', lambda *args: None)
+ self.stubs.Set(os, 'close', raise_oserror)
+ self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
+
+ def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
+ def raise_oserror_ebadf(fd):
+ raise OSError(errno.EBADF, 'fake error')
+
+ server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
+ self.stubs.Set(os, 'fork', lambda *args: 0)
+ self.stubs.Set(os, 'setsid', lambda *args: None)
+ self.stubs.Set(signal, 'signal', lambda *args: None)
+ self.stubs.Set(os, 'close', raise_oserror_ebadf)
+ self.stubs.Set(os, 'open', raise_oserror)
+ self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
+
+ def test_handshake_fails_on_not_ready(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocketserver.WebSocketServer.EClose, server.do_handshake,
+ FakeSocket(), '127.0.0.1')
+
+ def test_empty_handshake_fails(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
+
+ sock = FakeSocket('')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocketserver.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_handshake_policy_request(self):
+ # TODO(directxman12): implement
+ pass
+
+ def test_handshake_ssl_only_without_ssl_raises_error(self):
+ server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
+
+ sock = FakeSocket('some initial data')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocketserver.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_do_handshake_no_ssl(self):
+ class FakeHandler(object):
+ CALLED = False
+ def __init__(self, *args, **kwargs):
+ type(self).CALLED = True
+
+ FakeHandler.CALLED = False
+
+ server = self._get_server(
+ handler_class=FakeHandler, daemon=True,
+ ssl_only=0, idle_timeout=1)
+
+ sock = FakeSocket('some initial data')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
+ self.assertTrue(FakeHandler.CALLED, True)
+
+ def test_do_handshake_ssl(self):
+ # TODO(directxman12): implement this
+ pass
+
+ def test_do_handshake_ssl_without_ssl_raises_error(self):
+ # TODO(directxman12): implement this
+ pass
+
+ def test_do_handshake_ssl_without_cert_raises_error(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
+ cert='afdsfasdafdsafdsafdsafdas')
+
+ sock = FakeSocket("\x16some ssl data")
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.assertRaises(
+ websocketserver.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_do_handshake_ssl_error_eof_raises_close_error(self):
+ server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
+
+ sock = FakeSocket("\x16some ssl data")
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ return ([sock], [], [])
+
+ def fake_wrap_socket(*args, **kwargs):
+ raise ssl.SSLError(ssl.SSL_ERROR_EOF)
+
+ self.stubs.Set(select, 'select', fake_select)
+ self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket)
+ self.assertRaises(
+ websocketserver.WebSocketServer.EClose, server.do_handshake,
+ sock, '127.0.0.1')
+
+ def test_fallback_sigchld_handler(self):
+ # TODO(directxman12): implement this
+ pass
+
+ def test_start_server_error(self):
+ server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
+ sock = server.socket('localhost')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ raise Exception("fake error")
+
+ self.stubs.Set(websocketserver.WebSocketServer, 'socket',
+ lambda *args, **kwargs: sock)
+ self.stubs.Set(websocketserver.WebSocketServer, 'daemonize',
+ lambda *args, **kwargs: None)
+ self.stubs.Set(select, 'select', fake_select)
+ server.start_server()
+
+ def test_start_server_keyboardinterrupt(self):
+ server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
+ sock = server.socket('localhost')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ raise KeyboardInterrupt
+
+ self.stubs.Set(websocketserver.WebSocketServer, 'socket',
+ lambda *args, **kwargs: sock)
+ self.stubs.Set(websocketserver.WebSocketServer, 'daemonize',
+ lambda *args, **kwargs: None)
+ self.stubs.Set(select, 'select', fake_select)
+ server.start_server()
+
+ def test_start_server_systemexit(self):
+ server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
+ sock = server.socket('localhost')
+
+ def fake_select(rlist, wlist, xlist, timeout=None):
+ sys.exit()
+
+ self.stubs.Set(websocketserver.WebSocketServer, 'socket',
+ lambda *args, **kwargs: sock)
+ self.stubs.Set(websocketserver.WebSocketServer, 'daemonize',
+ lambda *args, **kwargs: None)
+ self.stubs.Set(select, 'select', fake_select)
+ server.start_server()
+
+ def test_socket_set_keepalive_options(self):
+ keepcnt = 12
+ keepidle = 34
+ keepintvl = 56
+
+ server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
+ sock = server.socket('localhost',
+ tcp_keepcnt=keepcnt,
+ tcp_keepidle=keepidle,
+ tcp_keepintvl=keepintvl)
+
+ if hasattr(socket, 'TCP_KEEPCNT'):
+ self.assertEqual(sock.getsockopt(socket.SOL_TCP,
+ socket.TCP_KEEPCNT), keepcnt)
+ self.assertEqual(sock.getsockopt(socket.SOL_TCP,
+ socket.TCP_KEEPIDLE), keepidle)
+ self.assertEqual(sock.getsockopt(socket.SOL_TCP,
+ socket.TCP_KEEPINTVL), keepintvl)
+
+ sock = server.socket('localhost',
+ tcp_keepalive=False,
+ tcp_keepcnt=keepcnt,
+ tcp_keepidle=keepidle,
+ tcp_keepintvl=keepintvl)
+
+ if hasattr(socket, 'TCP_KEEPCNT'):
+ self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
+ socket.TCP_KEEPCNT), keepcnt)
+ self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
+ socket.TCP_KEEPIDLE), keepidle)
+ self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
+ socket.TCP_KEEPINTVL), keepintvl)