summaryrefslogtreecommitdiff
path: root/tests/test_websocket.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_websocket.py')
-rw-r--r--tests/test_websocket.py260
1 files changed, 242 insertions, 18 deletions
diff --git a/tests/test_websocket.py b/tests/test_websocket.py
index c603189..49efe81 100644
--- a/tests/test_websocket.py
+++ b/tests/test_websocket.py
@@ -14,11 +14,27 @@
# License for the specific language governing permissions and limitations
# under the License.
-"""Unit tests for websockify."""
-
+""" Unit tests for websocket """
+import errno
+import os
+import select
import socket
+import ssl
+import stubout
+import sys
+import tempfile
import unittest
+from ssl import SSLError
from websockify import websocket as websocket
+from SimpleHTTPServer import SimpleHTTPRequestHandler
+
+
+class MockConnection(object):
+ def __init__(self, path):
+ self.path = path
+
+ def makefile(self, mode='r', bufsize=-1):
+ return open(self.path, mode, bufsize)
class WebSocketTestCase(unittest.TestCase):
@@ -26,27 +42,235 @@ class WebSocketTestCase(unittest.TestCase):
def setUp(self):
"""Called automatically before each test."""
super(WebSocketTestCase, self).setUp()
+ self.stubs = stubout.StubOutForTesting()
+ self.server = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='./',
+ web='./',
+ record='./',
+ daemon=True,
+ ssl_only=False)
+ self.soc = self.server.socket('localhost')
def tearDown(self):
"""Called automatically after each test."""
+ self.stubs.UnsetAll()
super(WebSocketTestCase, self).tearDown()
+ def _mock_os_open_oserror(self, file, flags):
+ raise OSError('')
+
+ def _mock_os_close_oserror(self, fd):
+ raise OSError('')
+
+ def _mock_os_close_oserror_EBADF(self, fd):
+ raise OSError(errno.EBADF, '')
+
+ def _mock_socket(self, *args, **kwargs):
+ return self.soc
+
+ def _mock_select(self, rlist, wlist, xlist, timeout=None):
+ return '_mock_select'
+
+ def _mock_select_exception(self, rlist, wlist, xlist, timeout=None):
+ raise Exception
+
+ def _mock_select_keyboardinterrupt(self, rlist, wlist,
+ xlist, timeout=None):
+ raise KeyboardInterrupt
+
+ def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None):
+ sys.exit()
+
+ def test_daemonize_error(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=1,
+ idle_timeout=1)
+ self.stubs.Set(os, 'fork', lambda *args: None)
+ self.stubs.Set(os, 'setsid', lambda *args: None)
+ self.stubs.Set(os, 'close', self._mock_os_close_oserror)
+ self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
+
+ def test_daemonize_EBADF_error(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=1,
+ idle_timeout=1)
+ self.stubs.Set(os, 'fork', lambda *args: None)
+ self.stubs.Set(os, 'setsid', lambda *args: None)
+ self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF)
+ self.stubs.Set(os, 'open', self._mock_os_open_oserror)
+ self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
+
+ def test_decode_hybi(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=False,
+ ssl_only=1,
+ idle_timeout=1)
+
+ self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
+ base64=True)
+
+ def test_do_websocket_handshake(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=0,
+ idle_timeout=1)
+ soc.scheme = 'scheme'
+ headers = {'Sec-WebSocket-Protocol': 'binary',
+ 'Sec-WebSocket-Version': '7',
+ 'Sec-WebSocket-Key': 'foo'}
+ soc.do_websocket_handshake(headers, '127.0.0.1')
+
+ def test_do_handshake(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=0,
+ idle_timeout=1)
+ self.stubs.Set(select, 'select', self._mock_select)
+ self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv')
+ self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
+
+ def test_do_handshake_ssl_error(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=0,
+ idle_timeout=1)
+
+ def _mock_wrap_socket(*args, **kwargs):
+ from ssl import SSLError
+ raise SSLError('unit test exception')
+
+ self.stubs.Set(select, 'select', self._mock_select)
+ self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16')
+ self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket)
+ self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
+
+ def test_fallback_SIGCHILD(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=0,
+ idle_timeout=1)
+ soc.fallback_SIGCHLD(None, None)
+
+ def test_start_server_Exception(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=False,
+ ssl_only=1,
+ idle_timeout=1)
+ self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
+ self.stubs.Set(websocket.WebSocketServer, 'daemonize',
+ lambda *args, **kwargs: None)
+ self.stubs.Set(select, 'select', self._mock_select_exception)
+ self.assertEqual(None, soc.start_server())
+
+ def test_start_server_KeyboardInterrupt(self):
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ cert='xxxxxx',
+ daemon=True,
+ ssl_only=1,
+ idle_timeout=1)
+ self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
+ self.stubs.Set(websocket.WebSocketServer, 'daemonize',
+ lambda *args, **kwargs: None)
+ self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt)
+ self.assertEqual(None, soc.start_server())
+
+ def test_start_server_systemexit(self):
+ websocket.ssl = None
+ soc = websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key='../',
+ web='../',
+ record='../',
+ daemon=True,
+ ssl_only=0,
+ idle_timeout=1,
+ verbose=True)
+ self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
+ self.stubs.Set(websocket.WebSocketServer, 'daemonize',
+ lambda *args, **kwargs: None)
+ self.stubs.Set(select, 'select', self._mock_select_systemexit)
+ self.assertEqual(None, soc.start_server())
+
+ def test_WSRequestHandle_do_GET_nofile(self):
+ request = 'GET /tmp.txt HTTP/0.9'
+ with tempfile.NamedTemporaryFile() as test_file:
+ test_file.write(request)
+ test_file.flush()
+ test_file.seek(0)
+ con = MockConnection(test_file.name)
+ soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True)
+ soc.path = ''
+ soc.headers = {'upgrade': ''}
+ self.stubs.Set(SimpleHTTPRequestHandler, 'send_response',
+ lambda *args: None)
+ soc.do_GET()
+ self.assertEqual(404, soc.last_code)
+
+ def test_WSRequestHandle_do_GET_hidden_resource(self):
+ request = 'GET /tmp.txt HTTP/0.9'
+ with tempfile.NamedTemporaryFile() as test_file:
+ test_file.write(request)
+ test_file.flush()
+ test_file.seek(0)
+ con = MockConnection(test_file.name)
+ soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True)
+ soc.path = test_file.name + '?'
+ soc.headers = {'upgrade': ''}
+ soc.webroot = 'no match startswith'
+ self.stubs.Set(SimpleHTTPRequestHandler,
+ 'send_response',
+ lambda *args: None)
+ soc.do_GET()
+ self.assertEqual(403, soc.last_code)
+
def testsocket_set_keepalive_options(self):
- server = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='./',
- web='./',
- record='./',
- daemon=True,
- ssl_only=1)
keepcnt = 12
keepidle = 34
keepintvl = 56
- sock = server.socket('localhost',
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
+ sock = self.server.socket('localhost',
+ tcp_keepcnt=keepcnt,
+ tcp_keepidle=keepidle,
+ tcp_keepintvl=keepintvl)
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)
@@ -55,11 +279,11 @@ class WebSocketTestCase(unittest.TestCase):
self.assertEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPINTVL), keepintvl)
- sock = server.socket('localhost',
- tcp_keepalive=False,
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
+ sock = self.server.socket('localhost',
+ tcp_keepalive=False,
+ tcp_keepcnt=keepcnt,
+ tcp_keepidle=keepidle,
+ tcp_keepintvl=keepintvl)
self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
socket.TCP_KEEPCNT), keepcnt)