summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoel Martin <github@martintribe.org>2013-11-20 05:21:18 -0800
committerJoel Martin <github@martintribe.org>2013-11-20 05:21:18 -0800
commita04edfe80f54b44df5a3579f71710560c6b7b4fc (patch)
treef4d2944bfe482591c6527672fa55d70a992acadf
parenta47be21f9fa69ddf8d888ff9e3c75cdfc9e31c00 (diff)
parent32c1abd5d9643296e0d30abe2b2ccde324d5abcc (diff)
downloadwebsockify-a04edfe80f54b44df5a3579f71710560c6b7b4fc.tar.gz
Merge pull request #105 from dosaboy/topic/unit-test-cleanup
Added temp dir for unit test data and cleanup
-rw-r--r--tests/test_websocket.py130
-rw-r--r--tests/test_websocketproxy.py44
-rw-r--r--tests/tox.ini8
-rw-r--r--websockify/websocket.py15
4 files changed, 90 insertions, 107 deletions
diff --git a/tests/test_websocket.py b/tests/test_websocket.py
index 49efe81..c7a106f 100644
--- a/tests/test_websocket.py
+++ b/tests/test_websocket.py
@@ -17,7 +17,9 @@
""" Unit tests for websocket """
import errno
import os
+import logging
import select
+import shutil
import socket
import ssl
import stubout
@@ -39,24 +41,44 @@ class MockConnection(object):
class WebSocketTestCase(unittest.TestCase):
+ def _init_logger(self, tmpdir):
+ name = 'websocket-unittest'
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = True
+ filename = "%s.log" % (name)
+ handler = logging.FileHandler(filename)
+ handler.setFormatter(logging.Formatter("%(message)s"))
+ logger.addHandler(handler)
+
def setUp(self):
"""Called automatically before each test."""
super(WebSocketTestCase, self).setUp()
self.stubs = stubout.StubOutForTesting()
- self.server = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='./',
- web='./',
- record='./',
- daemon=True,
- ssl_only=False)
+ # Temporary dir for test data
+ self.tmpdir = tempfile.mkdtemp()
+ # Put log somewhere persistent
+ self._init_logger('./')
+ # Mock this out cause it screws tests up
+ self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
+ self.server = self._get_websockserver(daemon=True,
+ ssl_only=False)
self.soc = self.server.socket('localhost')
def tearDown(self):
"""Called automatically after each test."""
self.stubs.UnsetAll()
+ shutil.rmtree(self.tmpdir)
super(WebSocketTestCase, self).tearDown()
+ def _get_websockserver(self, **kwargs):
+ return websocket.WebSocketServer(listen_host='localhost',
+ listen_port=80,
+ key=self.tmpdir,
+ web=self.tmpdir,
+ record=self.tmpdir,
+ **kwargs)
+
def _mock_os_open_oserror(self, file, flags):
raise OSError('')
@@ -83,28 +105,14 @@ class WebSocketTestCase(unittest.TestCase):
sys.exit()
def test_daemonize_error(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=1,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror)
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
def test_daemonize_EBADF_error(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=1,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1)
self.stubs.Set(os, 'fork', lambda *args: None)
self.stubs.Set(os, 'setsid', lambda *args: None)
self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF)
@@ -112,27 +120,12 @@ class WebSocketTestCase(unittest.TestCase):
self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./')
def test_decode_hybi(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=False,
- ssl_only=1,
- idle_timeout=1)
-
+ soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
self.assertRaises(Exception, soc.decode_hybi, 'a' * 128,
base64=True)
def test_do_websocket_handshake(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=0,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
soc.scheme = 'scheme'
headers = {'Sec-WebSocket-Protocol': 'binary',
'Sec-WebSocket-Version': '7',
@@ -140,27 +133,13 @@ class WebSocketTestCase(unittest.TestCase):
soc.do_websocket_handshake(headers, '127.0.0.1')
def test_do_handshake(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=0,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
self.stubs.Set(select, 'select', self._mock_select)
self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv')
self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1')
def test_do_handshake_ssl_error(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=0,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
def _mock_wrap_socket(*args, **kwargs):
from ssl import SSLError
@@ -172,25 +151,11 @@ class WebSocketTestCase(unittest.TestCase):
self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1')
def test_fallback_SIGCHILD(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=0,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1)
soc.fallback_SIGCHLD(None, None)
def test_start_server_Exception(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=False,
- ssl_only=1,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
@@ -198,15 +163,7 @@ class WebSocketTestCase(unittest.TestCase):
self.assertEqual(None, soc.start_server())
def test_start_server_KeyboardInterrupt(self):
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- cert='xxxxxx',
- daemon=True,
- ssl_only=1,
- idle_timeout=1)
+ soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
@@ -215,19 +172,12 @@ class WebSocketTestCase(unittest.TestCase):
def test_start_server_systemexit(self):
websocket.ssl = None
- soc = websocket.WebSocketServer(listen_host='localhost',
- listen_port=80,
- key='../',
- web='../',
- record='../',
- daemon=True,
- ssl_only=0,
- idle_timeout=1,
- verbose=True)
self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket)
self.stubs.Set(websocket.WebSocketServer, 'daemonize',
lambda *args, **kwargs: None)
self.stubs.Set(select, 'select', self._mock_select_systemexit)
+ soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1,
+ verbose=True)
self.assertEqual(None, soc.start_server())
def test_WSRequestHandle_do_GET_nofile(self):
diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py
index 0197ce5..0fdd0fb 100644
--- a/tests/test_websocketproxy.py
+++ b/tests/test_websocketproxy.py
@@ -15,11 +15,15 @@
# under the License.
""" Unit tests for websocketproxy """
-import unittest
-import time
-import subprocess
-import stubout
+import os
+import logging
import select
+import shutil
+import stubout
+import subprocess
+import tempfile
+import time
+import unittest
from websockify import websocketproxy
@@ -37,20 +41,42 @@ class MockSocket(object):
class WebSocketProxyTest(unittest.TestCase):
+ def _init_logger(self, tmpdir):
+ name = 'websocket-unittest'
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = True
+ filename = "%s.log" % (name)
+ handler = logging.FileHandler(filename)
+ handler.setFormatter(logging.Formatter("%(message)s"))
+ logger.addHandler(handler)
+
def setUp(self):
"""Called automatically before each test."""
super(WebSocketProxyTest, self).setUp()
-
self.soc = ''
self.stubs = stubout.StubOutForTesting()
+ # Temporary dir for test data
+ self.tmpdir = tempfile.mkdtemp()
+ # Put log somewhere persistent
+ self._init_logger('./')
+ # Mock this out cause it screws tests up
+ self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None)
def tearDown(self):
"""Called automatically after each test."""
self.stubs.UnsetAll()
+ shutil.rmtree(self.tmpdir)
super(WebSocketProxyTest, self).tearDown()
+ def _get_websockproxy(self, **kwargs):
+ return websocketproxy.WebSocketProxy(key=self.tmpdir,
+ web=self.tmpdir,
+ record=self.tmpdir,
+ **kwargs)
+
def test_run_wrap_cmd(self):
- web_socket_proxy = websocketproxy.WebSocketProxy()
+ web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
def mock_Popen(*args, **kwargs):
@@ -61,7 +87,7 @@ class WebSocketProxyTest(unittest.TestCase):
self.assertEquals(web_socket_proxy.spawn_message, True)
def test_started(self):
- web_socket_proxy = websocketproxy.WebSocketProxy()
+ web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["spawn_message"] = False
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
@@ -73,7 +99,7 @@ class WebSocketProxyTest(unittest.TestCase):
self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True)
def test_poll(self):
- web_socket_proxy = websocketproxy.WebSocketProxy()
+ web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd"
web_socket_proxy.__dict__["wrap_mode"] = "respawn"
web_socket_proxy.__dict__["wrap_times"] = [99999999]
@@ -84,7 +110,7 @@ class WebSocketProxyTest(unittest.TestCase):
self.assertEquals(web_socket_proxy.spawn_message, False)
def test_new_client(self):
- web_socket_proxy = websocketproxy.WebSocketProxy()
+ web_socket_proxy = self._get_websockproxy()
web_socket_proxy.__dict__["verbose"] = "verbose"
web_socket_proxy.__dict__["daemon"] = None
web_socket_proxy.__dict__["client"] = "client"
diff --git a/tests/tox.ini b/tests/tox.ini
index 4f28f3f..098e89c 100644
--- a/tests/tox.ini
+++ b/tests/tox.ini
@@ -4,7 +4,7 @@
# and then run "tox" from this directory.
[tox]
-envlist = py24, py25, py26, py27, py30
+envlist = py24,py25,py26,py27,py30
setupdir = ../
[testenv]
@@ -12,3 +12,9 @@ commands = nosetests {posargs}
deps =
mox
nose
+
+# At some point we should enable this since tox epdctes it to exist but
+# the code will need pep8ising first.
+#[testenv:pep8]
+#commands = flake8
+#dep = flake8
diff --git a/websockify/websocket.py b/websockify/websocket.py
index d93a1fc..d5ea96b 100644
--- a/websockify/websocket.py
+++ b/websockify/websocket.py
@@ -331,7 +331,7 @@ Sec-WebSocket-Accept: %s\r
return header + buf, len(header), 0
@staticmethod
- def decode_hybi(buf, base64=False):
+ def decode_hybi(buf, base64=False, logger=None):
""" Decode HyBi style WebSocket packets.
Returns:
{'fin' : 0_or_1,
@@ -355,7 +355,8 @@ Sec-WebSocket-Accept: %s\r
'close_code' : 1000,
'close_reason' : ''}
- logger = WebSocketServer.get_logger()
+ if logger is None:
+ logger = WebSocketServer.get_logger()
blen = len(buf)
f['left'] = blen
@@ -395,16 +396,15 @@ Sec-WebSocket-Accept: %s\r
f['payload'] = WebSocketServer.unmask(buf, f['hlen'],
f['length'])
else:
- self.vmsg("Unmasked frame: %s" % repr(buf))
+ logger.debug("Unmasked frame: %s" % repr(buf))
f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len]
if base64 and f['opcode'] in [1, 2]:
try:
f['payload'] = b64decode(f['payload'])
except:
- self.warn("Exception while b64decoding buffer: %s",
- repr(buf))
- self.vmsg('Exception', exc_info=True)
+ logger.exception("Exception while b64decoding buffer: %s" %
+ (repr(buf)))
raise
if f['opcode'] == 0x08:
@@ -510,7 +510,8 @@ Sec-WebSocket-Accept: %s\r
self.recv_part = None
while buf:
- frame = self.decode_hybi(buf, base64=self.base64)
+ frame = self.decode_hybi(buf, base64=self.base64,
+ logger=self.logger)
#self.msg("Received buf: %s, frame: %s", repr(buf), frame)
if frame['payload'] == None: