diff options
Diffstat (limited to 'tests/test_websocket.py')
-rw-r--r-- | tests/test_websocket.py | 260 |
1 files changed, 242 insertions, 18 deletions
diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c603189..49efe81 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -14,11 +14,27 @@ # License for the specific language governing permissions and limitations # under the License. -"""Unit tests for websockify.""" - +""" Unit tests for websocket """ +import errno +import os +import select import socket +import ssl +import stubout +import sys +import tempfile import unittest +from ssl import SSLError from websockify import websocket as websocket +from SimpleHTTPServer import SimpleHTTPRequestHandler + + +class MockConnection(object): + def __init__(self, path): + self.path = path + + def makefile(self, mode='r', bufsize=-1): + return open(self.path, mode, bufsize) class WebSocketTestCase(unittest.TestCase): @@ -26,27 +42,235 @@ class WebSocketTestCase(unittest.TestCase): def setUp(self): """Called automatically before each test.""" super(WebSocketTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.server = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='./', + web='./', + record='./', + daemon=True, + ssl_only=False) + self.soc = self.server.socket('localhost') def tearDown(self): """Called automatically after each test.""" + self.stubs.UnsetAll() super(WebSocketTestCase, self).tearDown() + def _mock_os_open_oserror(self, file, flags): + raise OSError('') + + def _mock_os_close_oserror(self, fd): + raise OSError('') + + def _mock_os_close_oserror_EBADF(self, fd): + raise OSError(errno.EBADF, '') + + def _mock_socket(self, *args, **kwargs): + return self.soc + + def _mock_select(self, rlist, wlist, xlist, timeout=None): + return '_mock_select' + + def _mock_select_exception(self, rlist, wlist, xlist, timeout=None): + raise Exception + + def _mock_select_keyboardinterrupt(self, rlist, wlist, + xlist, timeout=None): + raise KeyboardInterrupt + + def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None): + sys.exit() + + def test_daemonize_error(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: None) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(os, 'close', self._mock_os_close_oserror) + self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + + def test_daemonize_EBADF_error(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: None) + self.stubs.Set(os, 'setsid', lambda *args: None) + self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) + self.stubs.Set(os, 'open', self._mock_os_open_oserror) + self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + + def test_decode_hybi(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=False, + ssl_only=1, + idle_timeout=1) + + self.assertRaises(Exception, soc.decode_hybi, 'a' * 128, + base64=True) + + def test_do_websocket_handshake(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + soc.scheme = 'scheme' + headers = {'Sec-WebSocket-Protocol': 'binary', + 'Sec-WebSocket-Version': '7', + 'Sec-WebSocket-Key': 'foo'} + soc.do_websocket_handshake(headers, '127.0.0.1') + + def test_do_handshake(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + self.stubs.Set(select, 'select', self._mock_select) + self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') + self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1') + + def test_do_handshake_ssl_error(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + + def _mock_wrap_socket(*args, **kwargs): + from ssl import SSLError + raise SSLError('unit test exception') + + self.stubs.Set(select, 'select', self._mock_select) + self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16') + self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket) + self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1') + + def test_fallback_SIGCHILD(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1) + soc.fallback_SIGCHLD(None, None) + + def test_start_server_Exception(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=False, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', self._mock_select_exception) + self.assertEqual(None, soc.start_server()) + + def test_start_server_KeyboardInterrupt(self): + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + cert='xxxxxx', + daemon=True, + ssl_only=1, + idle_timeout=1) + self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt) + self.assertEqual(None, soc.start_server()) + + def test_start_server_systemexit(self): + websocket.ssl = None + soc = websocket.WebSocketServer(listen_host='localhost', + listen_port=80, + key='../', + web='../', + record='../', + daemon=True, + ssl_only=0, + idle_timeout=1, + verbose=True) + self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'daemonize', + lambda *args, **kwargs: None) + self.stubs.Set(select, 'select', self._mock_select_systemexit) + self.assertEqual(None, soc.start_server()) + + def test_WSRequestHandle_do_GET_nofile(self): + request = 'GET /tmp.txt HTTP/0.9' + with tempfile.NamedTemporaryFile() as test_file: + test_file.write(request) + test_file.flush() + test_file.seek(0) + con = MockConnection(test_file.name) + soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True) + soc.path = '' + soc.headers = {'upgrade': ''} + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + lambda *args: None) + soc.do_GET() + self.assertEqual(404, soc.last_code) + + def test_WSRequestHandle_do_GET_hidden_resource(self): + request = 'GET /tmp.txt HTTP/0.9' + with tempfile.NamedTemporaryFile() as test_file: + test_file.write(request) + test_file.flush() + test_file.seek(0) + con = MockConnection(test_file.name) + soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True) + soc.path = test_file.name + '?' + soc.headers = {'upgrade': ''} + soc.webroot = 'no match startswith' + self.stubs.Set(SimpleHTTPRequestHandler, + 'send_response', + lambda *args: None) + soc.do_GET() + self.assertEqual(403, soc.last_code) + def testsocket_set_keepalive_options(self): - server = websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key='./', - web='./', - record='./', - daemon=True, - ssl_only=1) keepcnt = 12 keepidle = 34 keepintvl = 56 - sock = server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = self.server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) @@ -55,11 +279,11 @@ class WebSocketTestCase(unittest.TestCase): self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl) - sock = server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = self.server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) |