diff options
author | Pierre Ossman <pierre@ossman.eu> | 2016-09-15 19:51:26 +0200 |
---|---|---|
committer | Pierre Ossman <pierre@ossman.eu> | 2017-02-01 08:22:27 +0100 |
commit | 8a697622495fd319582cd1c604e7eb2cc0ac0ef6 (patch) | |
tree | 9270b1bb631c6559d2c0e9049a0d9b505b4c507c /tests | |
parent | 4099949984eb80ef33c2d0dd216991124975a5d2 (diff) | |
download | websockify-8a697622495fd319582cd1c604e7eb2cc0ac0ef6.tar.gz |
Separate out raw WebSocket protocol handling
Diffstat (limited to 'tests')
-rwxr-xr-x | tests/echo.py | 4 | ||||
-rwxr-xr-x | tests/echo_client.py | 70 | ||||
-rwxr-xr-x | tests/load.py | 14 | ||||
-rw-r--r-- | tests/test_websocket.py | 386 | ||||
-rw-r--r-- | tests/test_websocketproxy.py | 4 | ||||
-rw-r--r-- | tests/test_websocketserver.py | 347 |
6 files changed, 448 insertions, 377 deletions
diff --git a/tests/echo.py b/tests/echo.py index e6a6851..3d81e04 100755 --- a/tests/echo.py +++ b/tests/echo.py @@ -12,7 +12,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates import os, sys, select, optparse, logging sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer, WebSocketRequestHandler +from websockify.websocketserver import WebSocketServer, WebSocketRequestHandler class WebSocketEcho(WebSocketRequestHandler): """ @@ -48,7 +48,7 @@ class WebSocketEcho(WebSocketRequestHandler): cqueue.extend(frames) if closed: - self.send_close() + break if __name__ == '__main__': parser = optparse.OptionParser(usage="%prog [options] listen_port") diff --git a/tests/echo_client.py b/tests/echo_client.py new file mode 100755 index 0000000..6d745ec --- /dev/null +++ b/tests/echo_client.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +import os +import sys +import optparse +import select + +sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) +from websockify.websocket import WebSocket, \ + WebSocketWantReadError, WebSocketWantWriteError + +parser = optparse.OptionParser(usage="%prog URL") +(opts, args) = parser.parse_args() + +try: + if len(args) != 1: raise + URL = args[0] +except: + parser.error("Invalid arguments") + +sock = WebSocket() +print("Connecting to %s..." % URL) +sock.connect(URL) +print("Connected.") + +def send(msg): + while True: + try: + sock.sendmsg(msg) + break + except WebSocketWantReadError: + msg = '' + ins, outs, excepts = select.select([sock], [], []) + if excepts: raise Exception("Socket exception") + except WebSocketWantWriteError: + msg = '' + ins, outs, excepts = select.select([], [sock], []) + if excepts: raise Exception("Socket exception") + +def read(): + while True: + try: + return sock.recvmsg() + except WebSocketWantReadError: + ins, outs, excepts = select.select([sock], [], []) + if excepts: raise Exception("Socket exception") + except WebSocketWantWriteError: + ins, outs, excepts = select.select([], [sock], []) + if excepts: raise Exception("Socket exception") + +counter = 1 +while True: + msg = "Message #%d" % counter + counter += 1 + send(msg) + print("Sent message: %r" % msg) + + while True: + ins, outs, excepts = select.select([sock], [], [], 1.0) + if excepts: raise Exception("Socket exception") + + if ins == []: + break + + while True: + msg = read() + print("Received message: %r" % msg) + + if not sock.pending(): + break diff --git a/tests/load.py b/tests/load.py index c76feb1..caf6b58 100755 --- a/tests/load.py +++ b/tests/load.py @@ -8,7 +8,7 @@ given a sequence number. Any errors are reported and counted. import sys, os, select, random, time, optparse, logging sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer, WebSocketRequestHandler +from websockify.websocketserver import WebSocketServer, WebSocketRequestHandler class WebSocketLoadServer(WebSocketServer): @@ -35,12 +35,10 @@ class WebSocketLoad(WebSocketRequestHandler): self.send_cnt = 0 self.recv_cnt = 0 - try: - self.responder(self.request) - except: - print "accumulated errors:", self.errors - self.errors = 0 - raise + self.responder(self.request) + + print "accumulated errors:", self.errors + self.errors = 0 def responder(self, client): c_pend = 0 @@ -62,7 +60,7 @@ class WebSocketLoad(WebSocketRequestHandler): print err if closed: - self.send_close() + break now = time.time() * 1000 if client in outs: diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 545fa1c..77d0eca 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -15,418 +15,74 @@ # under the License. """ Unit tests for websocket """ -import errno -import os -import logging -import select -import shutil -import socket -import ssl -from mox3 import stubout -import sys -import tempfile import unittest -import socket -import signal from websockify import websocket -try: - from SimpleHTTPServer import SimpleHTTPRequestHandler -except ImportError: - from http.server import SimpleHTTPRequestHandler - -try: - from StringIO import StringIO - BytesIO = StringIO -except ImportError: - from io import StringIO - from io import BytesIO - - - - -def raise_oserror(*args, **kwargs): - raise OSError('fake error') - - -class FakeSocket(object): - def __init__(self, data=''): - if isinstance(data, bytes): - self._data = data - else: - self._data = data.encode('latin_1') - - def recv(self, amt, flags=None): - res = self._data[0:amt] - if not (flags & socket.MSG_PEEK): - self._data = self._data[amt:] - - return res - - def makefile(self, mode='r', buffsize=None): - if 'b' in mode: - return BytesIO(self._data) - else: - return StringIO(self._data.decode('latin_1')) - - -class WebSocketRequestHandlerTestCase(unittest.TestCase): - def setUp(self): - super(WebSocketRequestHandlerTestCase, self).setUp() - self.stubs = stubout.StubOutForTesting() - self.tmpdir = tempfile.mkdtemp('-websockify-tests') - # Mock this out cause it screws tests up - self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - lambda *args, **kwargs: None) - - def tearDown(self): - """Called automatically after each test.""" - self.stubs.UnsetAll() - os.rmdir(self.tmpdir) - super(WebSocketRequestHandlerTestCase, self).tearDown() - - def _get_server(self, handler_class=websocket.WebSocketRequestHandler, - **kwargs): - web = kwargs.pop('web', self.tmpdir) - return websocket.WebSocketServer( - handler_class, listen_host='localhost', - listen_port=80, key=self.tmpdir, web=web, - record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, - **kwargs) - - def test_normal_get_with_only_upgrade_returns_error(self): - server = self._get_server(web=None) - handler = websocket.WebSocketRequestHandler( - FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) - - def fake_send_response(self, code, message=None): - self.last_code = code - - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - fake_send_response) - - handler.do_GET() - self.assertEqual(handler.last_code, 405) - - def test_list_dir_with_file_only_returns_error(self): - server = self._get_server(file_only=True) - handler = websocket.WebSocketRequestHandler( - FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) - - def fake_send_response(self, code, message=None): - self.last_code = code - - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - fake_send_response) - - handler.path = '/' - handler.do_GET() - self.assertEqual(handler.last_code, 404) - - -class WebSocketServerTestCase(unittest.TestCase): - def setUp(self): - super(WebSocketServerTestCase, self).setUp() - self.stubs = stubout.StubOutForTesting() - self.tmpdir = tempfile.mkdtemp('-websockify-tests') - # Mock this out cause it screws tests up - self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) - - def tearDown(self): - """Called automatically after each test.""" - self.stubs.UnsetAll() - os.rmdir(self.tmpdir) - super(WebSocketServerTestCase, self).tearDown() - - def _get_server(self, handler_class=websocket.WebSocketRequestHandler, - **kwargs): - return websocket.WebSocketServer( - handler_class, listen_host='localhost', - listen_port=80, key=self.tmpdir, web=self.tmpdir, - record=self.tmpdir, **kwargs) - - def test_daemonize_raises_error_while_closing_fds(self): - server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: 0) - self.stubs.Set(signal, 'signal', lambda *args: None) - self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(os, 'close', raise_oserror) - self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') - - def test_daemonize_ignores_ebadf_error_while_closing_fds(self): - def raise_oserror_ebadf(fd): - raise OSError(errno.EBADF, 'fake error') - - server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: 0) - self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(signal, 'signal', lambda *args: None) - self.stubs.Set(os, 'close', raise_oserror_ebadf) - self.stubs.Set(os, 'open', raise_oserror) - self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') - - def test_handshake_fails_on_not_ready(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - FakeSocket(), '127.0.0.1') - - def test_empty_handshake_fails(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - - sock = FakeSocket('') - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_handshake_policy_request(self): - # TODO(directxman12): implement - pass - - def test_handshake_ssl_only_without_ssl_raises_error(self): - server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) - - sock = FakeSocket('some initial data') - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_do_handshake_no_ssl(self): - class FakeHandler(object): - CALLED = False - def __init__(self, *args, **kwargs): - type(self).CALLED = True - - FakeHandler.CALLED = False - - server = self._get_server( - handler_class=FakeHandler, daemon=True, - ssl_only=0, idle_timeout=1) - - sock = FakeSocket('some initial data') - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) - self.assertTrue(FakeHandler.CALLED, True) - - def test_do_handshake_ssl(self): - # TODO(directxman12): implement this - pass - - def test_do_handshake_ssl_without_ssl_raises_error(self): - # TODO(directxman12): implement this - pass - - def test_do_handshake_ssl_without_cert_raises_error(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, - cert='afdsfasdafdsafdsafdsafdas') - - sock = FakeSocket("\x16some ssl data") - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - self.stubs.Set(select, 'select', fake_select) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_do_handshake_ssl_error_eof_raises_close_error(self): - server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) - - sock = FakeSocket("\x16some ssl data") - - def fake_select(rlist, wlist, xlist, timeout=None): - return ([sock], [], []) - - def fake_wrap_socket(*args, **kwargs): - raise ssl.SSLError(ssl.SSL_ERROR_EOF) - - self.stubs.Set(select, 'select', fake_select) - self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) - self.assertRaises( - websocket.WebSocketServer.EClose, server.do_handshake, - sock, '127.0.0.1') - - def test_fallback_sigchld_handler(self): - # TODO(directxman12): implement this - pass - - def test_start_server_error(self): - server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) - sock = server.socket('localhost') - - def fake_select(rlist, wlist, xlist, timeout=None): - raise Exception("fake error") - - self.stubs.Set(websocket.WebSocketServer, 'socket', - lambda *args, **kwargs: sock) - self.stubs.Set(websocket.WebSocketServer, 'daemonize', - lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', fake_select) - server.start_server() - - def test_start_server_keyboardinterrupt(self): - server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost') - - def fake_select(rlist, wlist, xlist, timeout=None): - raise KeyboardInterrupt - - self.stubs.Set(websocket.WebSocketServer, 'socket', - lambda *args, **kwargs: sock) - self.stubs.Set(websocket.WebSocketServer, 'daemonize', - lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', fake_select) - server.start_server() - - def test_start_server_systemexit(self): - server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost') - - def fake_select(rlist, wlist, xlist, timeout=None): - sys.exit() - - self.stubs.Set(websocket.WebSocketServer, 'socket', - lambda *args, **kwargs: sock) - self.stubs.Set(websocket.WebSocketServer, 'daemonize', - lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', fake_select) - server.start_server() - - def test_socket_set_keepalive_options(self): - keepcnt = 12 - keepidle = 34 - keepintvl = 56 - - server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) - sock = server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) - - if hasattr(socket, 'TCP_KEEPCNT'): - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPCNT), keepcnt) - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPIDLE), keepidle) - self.assertEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPINTVL), keepintvl) - - sock = server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) - - if hasattr(socket, 'TCP_KEEPCNT'): - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPCNT), keepcnt) - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPIDLE), keepidle) - self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, - socket.TCP_KEEPINTVL), keepintvl) - - class HyBiEncodeDecodeTestCase(unittest.TestCase): def test_decode_hybi_text(self): buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58' - res = websocket.WebSocketRequestHandler.decode_hybi(buf) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x1) self.assertEqual(res['masked'], True) - self.assertEqual(res['length'], 5) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], b'Hello') - self.assertEqual(res['left'], 0) def test_decode_hybi_binary(self): buf = b'\x82\x04\x01\x02\x03\x04' - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], 4) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], b'\x01\x02\x03\x04') - self.assertEqual(res['left'], 0) def test_decode_hybi_extended_16bit_binary(self): data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 buf = b'\x82\x7e\x01\x04' + data - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], 260) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], data) - self.assertEqual(res['left'], 0) def test_decode_hybi_extended_64bit_binary(self): data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260 buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) + ws = websocket.WebSocket() + res = ws._decode_hybi(buf) self.assertEqual(res['fin'], 1) self.assertEqual(res['opcode'], 0x2) - self.assertEqual(res['length'], 260) + self.assertEqual(res['length'], len(buf)) self.assertEqual(res['payload'], data) - self.assertEqual(res['left'], 0) def test_decode_hybi_multi(self): buf1 = b'\x01\x03\x48\x65\x6c' buf2 = b'\x80\x02\x6c\x6f' - res1 = websocket.WebSocketRequestHandler.decode_hybi(buf1, strict=False) + ws = websocket.WebSocket() + + res1 = ws._decode_hybi(buf1) self.assertEqual(res1['fin'], 0) self.assertEqual(res1['opcode'], 0x1) - self.assertEqual(res1['length'], 3) + self.assertEqual(res1['length'], len(buf1)) self.assertEqual(res1['payload'], b'Hel') - self.assertEqual(res1['left'], 0) - res2 = websocket.WebSocketRequestHandler.decode_hybi(buf2, strict=False) + res2 = ws._decode_hybi(buf2) self.assertEqual(res2['fin'], 1) self.assertEqual(res2['opcode'], 0x0) - self.assertEqual(res2['length'], 2) + self.assertEqual(res2['length'], len(buf2)) self.assertEqual(res2['payload'], b'lo') - self.assertEqual(res2['left'], 0) def test_encode_hybi_basic(self): - res = websocket.WebSocketRequestHandler.encode_hybi(b'Hello', 0x1) - expected = (b'\x81\x05\x48\x65\x6c\x6c\x6f', 2, 0) + ws = websocket.WebSocket() + res = ws._encode_hybi(0x1, b'Hello') + expected = b'\x81\x05\x48\x65\x6c\x6c\x6f' self.assertEqual(res, expected) - - def test_strict_mode_refuses_unmasked_client_frames(self): - buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' - self.assertRaises(websocket.WebSocketRequestHandler.CClose, - websocket.WebSocketRequestHandler.decode_hybi, - buf) - - def test_no_strict_mode_accepts_unmasked_client_frames(self): - buf = b'\x81\x05\x48\x65\x6c\x6c\x6f' - res = websocket.WebSocketRequestHandler.decode_hybi(buf, strict=False) - - self.assertEqual(res['fin'], 1) - self.assertEqual(res['opcode'], 0x1) - self.assertEqual(res['masked'], False) - self.assertEqual(res['length'], 5) - self.assertEqual(res['payload'], b'Hello') diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index b48796e..ac08dfa 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -22,7 +22,7 @@ import socket from mox3 import stubout -from websockify import websocket +from websockify import websocketserver from websockify import websocketproxy from websockify import token_plugins from websockify import auth_plugins @@ -75,7 +75,7 @@ class ProxyRequestHandlerTestCase(unittest.TestCase): FakeSocket(''), "127.0.0.1", FakeServer()) self.handler.path = "https://localhost:6080/websockify?token=blah" self.handler.headers = None - self.stubs.Set(websocket.WebSocketServer, 'socket', + self.stubs.Set(websocketserver.WebSocketServer, 'socket', staticmethod(lambda *args, **kwargs: None)) def tearDown(self): diff --git a/tests/test_websocketserver.py b/tests/test_websocketserver.py new file mode 100644 index 0000000..aaeeee6 --- /dev/null +++ b/tests/test_websocketserver.py @@ -0,0 +1,347 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright(c)2013 NTT corp. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" Unit tests for websocketserver """ +import errno +import os +import logging +import select +import shutil +import socket +import ssl +from mox3 import stubout +import sys +import tempfile +import unittest +import socket +import signal +from websockify import websocketserver + +try: + from SimpleHTTPServer import SimpleHTTPRequestHandler +except ImportError: + from http.server import SimpleHTTPRequestHandler + +try: + from StringIO import StringIO + BytesIO = StringIO +except ImportError: + from io import StringIO + from io import BytesIO + + + + +def raise_oserror(*args, **kwargs): + raise OSError('fake error') + + +class FakeSocket(object): + def __init__(self, data=''): + if isinstance(data, bytes): + self._data = data + else: + self._data = data.encode('latin_1') + + def recv(self, amt, flags=None): + res = self._data[0:amt] + if not (flags & socket.MSG_PEEK): + self._data = self._data[amt:] + + return res + + def makefile(self, mode='r', buffsize=None): + if 'b' in mode: + return BytesIO(self._data) + else: + return StringIO(self._data.decode('latin_1')) + + +class WebSocketRequestHandlerTestCase(unittest.TestCase): + def setUp(self): + super(WebSocketRequestHandlerTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.tmpdir = tempfile.mkdtemp('-websockify-tests') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + lambda *args, **kwargs: None) + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + os.rmdir(self.tmpdir) + super(WebSocketRequestHandlerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocketserver.WebSocketRequestHandler, + **kwargs): + web = kwargs.pop('web', self.tmpdir) + return websocketserver.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=web, + record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, + **kwargs) + + def test_normal_get_with_only_upgrade_returns_error(self): + server = self._get_server(web=None) + handler = websocketserver.WebSocketRequestHandler( + FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) + + def fake_send_response(self, code, message=None): + self.last_code = code + + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) + + handler.do_GET() + self.assertEqual(handler.last_code, 405) + + def test_list_dir_with_file_only_returns_error(self): + server = self._get_server(file_only=True) + handler = websocketserver.WebSocketRequestHandler( + FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) + + def fake_send_response(self, code, message=None): + self.last_code = code + + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) + + handler.path = '/' + handler.do_GET() + self.assertEqual(handler.last_code, 404) + + +class WebSocketServerTestCase(unittest.TestCase): + def setUp(self): + super(WebSocketServerTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.tmpdir = tempfile.mkdtemp('-websockify-tests') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + os.rmdir(self.tmpdir) + super(WebSocketServerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocketserver.WebSocketRequestHandler, + **kwargs): + return websocketserver.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=self.tmpdir, + record=self.tmpdir, **kwargs) + + def test_daemonize_raises_error_while_closing_fds(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) + self.stubs.Set(signal, 'signal', lambda *args: None) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(os, 'close', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + + def test_daemonize_ignores_ebadf_error_while_closing_fds(self): + def raise_oserror_ebadf(fd): + raise OSError(errno.EBADF, 'fake error') + + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(signal, 'signal', lambda *args: None) + self.stubs.Set(os, 'close', raise_oserror_ebadf) + self.stubs.Set(os, 'open', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + + def test_handshake_fails_on_not_ready(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + FakeSocket(), '127.0.0.1') + + def test_empty_handshake_fails(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket('') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_handshake_policy_request(self): + # TODO(directxman12): implement + pass + + def test_handshake_ssl_only_without_ssl_raises_error(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_no_ssl(self): + class FakeHandler(object): + CALLED = False + def __init__(self, *args, **kwargs): + type(self).CALLED = True + + FakeHandler.CALLED = False + + server = self._get_server( + handler_class=FakeHandler, daemon=True, + ssl_only=0, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) + self.assertTrue(FakeHandler.CALLED, True) + + def test_do_handshake_ssl(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_ssl_raises_error(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_cert_raises_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, + cert='afdsfasdafdsafdsafdsafdas') + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_ssl_error_eof_raises_close_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + def fake_wrap_socket(*args, **kwargs): + raise ssl.SSLError(ssl.SSL_ERROR_EOF) + + self.stubs.Set(select, 'select', fake_select) + self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) + self.assertRaises( + websocketserver.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_fallback_sigchld_handler(self): + # TODO(directxman12): implement this + pass + + def test_start_server_error(self): + server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise Exception("fake error") + + self.stubs.Set(websocketserver.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) + self.stubs.Set(websocketserver.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_start_server_keyboardinterrupt(self): + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise KeyboardInterrupt + + self.stubs.Set(websocketserver.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) + self.stubs.Set(websocketserver.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_start_server_systemexit(self): + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + sys.exit() + + self.stubs.Set(websocketserver.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) + self.stubs.Set(websocketserver.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_socket_set_keepalive_options(self): + keepcnt = 12 + keepidle = 34 + keepintvl = 56 + + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) + + if hasattr(socket, 'TCP_KEEPCNT'): + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPCNT), keepcnt) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPIDLE), keepidle) + self.assertEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPINTVL), keepintvl) + + sock = server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) + + if hasattr(socket, 'TCP_KEEPCNT'): + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPCNT), keepcnt) + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPIDLE), keepidle) + self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, + socket.TCP_KEEPINTVL), keepintvl) |