diff options
author | astrand <astrand@cendio.se> | 2013-12-19 01:31:14 -0800 |
---|---|---|
committer | astrand <astrand@cendio.se> | 2013-12-19 01:31:14 -0800 |
commit | 38db12c2b0883669214240d4cf404ff08b8259b2 (patch) | |
tree | 327b4a650ed7fcbb872de61d309b778393afde8f | |
parent | b662d185ca2b01d7214fceedfb742fa97216ff03 (diff) | |
parent | a749611370da26de3f70ff001db8f65ff50ac664 (diff) | |
download | websockify-38db12c2b0883669214240d4cf404ff08b8259b2.tar.gz |
Merge pull request #111 from astrand/master
Refactor to use standard SocketServer RequestHandler design.
-rwxr-xr-x | tests/echo.py | 7 | ||||
-rwxr-xr-x | tests/load.py | 25 | ||||
-rw-r--r-- | websockify/websocket.py | 357 | ||||
-rwxr-xr-x | websockify/websocketproxy.py | 136 |
4 files changed, 317 insertions, 208 deletions
diff --git a/tests/echo.py b/tests/echo.py index ad83296..27bdc46 100755 --- a/tests/echo.py +++ b/tests/echo.py @@ -12,9 +12,9 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates import os, sys, select, optparse sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer +from websockify.websocket import WebSocketServer, WebSocketRequestHandler -class WebSocketEcho(WebSocketServer): +class WebSocketEcho(WebSocketRequestHandler): """ WebSockets server that echos back whatever is received from the client. """ @@ -49,7 +49,6 @@ class WebSocketEcho(WebSocketServer): if closed: self.send_close() - raise self.EClose(closed) if __name__ == '__main__': parser = optparse.OptionParser(usage="%prog [options] listen_port") @@ -70,6 +69,6 @@ if __name__ == '__main__': parser.error("Invalid arguments") opts.web = "." - server = WebSocketEcho(**opts.__dict__) + server = WebSocketServer(WebSocketEcho, **opts.__dict__) server.start_server() diff --git a/tests/load.py b/tests/load.py index fdff106..2565368 100755 --- a/tests/load.py +++ b/tests/load.py @@ -8,28 +8,30 @@ given a sequence number. Any errors are reported and counted. import sys, os, select, random, time, optparse sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer +from websockify.websocket import WebSocketServer, WebSocketRequestHandler -class WebSocketLoad(WebSocketServer): +class WebSocketLoadServer(WebSocketServer): - buffer_size = 65536 - - max_packet_size = 10000 recv_cnt = 0 send_cnt = 0 def __init__(self, *args, **kwargs): - self.errors = 0 self.delay = kwargs.pop('delay') + WebSocketServer.__init__(self, *args, **kwargs) + + +class WebSocketLoad(WebSocketRequestHandler): + + max_packet_size = 10000 + + def new_websocket_client(self): print "Prepopulating random array" self.rand_array = [] for i in range(0, self.max_packet_size): self.rand_array.append(random.randint(0, 9)) - WebSocketServer.__init__(self, *args, **kwargs) - - def new_websocket_client(self): + self.errors = 0 self.send_cnt = 0 self.recv_cnt = 0 @@ -61,14 +63,13 @@ class WebSocketLoad(WebSocketServer): if closed: self.send_close() - raise self.EClose(closed) now = time.time() * 1000 if client in outs: if c_pend: last_send = now c_pend = self.send_frames() - elif now > (last_send + self.delay): + elif now > (last_send + self.server.delay): last_send = now c_pend = self.send_frames([self.generate()]) @@ -162,6 +163,6 @@ if __name__ == '__main__': parser.error("Invalid arguments") opts.web = "." - server = WebSocketLoad(**opts.__dict__) + server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__) server.start_server() diff --git a/websockify/websocket.py b/websockify/websocket.py index d264273..889bc40 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -65,30 +65,51 @@ if multiprocessing and sys.platform == 'win32': import multiprocessing.reduction -class WebSocketServer(object): +# HTTP handler with WebSocket upgrade support +class WebSocketRequestHandler(SimpleHTTPRequestHandler): """ - WebSockets server class. + WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler. Must be sub-classed with new_websocket_client method definition. + The request handler can be configured by setting optional + attributes on the server object: + + * only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled, + only websocket is allowed. + * verbose: If true, verbose logging is activated. + * daemon: Running as daemon, do not write to console etc + * record: Record raw frame data as JavaScript array into specified filename + * run_once: Handle a single request + * handler_id: A sequence number for this connection, appended to record filename """ - - log_prefix = "websocket" buffer_size = 65536 GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r -Upgrade: websocket\r -Connection: Upgrade\r -Sec-WebSocket-Accept: %s\r -""" + server_version = "WebSockify" + + protocol_version = "HTTP/1.1" # An exception while the WebSocket client was connected class CClose(Exception): pass - class Terminate(Exception): - pass + def __init__(self, req, addr, server): + # Retrieve a few configuration variables from the server + self.only_upgrade = getattr(server, "only_upgrade", False) + self.verbose = getattr(server, "verbose", False) + self.daemon = getattr(server, "daemon", False) + self.record = getattr(server, "record", False) + self.run_once = getattr(server, "run_once", False) + self.rec = None + self.handler_id = getattr(server, "handler_id", False) + self.file_only = getattr(server, "file_only", False) + self.traffic = getattr(server, "traffic", False) + self.logger = getattr(server, "logger", None) + if self.logger is None: + self.logger = WebSocketServer.get_logger() + + SimpleHTTPRequestHandler.__init__(self, req, addr, server) @staticmethod def unmask(buf, hlen, plen): @@ -213,7 +234,7 @@ Sec-WebSocket-Accept: %s\r # Process 1 frame if f['masked']: # unmask payload - f['payload'] = WebSocketServer.unmask(buf, f['hlen'], + f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'], f['length']) else: logger.debug("Unmasked frame: %s" % repr(buf)) @@ -237,9 +258,33 @@ Sec-WebSocket-Accept: %s\r # - # Main WebSocketServer methods + # WebSocketRequestHandler logging/output functions # + def print_traffic(self, token="."): + """ Show traffic flow mode. """ + if self.traffic: + sys.stdout.write(token) + sys.stdout.flush() + + def msg(self, msg, *args, **kwargs): + """ Output message with handler_id prefix. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) + + def vmsg(self, msg, *args, **kwargs): + """ Same as msg() but as debug. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs) + + def warn(self, msg, *args, **kwargs): + """ Same as msg() but as warning. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs) + + # + # Main WebSocketRequestHandler methods + # def send_frames(self, bufs=None): """ Encode and send WebSocket frames. Any frames already queued will be sent first. If buf is not set then only queued @@ -322,7 +367,7 @@ Sec-WebSocket-Accept: %s\r start = frame['hlen'] end = frame['hlen'] + frame['length'] if frame['masked']: - recbuf = WebSocketServer.unmask(buf, frame['hlen'], + recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'], frame['length']) else: recbuf = buf[frame['hlen']:frame['hlen'] + @@ -347,9 +392,8 @@ Sec-WebSocket-Accept: %s\r buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) self.request.send(buf) - def do_websocket_handshake(self, headers, path): - h = self.headers = headers - self.path = path + def do_websocket_handshake(self): + h = self.headers prot = 'WebSocket-Protocol' protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') @@ -364,7 +408,8 @@ Sec-WebSocket-Accept: %s\r if ver in ['7', '8', '13']: self.version = "hybi-%02d" % int(ver) else: - raise self.EClose('Unsupported protocol version %s' % ver) + self.send_error(400, "Unsupported protocol version %s" % ver) + return False key = h['Sec-WebSocket-Key'] @@ -374,42 +419,158 @@ Sec-WebSocket-Accept: %s\r elif 'base64' in protocols: self.base64 = True else: - raise self.EClose("Client must support 'binary' or 'base64' protocol") + self.send_error(400, "Client must support 'binary' or 'base64' protocol") + return False # Generate the hash value for the accept header accept = b64encode(sha1(s2b(key + self.GUID)).digest()) - response = self.server_handshake_hybi % b2s(accept) + self.send_response(101, "Switching Protocols") + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "Upgrade") + self.send_header("Sec-WebSocket-Accept", b2s(accept)) if self.base64: - response += "Sec-WebSocket-Protocol: base64\r\n" + self.send_header("Sec-WebSocket-Protocol", "base64") else: - response += "Sec-WebSocket-Protocol: binary\r\n" - response += "\r\n" + self.send_header("Sec-WebSocket-Protocol", "binary") + self.end_headers() + return True + else: + self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + + return False + + def handle_websocket(self): + """Upgrade a connection to Websocket, if requested. If this succeeds, + new_websocket_client() will be called. Otherwise, False is returned. + """ + if (self.headers.get('upgrade') and + self.headers.get('upgrade').lower() == 'websocket'): + + if not self.do_websocket_handshake(): + return False + + # Indicate to server that a Websocket upgrade was done + self.server.ws_connection = True + # Initialize per client settings + self.send_parts = [] + self.recv_part = None + self.start_time = int(time.time()*1000) + + # client_address is empty with, say, UNIX domain sockets + client_addr = "" + is_ssl = False + try: + client_addr = self.client_address[0] + is_ssl = self.client_address[2] + except IndexError: + pass + + if is_ssl: + self.stype = "SSL/TLS (wss://)" + else: + self.stype = "Plain non-SSL (ws://)" + + self.log_message("%s: %s WebSocket connection" % (client_addr, + self.stype)) + self.log_message("%s: Version %s, base64: '%s'" % (client_addr, + self.version, self.base64)) + if self.path != '/': + self.log_message("%s: Path: '%s'" % (client_addr, self.path)) + + if self.record: + # Record raw frame data as JavaScript array + fname = "%s.%s" % (self.record, + self.handler_id) + self.log_message("opening record file: %s" % fname) + self.rec = open(fname, 'w+') + encoding = "binary" + if self.base64: encoding = "base64" + self.rec.write("var VNC_frame_encoding = '%s';\n" + % encoding) + self.rec.write("var VNC_frame_data = [\n") + try: + self.new_websocket_client() + except self.CClose: + # Close the client + _, exc, _ = sys.exc_info() + self.send_close(exc.args[0], exc.args[1]) + return True else: - raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + return False - return response + def do_GET(self): + """Handle GET request. Calls handle_websocket(). If unsuccessful, + and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called.""" + if not self.handle_websocket(): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_GET(self) + def list_directory(self, path): + if self.file_only: + self.send_error(404, "No such file") + else: + return SimpleHTTPRequestHandler.list_directory(self, path) + def new_websocket_client(self): """ Do something with a WebSockets client connection. """ - raise("WebSocketServer.new_websocket_client() must be overloaded") + raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + + def do_HEAD(self): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_HEAD(self) + + def finish(self): + if self.rec: + self.rec.write("'EOF'];\n") + self.rec.close() + + def handle(self): + # When using run_once, we have a single process, so + # we cannot loop in BaseHTTPRequestHandler.handle; we + # must return and handle new connections + if self.run_once: + self.handle_one_request() + else: + SimpleHTTPRequestHandler.handle(self) + + def log_request(self, code='-', size='-'): + if self.verbose: + SimpleHTTPRequestHandler.log_request(self, code, size) + + +class WebSocketServer(object): + """ + WebSockets server class. + As an alternative, the standard library SocketServer can be used + """ policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" + log_prefix = "websocket" # An exception before the WebSocket connection was established class EClose(Exception): pass - def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, + class Terminate(Exception): + pass + + def __init__(self, RequestHandlerClass, listen_host='', + listen_port=None, source_is_ipv6=False, verbose=False, cert='', key='', ssl_only=None, daemon=False, record='', web='', - file_only=False, no_parent=False, + file_only=False, run_once=False, timeout=0, idle_timeout=0, traffic=False, tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): # settings + self.RequestHandlerClass = RequestHandlerClass self.verbose = verbose self.listen_host = listen_host self.listen_port = listen_port @@ -423,12 +584,8 @@ Sec-WebSocket-Accept: %s\r self.launch_time = time.time() self.ws_connection = False - self.i_am_client = False self.handler_id = 1 - self.file_only = file_only - self.no_parent = no_parent - self.logger = self.get_logger() self.tcp_keepalive = tcp_keepalive self.tcp_keepcnt = tcp_keepcnt @@ -447,6 +604,7 @@ Sec-WebSocket-Accept: %s\r if self.web: os.chdir(self.web) + self.only_upgrade = not self.web # Sanity checks if not ssl and self.ssl_only: @@ -593,7 +751,6 @@ Sec-WebSocket-Accept: %s\r - Send a WebSockets handshake server response. - Return the socket for this WebSocket client. """ - stype = "" ready = select.select([sock], [], [], 3)[0] @@ -637,43 +794,19 @@ Sec-WebSocket-Accept: %s\r else: raise - self.scheme = "wss" - stype = "SSL/TLS (wss://)" - elif self.ssl_only: raise self.EClose("non-SSL connection received but disallowed") else: retsock = sock - self.scheme = "ws" - stype = "Plain non-SSL (ws://)" - - wsh = WSRequestHandler(retsock, address, not self.web, - self.file_only, self.no_parent) - if wsh.last_code == 101: - # Continue on to handle WebSocket upgrade - pass - elif wsh.last_code == 405: - raise self.EClose("Normal web request received but disallowed") - elif wsh.last_code < 200 or wsh.last_code >= 300: - raise self.EClose(wsh.last_message) - elif self.verbose: - raise self.EClose(wsh.last_message) - else: - raise self.EClose("") - response = self.do_websocket_handshake(wsh.headers, wsh.path) + # If the address is like (host, port), we are extending it + # with a flag indicating SSL. Not many other options + # available... + if len(address) == 2: + address = (address[0], address[1], (retsock != sock)) - self.msg("%s: %s WebSocket connection" % (address[0], stype)) - self.msg("%s: Version %s, base64: '%s'" % (address[0], - self.version, self.base64)) - if self.path != '/': - self.msg("%s: Path: '%s'" % (address[0], self.path)) - - - # Send server WebSockets handshake response - #self.msg("sending response [%s]" % response) - retsock.send(s2b(response)) + self.RequestHandlerClass(retsock, address, self) # Return the WebSockets socket which may be SSL wrapped return retsock @@ -681,32 +814,18 @@ Sec-WebSocket-Accept: %s\r # # WebSocketServer logging/output functions # - def print_traffic(self, token="."): - """ Show traffic flow mode. """ - if self.traffic: - sys.stdout.write(token) - sys.stdout.flush() - - - def log(self, lvl, msg, *args, **kwargs): - """ Wrapper around python logging """ - prefix = "" - if self.i_am_client: - prefix = "% 3d: " % self.handler_id - self.logger.log(lvl, "%s%s" % (prefix, msg), - *args, **kwargs) def msg(self, *args, **kwargs): - """ Output message with handler_id prefix. """ - self.log(logging.INFO, *args, **kwargs) + """ Output message as info """ + self.logger.log(logging.INFO, *args, **kwargs) def vmsg(self, *args, **kwargs): """ Same as msg() but as debug. """ - self.log(logging.DEBUG, *args, **kwargs) + self.logger.log(logging.DEBUG, *args, **kwargs) def warn(self, *args, **kwargs): """ Same as msg() but as warning. """ - self.log(logging.WARN, *args, **kwargs) + self.logger.log(logging.WARN, *args, **kwargs) # @@ -748,38 +867,11 @@ Sec-WebSocket-Accept: %s\r def top_new_client(self, startsock, address): """ Do something with a WebSockets client connection. """ - # Initialize per client settings - self.i_am_client = True - self.send_parts = [] - self.recv_part = None - self.base64 = False - self.rec = None - self.start_time = int(time.time()*1000) - # handler process + client = None try: try: - self.request = self.do_handshake(startsock, address) - - if self.record: - # Record raw frame data as JavaScript array - fname = "%s.%s" % (self.record, - self.handler_id) - self.msg("opening record file: %s" % fname) - self.rec = open(fname, 'w+') - encoding = "binary" - if self.base64: encoding = "base64" - self.rec.write("var VNC_frame_encoding = '%s';\n" - % encoding) - self.rec.write("var VNC_frame_data = [\n") - - self.ws_connection = True - self.new_websocket_client() - except self.CClose: - # Close the client - _, exc, _ = sys.exc_info() - if self.request: - self.send_close(exc.args[0], exc.args[1]) + client = self.do_handshake(startsock, address) except self.EClose: _, exc, _ = sys.exc_info() # Connection was not a WebSockets connection @@ -792,14 +884,11 @@ Sec-WebSocket-Accept: %s\r self.msg("handler exception: %s" % str(exc)) self.vmsg("exception", exc_info=True) finally: - if self.rec: - self.rec.write("'EOF'];\n") - self.rec.close() - if self.request and self.request != startsock: + if client and client != startsock: # Close the SSL wrapped socket # Original socket closed by caller - self.request.close() + client.close() def start_server(self): """ @@ -841,7 +930,6 @@ Sec-WebSocket-Accept: %s\r while True: try: try: - self.request = None startsock = None pid = err = 0 child_count = 0 @@ -940,40 +1028,3 @@ Sec-WebSocket-Accept: %s\r signal.signal(sig, func) -# HTTP handler with WebSocket upgrade support -class WSRequestHandler(SimpleHTTPRequestHandler): - def __init__(self, req, addr, only_upgrade=False, file_only=False, - no_parent=False): - self.only_upgrade = only_upgrade # only allow upgrades - self.webroot = os.path.realpath(".") - self.file_only = file_only - self.no_parent = no_parent - SimpleHTTPRequestHandler.__init__(self, req, addr, object()) - - def do_GET(self): - abspath = os.path.realpath("." + (self.path.split('?')[0])) - if (self.headers.get('upgrade') and - self.headers.get('upgrade').lower() == 'websocket'): - - # Just indicate that an WebSocket upgrade is needed - self.last_code = 101 - self.last_message = "101 Switching Protocols" - elif self.only_upgrade: - # Normal web request responses are disabled - self.last_code = 405 - self.last_message = "405 Method Not Allowed" - elif self.file_only and not os.path.isfile(abspath): - self.send_response(404, "No such file") - elif self.no_parent and not abspath.startswith(self.webroot): - self.send_response(403, "Hidden resources") - else: - SimpleHTTPRequestHandler.do_GET(self) - - def send_response(self, code, message=None): - # Save the status code - self.last_code = code - SimpleHTTPRequestHandler.send_response(self, code, message) - - def log_message(self, f, *args): - # Save instead of printing - self.last_message = f % args diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index b3eb71b..f51cc7c 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -12,6 +12,10 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' import signal, socket, optparse, time, os, sys, subprocess, logging +try: from socketserver import ForkingMixIn +except: from SocketServer import ForkingMixIn +try: from http.server import HTTPServer +except: from BaseHTTPServer import HTTPServer from select import select import websocket try: @@ -20,15 +24,7 @@ except: from cgi import parse_qs from urlparse import urlparse -class WebSocketProxy(websocket.WebSocketServer): - """ - Proxy traffic to and from a WebSockets client to a normal TCP - socket server target. All traffic to/from the client is base64 - encoded/decoded to allow binary data to be sent/received to/from - the target. - """ - - buffer_size = 65536 +class ProxyRequestHandler(websocket.WebSocketRequestHandler): traffic_legend = """ Traffic Legend: @@ -42,35 +38,31 @@ Traffic Legend: <. - Client send partial """ - # - # Routines below this point are connection handler routines and - # will be run in a separate forked process for each connection. - # - def new_websocket_client(self): """ Called after a new WebSocket connection has been established. """ # Checks if we receive a token, and look # for a valid target for it then - if self.target_cfg: - (self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) + if self.server.target_cfg: + (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path) # Connect to the target - if self.wrap_cmd: - msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) - elif self.unix_target: - msg = "connecting to unix socket: %s" % self.unix_target + if self.server.wrap_cmd: + msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) + elif self.server.unix_target: + msg = "connecting to unix socket: %s" % self.server.unix_target else: msg = "connecting to: %s:%s" % ( - self.target_host, self.target_port) + self.server.target_host, self.server.target_port) - if self.ssl_target: + if self.server.ssl_target: msg += " (using SSL)" - self.msg(msg) + self.log_message(msg) - tsock = self.socket(self.target_host, self.target_port, - connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) + tsock = websocket.WebSocketServer.socket(self.server.target_host, + self.server.target_port, + connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target) self.print_traffic(self.traffic_legend) @@ -81,8 +73,9 @@ Traffic Legend: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() - self.vmsg("%s:%s: Closed target" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Closed target" %( + self.server.target_host, self.server.target_port)) raise def get_target(self, target_cfg, path): @@ -154,8 +147,9 @@ Traffic Legend: if closed: # TODO: What about blocking on client socket? - self.vmsg("%s:%s: Client closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Client closed connection" %( + self.server.target_host, self.server.target_port)) raise self.CClose(closed['code'], closed['reason']) @@ -175,20 +169,25 @@ Traffic Legend: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: - self.vmsg("%s:%s: Target closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Target closed connection" %( + self.server.target_host, self.server.target_port)) raise self.CClose(1000, "Target closed") cqueue.append(buf) self.print_traffic("{") +class WebSocketProxy(websocket.WebSocketServer): + """ + Proxy traffic to and from a WebSockets client to a normal TCP + socket server target. All traffic to/from the client is base64 + encoded/decoded to allow binary data to be sent/received to/from + the target. + """ - # - # Routines below this point are run in the master listener - # process. - # + buffer_size = 65536 - def __init__(self, *args, **kwargs): + def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): # Save off proxy specific options self.target_host = kwargs.pop('target_host', None) self.target_port = kwargs.pop('target_port', None) @@ -229,7 +228,7 @@ Traffic Legend: "REBIND_OLD_PORT": str(kwargs['listen_port']), "REBIND_NEW_PORT": str(self.target_port)}) - websocket.WebSocketServer.__init__(self, *args, **kwargs) + websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs) def run_wrap_cmd(self): self.msg("Starting '%s'", " ".join(self.wrap_cmd)) @@ -358,6 +357,8 @@ def websockify_init(): help="Configuration file containing valid targets " "in the form 'token: host:port' or, alternatively, a " "directory containing configuration files of this form") + parser.add_option("--libserver", action="store_true", + help="use Python library SocketServer engine") (opts, args) = parser.parse_args() if opts.verbose: @@ -406,8 +407,65 @@ def websockify_init(): opts.target_cfg = os.path.abspath(opts.target_cfg) # Create and start the WebSockets proxy - server = WebSocketProxy(**opts.__dict__) - server.start_server() + libserver = opts.libserver + del opts.libserver + if libserver: + # Use standard Python SocketServer framework + server = LibProxyServer(**opts.__dict__) + server.serve_forever() + else: + # Use internal service framework + server = WebSocketProxy(**opts.__dict__) + server.start_server() + + +class LibProxyServer(ForkingMixIn, HTTPServer): + """ + Just like WebSocketProxy, but uses standard Python SocketServer + framework. + """ + + def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs): + # Save off proxy specific options + self.target_host = kwargs.pop('target_host', None) + self.target_port = kwargs.pop('target_port', None) + self.wrap_cmd = kwargs.pop('wrap_cmd', None) + self.wrap_mode = kwargs.pop('wrap_mode', None) + self.unix_target = kwargs.pop('unix_target', None) + self.ssl_target = kwargs.pop('ssl_target', None) + self.target_cfg = kwargs.pop('target_cfg', None) + self.daemon = False + self.target_cfg = None + + # Server configuration + listen_host = kwargs.pop('listen_host', '') + listen_port = kwargs.pop('listen_port', None) + web = kwargs.pop('web', '') + + # Configuration affecting base request handler + self.only_upgrade = not web + self.verbose = kwargs.pop('verbose', False) + record = kwargs.pop('record', '') + if record: + self.record = os.path.abspath(record) + self.run_once = kwargs.pop('run_once', False) + self.handler_id = 0 + + for arg in kwargs.keys(): + print("warning: option %s ignored when using --libserver" % arg) + + if web: + os.chdir(web) + + HTTPServer.__init__(self, (listen_host, listen_port), + RequestHandlerClass) + + + def process_request(self, request, client_address): + """Override process_request to implement a counter""" + self.handler_id += 1 + ForkingMixIn.process_request(self, request, client_address) + if __name__ == '__main__': websockify_init() |