From 04fd789a677a84b3c908bb8ac02413d16271d5fb Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Wed, 6 May 2015 13:49:13 -0400 Subject: Update Tests and Test Plugins This commit updates the unit tests to work with the current code and adds in tests for the auth and token plugin functionality. --- tests/test_websocket.py | 420 +++++++++++++++++++++++++++---------------- tests/test_websocketproxy.py | 199 ++++++++++---------- tests/tox.ini | 20 --- 3 files changed, 365 insertions(+), 274 deletions(-) delete mode 100644 tests/tox.ini (limited to 'tests') diff --git a/tests/test_websocket.py b/tests/test_websocket.py index c7a106f..acd7699 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -26,201 +26,303 @@ import stubout import sys import tempfile import unittest -from ssl import SSLError -from websockify import websocket as websocket -from SimpleHTTPServer import SimpleHTTPRequestHandler +import socket +import signal +from websockify import websocket + +try: + from SimpleHTTPServer import SimpleHTTPRequestHandler +except ImportError: + from http.server import SimpleHTTPRequestHandler + +try: + from StringIO import StringIO + BytesIO = StringIO +except ImportError: + from io import StringIO + from io import BytesIO + + + + +def raise_oserror(*args, **kwargs): + raise OSError('fake error') -class MockConnection(object): - def __init__(self, path): - self.path = path +class FakeSocket(object): + def __init__(self, data=''): + if isinstance(data, bytes): + self._data = data + else: + self._data = data.encode('latin_1') - def makefile(self, mode='r', bufsize=-1): - return open(self.path, mode, bufsize) + def recv(self, amt, flags=None): + res = self._data[0:amt] + if not (flags & socket.MSG_PEEK): + self._data = self._data[amt:] + return res -class WebSocketTestCase(unittest.TestCase): + def makefile(self, mode='r', buffsize=None): + if 'b' in mode: + return BytesIO(self._data) + else: + return StringIO(self._data.decode('latin_1')) - def _init_logger(self, tmpdir): - name = 'websocket-unittest' - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.propagate = True - filename = "%s.log" % (name) - handler = logging.FileHandler(filename) - handler.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(handler) +class WebSocketRequestHandlerTestCase(unittest.TestCase): def setUp(self): - """Called automatically before each test.""" - super(WebSocketTestCase, self).setUp() + super(WebSocketRequestHandlerTestCase, self).setUp() self.stubs = stubout.StubOutForTesting() - # Temporary dir for test data - self.tmpdir = tempfile.mkdtemp() - # Put log somewhere persistent - self._init_logger('./') + self.tmpdir = tempfile.mkdtemp('-websockify-tests') # Mock this out cause it screws tests up self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) - self.server = self._get_websockserver(daemon=True, - ssl_only=False) - self.soc = self.server.socket('localhost') + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + lambda *args, **kwargs: None) def tearDown(self): """Called automatically after each test.""" self.stubs.UnsetAll() - shutil.rmtree(self.tmpdir) - super(WebSocketTestCase, self).tearDown() + os.rmdir(self.tmpdir) + super(WebSocketRequestHandlerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocket.WebSocketRequestHandler, + **kwargs): + web = kwargs.pop('web', self.tmpdir) + return websocket.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=web, + record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1, + **kwargs) - def _get_websockserver(self, **kwargs): - return websocket.WebSocketServer(listen_host='localhost', - listen_port=80, - key=self.tmpdir, - web=self.tmpdir, - record=self.tmpdir, - **kwargs) + def test_normal_get_with_only_upgrade_returns_error(self): + server = self._get_server(web=None) + handler = websocket.WebSocketRequestHandler( + FakeSocket('GET /tmp.txt HTTP/1.1'), '127.0.0.1', server) - def _mock_os_open_oserror(self, file, flags): - raise OSError('') + def fake_send_response(self, code, message=None): + self.last_code = code - def _mock_os_close_oserror(self, fd): - raise OSError('') + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) - def _mock_os_close_oserror_EBADF(self, fd): - raise OSError(errno.EBADF, '') + handler.do_GET() + self.assertEqual(handler.last_code, 405) - def _mock_socket(self, *args, **kwargs): - return self.soc + def test_list_dir_with_file_only_returns_error(self): + server = self._get_server(file_only=True) + handler = websocket.WebSocketRequestHandler( + FakeSocket('GET / HTTP/1.1'), '127.0.0.1', server) - def _mock_select(self, rlist, wlist, xlist, timeout=None): - return '_mock_select' + def fake_send_response(self, code, message=None): + self.last_code = code - def _mock_select_exception(self, rlist, wlist, xlist, timeout=None): - raise Exception + self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', + fake_send_response) - def _mock_select_keyboardinterrupt(self, rlist, wlist, - xlist, timeout=None): - raise KeyboardInterrupt + handler.path = '/' + handler.do_GET() + self.assertEqual(handler.last_code, 404) - def _mock_select_systemexit(self, rlist, wlist, xlist, timeout=None): - sys.exit() - def test_daemonize_error(self): - soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: None) +class WebSocketServerTestCase(unittest.TestCase): + def setUp(self): + super(WebSocketServerTestCase, self).setUp() + self.stubs = stubout.StubOutForTesting() + self.tmpdir = tempfile.mkdtemp('-websockify-tests') + # Mock this out cause it screws tests up + self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + + def tearDown(self): + """Called automatically after each test.""" + self.stubs.UnsetAll() + os.rmdir(self.tmpdir) + super(WebSocketServerTestCase, self).tearDown() + + def _get_server(self, handler_class=websocket.WebSocketRequestHandler, + **kwargs): + return websocket.WebSocketServer( + handler_class, listen_host='localhost', + listen_port=80, key=self.tmpdir, web=self.tmpdir, + record=self.tmpdir, **kwargs) + + def test_daemonize_raises_error_while_closing_fds(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) + self.stubs.Set(signal, 'signal', lambda *args: None) self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(os, 'close', self._mock_os_close_oserror) - self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') + self.stubs.Set(os, 'close', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + + def test_daemonize_ignores_ebadf_error_while_closing_fds(self): + def raise_oserror_ebadf(fd): + raise OSError(errno.EBADF, 'fake error') - def test_daemonize_EBADF_error(self): - soc = self._get_websockserver(daemon=True, ssl_only=1, idle_timeout=1) - self.stubs.Set(os, 'fork', lambda *args: None) + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + self.stubs.Set(os, 'fork', lambda *args: 0) self.stubs.Set(os, 'setsid', lambda *args: None) - self.stubs.Set(os, 'close', self._mock_os_close_oserror_EBADF) - self.stubs.Set(os, 'open', self._mock_os_open_oserror) - self.assertRaises(OSError, soc.daemonize, keepfd=None, chdir='./') - - def test_decode_hybi(self): - soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) - self.assertRaises(Exception, soc.decode_hybi, 'a' * 128, - base64=True) - - def test_do_websocket_handshake(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - soc.scheme = 'scheme' - headers = {'Sec-WebSocket-Protocol': 'binary', - 'Sec-WebSocket-Version': '7', - 'Sec-WebSocket-Key': 'foo'} - soc.do_websocket_handshake(headers, '127.0.0.1') - - def test_do_handshake(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - self.stubs.Set(select, 'select', self._mock_select) - self.stubs.Set(socket._socketobject, 'recv', lambda *args: 'mock_recv') - self.assertRaises(Exception, soc.do_handshake, self.soc, '127.0.0.1') - - def test_do_handshake_ssl_error(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - - def _mock_wrap_socket(*args, **kwargs): - from ssl import SSLError - raise SSLError('unit test exception') - - self.stubs.Set(select, 'select', self._mock_select) - self.stubs.Set(socket._socketobject, 'recv', lambda *args: '\x16') - self.stubs.Set(ssl, 'wrap_socket', _mock_wrap_socket) - self.assertRaises(SSLError, soc.do_handshake, self.soc, '127.0.0.1') - - def test_fallback_SIGCHILD(self): - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1) - soc.fallback_SIGCHLD(None, None) - - def test_start_server_Exception(self): - soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) - self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(signal, 'signal', lambda *args: None) + self.stubs.Set(os, 'close', raise_oserror_ebadf) + self.stubs.Set(os, 'open', raise_oserror) + self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./') + + def test_handshake_fails_on_not_ready(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + FakeSocket(), '127.0.0.1') + + def test_empty_handshake_fails(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket('') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_handshake_policy_request(self): + # TODO(directxman12): implement + pass + + def test_handshake_ssl_only_without_ssl_raises_error(self): + server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_no_ssl(self): + class FakeHandler(object): + CALLED = False + def __init__(self, *args, **kwargs): + type(self).CALLED = True + + FakeHandler.CALLED = False + + server = self._get_server( + handler_class=FakeHandler, daemon=True, + ssl_only=0, idle_timeout=1) + + sock = FakeSocket('some initial data') + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock) + self.assertTrue(FakeHandler.CALLED, True) + + def test_do_handshake_ssl(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_ssl_raises_error(self): + # TODO(directxman12): implement this + pass + + def test_do_handshake_ssl_without_cert_raises_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1, + cert='afdsfasdafdsafdsafdsafdas') + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + self.stubs.Set(select, 'select', fake_select) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_do_handshake_ssl_error_eof_raises_close_error(self): + server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1) + + sock = FakeSocket("\x16some ssl data") + + def fake_select(rlist, wlist, xlist, timeout=None): + return ([sock], [], []) + + def fake_wrap_socket(*args, **kwargs): + raise ssl.SSLError(ssl.SSL_ERROR_EOF) + + self.stubs.Set(select, 'select', fake_select) + self.stubs.Set(ssl, 'wrap_socket', fake_wrap_socket) + self.assertRaises( + websocket.WebSocketServer.EClose, server.do_handshake, + sock, '127.0.0.1') + + def test_fallback_sigchld_handler(self): + # TODO(directxman12): implement this + pass + + def test_start_server_error(self): + server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise Exception("fake error") + + self.stubs.Set(websocket.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', self._mock_select_exception) - self.assertEqual(None, soc.start_server()) + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_start_server_keyboardinterrupt(self): + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + raise KeyboardInterrupt - def test_start_server_KeyboardInterrupt(self): - soc = self._get_websockserver(daemon=False, ssl_only=1, idle_timeout=1) - self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + self.stubs.Set(websocket.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', self._mock_select_keyboardinterrupt) - self.assertEqual(None, soc.start_server()) + self.stubs.Set(select, 'select', fake_select) + server.start_server() def test_start_server_systemexit(self): - websocket.ssl = None - self.stubs.Set(websocket.WebSocketServer, 'socket', self._mock_socket) + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost') + + def fake_select(rlist, wlist, xlist, timeout=None): + sys.exit() + + self.stubs.Set(websocket.WebSocketServer, 'socket', + lambda *args, **kwargs: sock) self.stubs.Set(websocket.WebSocketServer, 'daemonize', lambda *args, **kwargs: None) - self.stubs.Set(select, 'select', self._mock_select_systemexit) - soc = self._get_websockserver(daemon=True, ssl_only=0, idle_timeout=1, - verbose=True) - self.assertEqual(None, soc.start_server()) - - def test_WSRequestHandle_do_GET_nofile(self): - request = 'GET /tmp.txt HTTP/0.9' - with tempfile.NamedTemporaryFile() as test_file: - test_file.write(request) - test_file.flush() - test_file.seek(0) - con = MockConnection(test_file.name) - soc = websocket.WSRequestHandler(con, "127.0.0.1", file_only=True) - soc.path = '' - soc.headers = {'upgrade': ''} - self.stubs.Set(SimpleHTTPRequestHandler, 'send_response', - lambda *args: None) - soc.do_GET() - self.assertEqual(404, soc.last_code) - - def test_WSRequestHandle_do_GET_hidden_resource(self): - request = 'GET /tmp.txt HTTP/0.9' - with tempfile.NamedTemporaryFile() as test_file: - test_file.write(request) - test_file.flush() - test_file.seek(0) - con = MockConnection(test_file.name) - soc = websocket.WSRequestHandler(con, '127.0.0.1', no_parent=True) - soc.path = test_file.name + '?' - soc.headers = {'upgrade': ''} - soc.webroot = 'no match startswith' - self.stubs.Set(SimpleHTTPRequestHandler, - 'send_response', - lambda *args: None) - soc.do_GET() - self.assertEqual(403, soc.last_code) - - def testsocket_set_keepalive_options(self): + self.stubs.Set(select, 'select', fake_select) + server.start_server() + + def test_socket_set_keepalive_options(self): keepcnt = 12 keepidle = 34 keepintvl = 56 - sock = self.server.socket('localhost', - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1) + sock = server.socket('localhost', + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) @@ -229,11 +331,11 @@ class WebSocketTestCase(unittest.TestCase): self.assertEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL), keepintvl) - sock = self.server.socket('localhost', - tcp_keepalive=False, - tcp_keepcnt=keepcnt, - tcp_keepidle=keepidle, - tcp_keepintvl=keepintvl) + sock = server.socket('localhost', + tcp_keepalive=False, + tcp_keepcnt=keepcnt, + tcp_keepidle=keepidle, + tcp_keepintvl=keepintvl) self.assertNotEqual(sock.getsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT), keepcnt) diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index cf940ae..8103ef6 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -1,6 +1,6 @@ # vim: tabstop=4 shiftwidth=4 softtabstop=4 -# Copyright(c)2013 NTT corp. All Rights Reserved. +# Copyright(c) 2015 Red Hat, Inc All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -15,113 +15,122 @@ # under the License. """ Unit tests for websocketproxy """ -import os -import logging -import select -import shutil -import stubout -import subprocess -import tempfile -import time + import unittest +import unittest +import socket +import stubout + +from websockify import websocket from websockify import websocketproxy +from websockify import token_plugins +from websockify import auth_plugins +try: + from StringIO import StringIO + BytesIO = StringIO +except ImportError: + from io import StringIO + from io import BytesIO -class MockSocket(object): - def __init__(*args, **kwargs): - pass - def shutdown(*args): - pass +class FakeSocket(object): + def __init__(self, data=''): + if isinstance(data, bytes): + self._data = data + else: + self._data = data.encode('latin_1') - def close(*args): - pass + def recv(self, amt, flags=None): + res = self._data[0:amt] + if not (flags & socket.MSG_PEEK): + self._data = self._data[amt:] + + return res + + def makefile(self, mode='r', buffsize=None): + if 'b' in mode: + return BytesIO(self._data) + else: + return StringIO(self._data.decode('latin_1')) -class WebSocketProxyTest(unittest.TestCase): +class FakeServer(object): + class EClose(Exception): + pass - def _init_logger(self, tmpdir): - name = 'websocket-unittest' - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.propagate = True - filename = "%s.log" % (name) - handler = logging.FileHandler(filename) - handler.setFormatter(logging.Formatter("%(message)s")) - logger.addHandler(handler) + def __init__(self): + self.token_plugin = None + self.auth_plugin = None + self.wrap_cmd = None + self.ssl_target = None + self.unix_target = None +class ProxyRequestHandlerTestCase(unittest.TestCase): def setUp(self): - """Called automatically before each test.""" - super(WebSocketProxyTest, self).setUp() - self.soc = '' + super(ProxyRequestHandlerTestCase, self).setUp() self.stubs = stubout.StubOutForTesting() - # Temporary dir for test data - self.tmpdir = tempfile.mkdtemp() - # Put log somewhere persistent - self._init_logger('./') - # Mock this out cause it screws tests up - self.stubs.Set(os, 'chdir', lambda *args, **kwargs: None) + self.handler = websocketproxy.ProxyRequestHandler( + FakeSocket(''), "127.0.0.1", FakeServer()) + self.handler.path = "https://localhost:6080/websockify?token=blah" + self.handler.headers = None + self.stubs.Set(websocket.WebSocketServer, 'socket', + staticmethod(lambda *args, **kwargs: None)) def tearDown(self): - """Called automatically after each test.""" self.stubs.UnsetAll() - shutil.rmtree(self.tmpdir) - super(WebSocketProxyTest, self).tearDown() - - def _get_websockproxy(self, **kwargs): - return websocketproxy.WebSocketProxy(key=self.tmpdir, - web=self.tmpdir, - record=self.tmpdir, - **kwargs) - - def test_run_wrap_cmd(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" - - def mock_Popen(*args, **kwargs): - return '_mock_cmd' - - self.stubs.Set(subprocess, 'Popen', mock_Popen) - web_socket_proxy.run_wrap_cmd() - self.assertEquals(web_socket_proxy.spawn_message, True) - - def test_started(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["spawn_message"] = False - web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" - - def mock_run_wrap_cmd(*args, **kwargs): - web_socket_proxy.__dict__["spawn_message"] = True - - self.stubs.Set(web_socket_proxy, 'run_wrap_cmd', mock_run_wrap_cmd) - web_socket_proxy.started() - self.assertEquals(web_socket_proxy.__dict__["spawn_message"], True) - - def test_poll(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["wrap_cmd"] = "wrap_cmd" - web_socket_proxy.__dict__["wrap_mode"] = "respawn" - web_socket_proxy.__dict__["wrap_times"] = [99999999] - web_socket_proxy.__dict__["spawn_message"] = True - web_socket_proxy.__dict__["cmd"] = None - self.stubs.Set(time, 'time', lambda: 100000000.000) - web_socket_proxy.poll() - self.assertEquals(web_socket_proxy.spawn_message, False) - - def test_new_client(self): - web_socket_proxy = self._get_websockproxy() - web_socket_proxy.__dict__["verbose"] = "verbose" - web_socket_proxy.__dict__["daemon"] = None - web_socket_proxy.__dict__["client"] = "client" - - self.stubs.Set(web_socket_proxy, 'socket', MockSocket) - - def mock_select(*args, **kwargs): - ins = None - outs = None - excepts = "excepts" - return ins, outs, excepts - - self.stubs.Set(select, 'select', mock_select) - self.assertRaises(Exception, web_socket_proxy.new_websocket_client) + super(ProxyRequestHandlerTestCase, self).tearDown() + + def test_get_target(self): + class TestPlugin(token_plugins.BasePlugin): + def lookup(self, token): + return ("some host", "some port") + + host, port = self.handler.get_target( + TestPlugin(None), self.handler.path) + + self.assertEqual(host, "some host") + self.assertEqual(port, "some port") + + def test_get_target_raises_error_on_unknown_token(self): + class TestPlugin(token_plugins.BasePlugin): + def lookup(self, token): + return None + + self.assertRaises(FakeServer.EClose, self.handler.get_target, + TestPlugin(None), "https://localhost:6080/websockify?token=blah") + + def test_token_plugin(self): + class TestPlugin(token_plugins.BasePlugin): + def lookup(self, token): + return (self.source + token).split(',') + + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', + lambda *args, **kwargs: None) + + self.handler.server.token_plugin = TestPlugin("somehost,") + self.handler.new_websocket_client() + + self.assertEqual(self.handler.server.target_host, "somehost") + self.assertEqual(self.handler.server.target_port, "blah") + + def test_auth_plugin(self): + class TestPlugin(auth_plugins.BasePlugin): + def authenticate(self, headers, target_host, target_port): + if target_host == self.source: + raise auth_plugins.AuthenticationError("some error") + + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', + staticmethod(lambda *args, **kwargs: None)) + + self.handler.server.auth_plugin = TestPlugin("somehost") + self.handler.server.target_host = "somehost" + self.handler.server.target_port = "someport" + + self.assertRaises(auth_plugins.AuthenticationError, + self.handler.new_websocket_client) + + self.handler.server.target_host = "someotherhost" + self.handler.new_websocket_client() + diff --git a/tests/tox.ini b/tests/tox.ini deleted file mode 100644 index 098e89c..0000000 --- a/tests/tox.ini +++ /dev/null @@ -1,20 +0,0 @@ -# Tox (http://tox.testrun.org/) is a tool for running tests -# in multiple virtualenvs. This configuration file will run the -# test suite on all supported python versions. To use it, "pip install tox" -# and then run "tox" from this directory. - -[tox] -envlist = py24,py25,py26,py27,py30 -setupdir = ../ - -[testenv] -commands = nosetests {posargs} -deps = - mox - nose - -# At some point we should enable this since tox epdctes it to exist but -# the code will need pep8ising first. -#[testenv:pep8] -#commands = flake8 -#dep = flake8 -- cgit v1.2.1