diff options
author | samhed <samuel@cendio.se> | 2014-04-14 14:45:15 +0200 |
---|---|---|
committer | samhed <samuel@cendio.se> | 2014-04-14 14:45:15 +0200 |
commit | 082027dc8742292ce3e3de0f742fed5c13e68337 (patch) | |
tree | af85b869182f9caf777dade99ab3dd88ffbb41cd | |
parent | 60a415ae1c338856b1e47ea3517067e6e8257c85 (diff) | |
download | novnc-082027dc8742292ce3e3de0f742fed5c13e68337.tar.gz |
Sync with websockify
Pull 90b519edf0c1857d
-rw-r--r-- | include/websock.js | 16 | ||||
-rw-r--r-- | utils/websocket.py | 915 | ||||
-rwxr-xr-x | utils/websockify | 388 |
3 files changed, 768 insertions, 551 deletions
diff --git a/include/websock.js b/include/websock.js index d3790fe..01a24c3 100644 --- a/include/websock.js +++ b/include/websock.js @@ -262,7 +262,7 @@ function on(evt, handler) { eventHandlers[evt] = handler; } -function init(protocols, ws_schema) { +function init(protocols) { rQ = []; rQi = 0; sQ = []; @@ -278,14 +278,11 @@ function init(protocols, ws_schema) { bt = true; } - // Check for full binary type support in WebSocket - // Inspired by: - // https://github.com/Modernizr/Modernizr/issues/370 - // https://github.com/Modernizr/Modernizr/blob/master/feature-detects/websockets/binary.js + // Check for full binary type support in WebSockets + // TODO: this sucks, the property should exist on the prototype + // but it does not. try { - if (bt && - ('binaryType' in WebSocket.prototype || - !!(new WebSocket(ws_schema + '://.').binaryType))) { + if (bt && ('binaryType' in (new WebSocket("ws://localhost:17523")))) { Util.Info("Detected binaryType support in WebSockets"); wsbt = true; } @@ -328,8 +325,7 @@ function init(protocols, ws_schema) { } function open(uri, protocols) { - var ws_schema = uri.match(/^([a-z]+):\/\//)[1]; - protocols = init(protocols, ws_schema); + protocols = init(protocols); if (test_mode) { websocket = {}; diff --git a/utils/websocket.py b/utils/websocket.py index 8240fc3..67f5aef 100644 --- a/utils/websocket.py +++ b/utils/websocket.py @@ -16,7 +16,7 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' -import os, sys, time, errno, signal, socket, traceback, select +import os, sys, time, errno, signal, socket, select, logging import array, struct from base64 import b64encode, b64decode @@ -59,177 +59,57 @@ for mod, msg in [('numpy', 'HyBi protocol will be slower'), except ImportError: globals()[mod] = None print("WARNING: no '%s' module, %s" % (mod, msg)) + if multiprocessing and sys.platform == 'win32': # make sockets pickle-able/inheritable import multiprocessing.reduction -class WebSocketServer(object): +# HTTP handler with WebSocket upgrade support +class WebSocketRequestHandler(SimpleHTTPRequestHandler): """ - WebSockets server class. - Must be sub-classed with new_client method definition. + WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler. + Must be sub-classed with new_websocket_client method definition. + The request handler can be configured by setting optional + attributes on the server object: + + * only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled, + only websocket is allowed. + * verbose: If true, verbose logging is activated. + * daemon: Running as daemon, do not write to console etc + * record: Record raw frame data as JavaScript array into specified filename + * run_once: Handle a single request + * handler_id: A sequence number for this connection, appended to record filename """ - buffer_size = 65536 - server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r -Upgrade: websocket\r -Connection: Upgrade\r -Sec-WebSocket-Accept: %s\r -""" - GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" + server_version = "WebSockify" - # An exception before the WebSocket connection was established - class EClose(Exception): - pass + protocol_version = "HTTP/1.1" # An exception while the WebSocket client was connected class CClose(Exception): pass - def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, - verbose=False, cert='', key='', ssl_only=None, - daemon=False, record='', web='', - run_once=False, timeout=0, idle_timeout=0): - - # settings - self.verbose = verbose - self.listen_host = listen_host - self.listen_port = listen_port - self.prefer_ipv6 = source_is_ipv6 - self.ssl_only = ssl_only - self.daemon = daemon - self.run_once = run_once - self.timeout = timeout - self.idle_timeout = idle_timeout - - self.launch_time = time.time() - self.ws_connection = False - self.handler_id = 1 - - # Make paths settings absolute - self.cert = os.path.abspath(cert) - self.key = self.web = self.record = '' - if key: - self.key = os.path.abspath(key) - if web: - self.web = os.path.abspath(web) - if record: - self.record = os.path.abspath(record) - - if self.web: - os.chdir(self.web) - - # Sanity checks - if not ssl and self.ssl_only: - raise Exception("No 'ssl' module and SSL-only specified") - if self.daemon and not resource: - raise Exception("Module 'resource' required to daemonize") - - # Show configuration - print("WebSocket server settings:") - print(" - Listen on %s:%s" % ( - self.listen_host, self.listen_port)) - print(" - Flash security policy server") - if self.web: - print(" - Web server. Web root: %s" % self.web) - if ssl: - if os.path.exists(self.cert): - print(" - SSL/TLS support") - if self.ssl_only: - print(" - Deny non-SSL/TLS connections") - else: - print(" - No SSL/TLS support (no cert file)") - else: - print(" - No SSL/TLS support (no 'ssl' module)") - if self.daemon: - print(" - Backgrounding (daemon)") - if self.record: - print(" - Recording to '%s.*'" % self.record) - - # - # WebSocketServer static methods - # - - @staticmethod - def socket(host, port=None, connect=False, prefer_ipv6=False, unix_socket=None, use_ssl=False): - """ Resolve a host (and optional port) to an IPv4 or IPv6 - address. Create a socket. Bind to it if listen is set, - otherwise connect to it. Return the socket. - """ - flags = 0 - if host == '': - host = None - if connect and not (port or unix_socket): - raise Exception("Connect mode requires a port") - if use_ssl and not ssl: - raise Exception("SSL socket requested but Python SSL module not loaded."); - if not connect and use_ssl: - raise Exception("SSL only supported in connect mode (for now)") - if not connect: - flags = flags | socket.AI_PASSIVE - - if not unix_socket: - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, - socket.IPPROTO_TCP, flags) - if not addrs: - raise Exception("Could not resolve host '%s'" % host) - addrs.sort(key=lambda x: x[0]) - if prefer_ipv6: - addrs.reverse() - sock = socket.socket(addrs[0][0], addrs[0][1]) - if connect: - sock.connect(addrs[0][4]) - if use_ssl: - sock = ssl.wrap_socket(sock) - else: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(addrs[0][4]) - sock.listen(100) - else: - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(unix_socket) - - return sock - - @staticmethod - def daemonize(keepfd=None, chdir='/'): - os.umask(0) - if chdir: - os.chdir(chdir) - else: - os.chdir('/') - os.setgid(os.getgid()) # relinquish elevations - os.setuid(os.getuid()) # relinquish elevations - - # Double fork to daemonize - if os.fork() > 0: os._exit(0) # Parent exits - os.setsid() # Obtain new process group - if os.fork() > 0: os._exit(0) # Parent exits - - # Signal handling - def terminate(a,b): os._exit(0) - signal.signal(signal.SIGTERM, terminate) - signal.signal(signal.SIGINT, signal.SIG_IGN) + def __init__(self, req, addr, server): + # Retrieve a few configuration variables from the server + self.only_upgrade = getattr(server, "only_upgrade", False) + self.verbose = getattr(server, "verbose", False) + self.daemon = getattr(server, "daemon", False) + self.record = getattr(server, "record", False) + self.run_once = getattr(server, "run_once", False) + self.rec = None + self.handler_id = getattr(server, "handler_id", False) + self.file_only = getattr(server, "file_only", False) + self.traffic = getattr(server, "traffic", False) - # Close open files - maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] - if maxfd == resource.RLIM_INFINITY: maxfd = 256 - for fd in reversed(range(maxfd)): - try: - if fd != keepfd: - os.close(fd) - except OSError: - _, exc, _ = sys.exc_info() - if exc.errno != errno.EBADF: raise + self.logger = getattr(server, "logger", None) + if self.logger is None: + self.logger = WebSocketServer.get_logger() - # Redirect I/O to /dev/null - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) + SimpleHTTPRequestHandler.__init__(self, req, addr, server) @staticmethod def unmask(buf, hlen, plen): @@ -246,7 +126,7 @@ Sec-WebSocket-Accept: %s\r b = numpy.bitwise_xor(data, mask).tostring() if plen % 4: - #print("Partial unmask") + #self.msg("Partial unmask") mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'), offset=hlen, count=(plen % 4)) data = numpy.frombuffer(buf, dtype=numpy.dtype('B'), @@ -287,12 +167,12 @@ Sec-WebSocket-Accept: %s\r elif payload_len >= 65536: header = pack('>BBQ', b1, 127, payload_len) - #print("Encoded: %s" % repr(header + buf)) + #self.msg("Encoded: %s", repr(header + buf)) return header + buf, len(header), 0 @staticmethod - def decode_hybi(buf, base64=False): + def decode_hybi(buf, base64=False, logger=None): """ Decode HyBi style WebSocket packets. Returns: {'fin' : 0_or_1, @@ -316,6 +196,9 @@ Sec-WebSocket-Accept: %s\r 'close_code' : 1000, 'close_reason' : ''} + if logger is None: + logger = WebSocketServer.get_logger() + blen = len(buf) f['left'] = blen @@ -351,18 +234,18 @@ Sec-WebSocket-Accept: %s\r # Process 1 frame if f['masked']: # unmask payload - f['payload'] = WebSocketServer.unmask(buf, f['hlen'], + f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'], f['length']) else: - print("Unmasked frame: %s" % repr(buf)) + logger.debug("Unmasked frame: %s" % repr(buf)) f['payload'] = buf[(f['hlen'] + f['masked'] * 4):full_len] if base64 and f['opcode'] in [1, 2]: try: f['payload'] = b64decode(f['payload']) except: - print("Exception while b64decoding buffer: %s" % - repr(buf)) + logger.exception("Exception while b64decoding buffer: %s" % + (repr(buf))) raise if f['opcode'] == 0x08: @@ -375,27 +258,32 @@ Sec-WebSocket-Accept: %s\r # - # WebSocketServer logging/output functions + # WebSocketRequestHandler logging/output functions # - def traffic(self, token="."): - """ Show traffic flow in verbose mode. """ - if self.verbose and not self.daemon: + def print_traffic(self, token="."): + """ Show traffic flow mode. """ + if self.traffic: sys.stdout.write(token) sys.stdout.flush() - def msg(self, msg): + def msg(self, msg, *args, **kwargs): """ Output message with handler_id prefix. """ - if not self.daemon: - print("% 3d: %s" % (self.handler_id, msg)) + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.INFO, "%s%s" % (prefix, msg), *args, **kwargs) - def vmsg(self, msg): - """ Same as msg() but only if verbose. """ - if self.verbose: - self.msg(msg) + def vmsg(self, msg, *args, **kwargs): + """ Same as msg() but as debug. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.DEBUG, "%s%s" % (prefix, msg), *args, **kwargs) + + def warn(self, msg, *args, **kwargs): + """ Same as msg() but as warning. """ + prefix = "% 3d: " % self.handler_id + self.logger.log(logging.WARN, "%s%s" % (prefix, msg), *args, **kwargs) # - # Main WebSocketServer methods + # Main WebSocketRequestHandler methods # def send_frames(self, bufs=None): """ Encode and send WebSocket frames. Any frames already @@ -424,12 +312,12 @@ Sec-WebSocket-Accept: %s\r while self.send_parts: # Send pending frames buf = self.send_parts.pop(0) - sent = self.client.send(buf) + sent = self.request.send(buf) if sent == len(buf): - self.traffic("<") + self.print_traffic("<") else: - self.traffic("<.") + self.print_traffic("<.") self.send_parts.insert(0, buf[sent:]) break @@ -446,7 +334,7 @@ Sec-WebSocket-Accept: %s\r bufs = [] tdelta = int(time.time()*1000) - self.start_time - buf = self.client.recv(self.buffer_size) + buf = self.request.recv(self.buffer_size) if len(buf) == 0: closed = {'code': 1000, 'reason': "Client closed abruptly"} return bufs, closed @@ -457,12 +345,13 @@ Sec-WebSocket-Accept: %s\r self.recv_part = None while buf: - frame = self.decode_hybi(buf, base64=self.base64) - #print("Received buf: %s, frame: %s" % (repr(buf), frame)) + frame = self.decode_hybi(buf, base64=self.base64, + logger=self.logger) + #self.msg("Received buf: %s, frame: %s", repr(buf), frame) if frame['payload'] == None: # Incomplete/partial frame - self.traffic("}.") + self.print_traffic("}.") if frame['left'] > 0: self.recv_part = buf[-frame['left']:] break @@ -472,13 +361,13 @@ Sec-WebSocket-Accept: %s\r 'reason': frame['close_reason']} break - self.traffic("}") + self.print_traffic("}") if self.rec: start = frame['hlen'] end = frame['hlen'] + frame['length'] if frame['masked']: - recbuf = WebSocketServer.unmask(buf, frame['hlen'], + recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'], frame['length']) else: recbuf = buf[frame['hlen']:frame['hlen'] + @@ -501,11 +390,10 @@ Sec-WebSocket-Accept: %s\r msg = pack(">H%ds" % len(reason), code, reason) buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) - self.client.send(buf) + self.request.send(buf) - def do_websocket_handshake(self, headers, path): - h = self.headers = headers - self.path = path + def do_websocket_handshake(self): + h = self.headers prot = 'WebSocket-Protocol' protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') @@ -520,7 +408,8 @@ Sec-WebSocket-Accept: %s\r if ver in ['7', '8', '13']: self.version = "hybi-%02d" % int(ver) else: - raise self.EClose('Unsupported protocol version %s' % ver) + self.send_error(400, "Unsupported protocol version %s" % ver) + return False key = h['Sec-WebSocket-Key'] @@ -530,23 +419,320 @@ Sec-WebSocket-Accept: %s\r elif 'base64' in protocols: self.base64 = True else: - raise self.EClose("Client must support 'binary' or 'base64' protocol") + self.send_error(400, "Client must support 'binary' or 'base64' protocol") + return False # Generate the hash value for the accept header accept = b64encode(sha1(s2b(key + self.GUID)).digest()) - response = self.server_handshake_hybi % b2s(accept) + self.send_response(101, "Switching Protocols") + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "Upgrade") + self.send_header("Sec-WebSocket-Accept", b2s(accept)) if self.base64: - response += "Sec-WebSocket-Protocol: base64\r\n" + self.send_header("Sec-WebSocket-Protocol", "base64") else: - response += "Sec-WebSocket-Protocol: binary\r\n" - response += "\r\n" + self.send_header("Sec-WebSocket-Protocol", "binary") + self.end_headers() + return True + else: + self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + + return False + + def handle_websocket(self): + """Upgrade a connection to Websocket, if requested. If this succeeds, + new_websocket_client() will be called. Otherwise, False is returned. + """ + if (self.headers.get('upgrade') and + self.headers.get('upgrade').lower() == 'websocket'): + + if not self.do_websocket_handshake(): + return False + + # Indicate to server that a Websocket upgrade was done + self.server.ws_connection = True + # Initialize per client settings + self.send_parts = [] + self.recv_part = None + self.start_time = int(time.time()*1000) + + # client_address is empty with, say, UNIX domain sockets + client_addr = "" + is_ssl = False + try: + client_addr = self.client_address[0] + is_ssl = self.client_address[2] + except IndexError: + pass + + if is_ssl: + self.stype = "SSL/TLS (wss://)" + else: + self.stype = "Plain non-SSL (ws://)" + + self.log_message("%s: %s WebSocket connection", client_addr, + self.stype) + self.log_message("%s: Version %s, base64: '%s'", client_addr, + self.version, self.base64) + if self.path != '/': + self.log_message("%s: Path: '%s'", client_addr, self.path) + + if self.record: + # Record raw frame data as JavaScript array + fname = "%s.%s" % (self.record, + self.handler_id) + self.log_message("opening record file: %s", fname) + self.rec = open(fname, 'w+') + encoding = "binary" + if self.base64: encoding = "base64" + self.rec.write("var VNC_frame_encoding = '%s';\n" + % encoding) + self.rec.write("var VNC_frame_data = [\n") + + try: + self.new_websocket_client() + except self.CClose: + # Close the client + _, exc, _ = sys.exc_info() + self.send_close(exc.args[0], exc.args[1]) + return True + else: + return False + + def do_GET(self): + """Handle GET request. Calls handle_websocket(). If unsuccessful, + and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called.""" + if not self.handle_websocket(): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_GET(self) + + def list_directory(self, path): + if self.file_only: + self.send_error(404, "No such file") + else: + return SimpleHTTPRequestHandler.list_directory(self, path) + def new_websocket_client(self): + """ Do something with a WebSockets client connection. """ + raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + + def do_HEAD(self): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_HEAD(self) + + def finish(self): + if self.rec: + self.rec.write("'EOF'];\n") + self.rec.close() + + def handle(self): + # When using run_once, we have a single process, so + # we cannot loop in BaseHTTPRequestHandler.handle; we + # must return and handle new connections + if self.run_once: + self.handle_one_request() else: - raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + SimpleHTTPRequestHandler.handle(self) + + def log_request(self, code='-', size='-'): + if self.verbose: + SimpleHTTPRequestHandler.log_request(self, code, size) + + +class WebSocketServer(object): + """ + WebSockets server class. + As an alternative, the standard library SocketServer can be used + """ - return response + policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" + log_prefix = "websocket" + + # An exception before the WebSocket connection was established + class EClose(Exception): + pass + class Terminate(Exception): + pass + + def __init__(self, RequestHandlerClass, listen_host='', + listen_port=None, source_is_ipv6=False, + verbose=False, cert='', key='', ssl_only=None, + daemon=False, record='', web='', + file_only=False, + run_once=False, timeout=0, idle_timeout=0, traffic=False, + tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, + tcp_keepintvl=None): + + # settings + self.RequestHandlerClass = RequestHandlerClass + self.verbose = verbose + self.listen_host = listen_host + self.listen_port = listen_port + self.prefer_ipv6 = source_is_ipv6 + self.ssl_only = ssl_only + self.daemon = daemon + self.run_once = run_once + self.timeout = timeout + self.idle_timeout = idle_timeout + self.traffic = traffic + + self.launch_time = time.time() + self.ws_connection = False + self.handler_id = 1 + + self.logger = self.get_logger() + self.tcp_keepalive = tcp_keepalive + self.tcp_keepcnt = tcp_keepcnt + self.tcp_keepidle = tcp_keepidle + self.tcp_keepintvl = tcp_keepintvl + + # Make paths settings absolute + self.cert = os.path.abspath(cert) + self.key = self.web = self.record = '' + if key: + self.key = os.path.abspath(key) + if web: + self.web = os.path.abspath(web) + if record: + self.record = os.path.abspath(record) + + if self.web: + os.chdir(self.web) + self.only_upgrade = not self.web + + # Sanity checks + if not ssl and self.ssl_only: + raise Exception("No 'ssl' module and SSL-only specified") + if self.daemon and not resource: + raise Exception("Module 'resource' required to daemonize") + + # Show configuration + self.msg("WebSocket server settings:") + self.msg(" - Listen on %s:%s", + self.listen_host, self.listen_port) + self.msg(" - Flash security policy server") + if self.web: + self.msg(" - Web server. Web root: %s", self.web) + if ssl: + if os.path.exists(self.cert): + self.msg(" - SSL/TLS support") + if self.ssl_only: + self.msg(" - Deny non-SSL/TLS connections") + else: + self.msg(" - No SSL/TLS support (no cert file)") + else: + self.msg(" - No SSL/TLS support (no 'ssl' module)") + if self.daemon: + self.msg(" - Backgrounding (daemon)") + if self.record: + self.msg(" - Recording to '%s.*'", self.record) + + # + # WebSocketServer static methods + # + + @staticmethod + def get_logger(): + return logging.getLogger("%s.%s" % ( + WebSocketServer.log_prefix, + WebSocketServer.__class__.__name__)) + + @staticmethod + def socket(host, port=None, connect=False, prefer_ipv6=False, + unix_socket=None, use_ssl=False, tcp_keepalive=True, + tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): + """ Resolve a host (and optional port) to an IPv4 or IPv6 + address. Create a socket. Bind to it if listen is set, + otherwise connect to it. Return the socket. + """ + flags = 0 + if host == '': + host = None + if connect and not (port or unix_socket): + raise Exception("Connect mode requires a port") + if use_ssl and not ssl: + raise Exception("SSL socket requested but Python SSL module not loaded."); + if not connect and use_ssl: + raise Exception("SSL only supported in connect mode (for now)") + if not connect: + flags = flags | socket.AI_PASSIVE + + if not unix_socket: + addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, + socket.IPPROTO_TCP, flags) + if not addrs: + raise Exception("Could not resolve host '%s'" % host) + addrs.sort(key=lambda x: x[0]) + if prefer_ipv6: + addrs.reverse() + sock = socket.socket(addrs[0][0], addrs[0][1]) + + if tcp_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if tcp_keepcnt: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, + tcp_keepcnt) + if tcp_keepidle: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, + tcp_keepidle) + if tcp_keepintvl: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, + tcp_keepintvl) + + if connect: + sock.connect(addrs[0][4]) + if use_ssl: + sock = ssl.wrap_socket(sock) + else: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addrs[0][4]) + sock.listen(100) + else: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(unix_socket) + + return sock + + @staticmethod + def daemonize(keepfd=None, chdir='/'): + os.umask(0) + if chdir: + os.chdir(chdir) + else: + os.chdir('/') + os.setgid(os.getgid()) # relinquish elevations + os.setuid(os.getuid()) # relinquish elevations + + # Double fork to daemonize + if os.fork() > 0: os._exit(0) # Parent exits + os.setsid() # Obtain new process group + if os.fork() > 0: os._exit(0) # Parent exits + + # Signal handling + signal.signal(signal.SIGTERM, signal.SIG_IGN) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Close open files + maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if maxfd == resource.RLIM_INFINITY: maxfd = 256 + for fd in reversed(range(maxfd)): + try: + if fd != keepfd: + os.close(fd) + except OSError: + _, exc, _ = sys.exc_info() + if exc.errno != errno.EBADF: raise + + # Redirect I/O to /dev/null + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) def do_handshake(self, sock, address): """ @@ -565,7 +751,6 @@ Sec-WebSocket-Accept: %s\r - Send a WebSockets handshake server response. - Return the socket for this WebSocket client. """ - stype = "" ready = select.select([sock], [], [], 3)[0] @@ -609,45 +794,38 @@ Sec-WebSocket-Accept: %s\r else: raise - self.scheme = "wss" - stype = "SSL/TLS (wss://)" - elif self.ssl_only: raise self.EClose("non-SSL connection received but disallowed") else: retsock = sock - self.scheme = "ws" - stype = "Plain non-SSL (ws://)" - wsh = WSRequestHandler(retsock, address, not self.web) - if wsh.last_code == 101: - # Continue on to handle WebSocket upgrade - pass - elif wsh.last_code == 405: - raise self.EClose("Normal web request received but disallowed") - elif wsh.last_code < 200 or wsh.last_code >= 300: - raise self.EClose(wsh.last_message) - elif self.verbose: - raise self.EClose(wsh.last_message) - else: - raise self.EClose("") + # If the address is like (host, port), we are extending it + # with a flag indicating SSL. Not many other options + # available... + if len(address) == 2: + address = (address[0], address[1], (retsock != sock)) - response = self.do_websocket_handshake(wsh.headers, wsh.path) + self.RequestHandlerClass(retsock, address, self) - self.msg("%s: %s WebSocket connection" % (address[0], stype)) - self.msg("%s: Version %s, base64: '%s'" % (address[0], - self.version, self.base64)) - if self.path != '/': - self.msg("%s: Path: '%s'" % (address[0], self.path)) + # Return the WebSockets socket which may be SSL wrapped + return retsock + # + # WebSocketServer logging/output functions + # - # Send server WebSockets handshake response - #self.msg("sending response [%s]" % response) - retsock.send(s2b(response)) + def msg(self, *args, **kwargs): + """ Output message as info """ + self.logger.log(logging.INFO, *args, **kwargs) - # Return the WebSockets socket which may be SSL wrapped - return retsock + def vmsg(self, *args, **kwargs): + """ Same as msg() but as debug. """ + self.logger.log(logging.DEBUG, *args, **kwargs) + + def warn(self, *args, **kwargs): + """ Same as msg() but as warning. """ + self.logger.log(logging.WARN, *args, **kwargs) # @@ -662,6 +840,12 @@ Sec-WebSocket-Accept: %s\r #self.vmsg("Running poll()") pass + def terminate(self): + raise self.Terminate() + + def multiprocessing_SIGCHLD(self, sig, stack): + self.vmsg('Reaing zombies, active child count is %s', len(multiprocessing.active_children())) + def fallback_SIGCHLD(self, sig, stack): # Reap zombies when using os.fork() (python 2.4) self.vmsg("Got SIGCHLD, reaping zombies") @@ -675,213 +859,172 @@ Sec-WebSocket-Accept: %s\r def do_SIGINT(self, sig, stack): self.msg("Got SIGINT, exiting") - sys.exit(0) + self.terminate() + + def do_SIGTERM(self, sig, stack): + self.msg("Got SIGTERM, exiting") + self.terminate() def top_new_client(self, startsock, address): """ Do something with a WebSockets client connection. """ - # Initialize per client settings - self.send_parts = [] - self.recv_part = None - self.base64 = False - self.rec = None - self.start_time = int(time.time()*1000) - # handler process + client = None try: try: - self.client = self.do_handshake(startsock, address) - - if self.record: - # Record raw frame data as JavaScript array - fname = "%s.%s" % (self.record, - self.handler_id) - self.msg("opening record file: %s" % fname) - self.rec = open(fname, 'w+') - encoding = "binary" - if self.base64: encoding = "base64" - self.rec.write("var VNC_frame_encoding = '%s';\n" - % encoding) - self.rec.write("var VNC_frame_data = [\n") - - self.ws_connection = True - self.new_client() - except self.CClose: - # Close the client - _, exc, _ = sys.exc_info() - if self.client: - self.send_close(exc.args[0], exc.args[1]) + client = self.do_handshake(startsock, address) except self.EClose: _, exc, _ = sys.exc_info() # Connection was not a WebSockets connection if exc.args[0]: self.msg("%s: %s" % (address[0], exc.args[0])) + except WebSocketServer.Terminate: + raise except Exception: _, exc, _ = sys.exc_info() self.msg("handler exception: %s" % str(exc)) - if self.verbose: - self.msg(traceback.format_exc()) + self.vmsg("exception", exc_info=True) finally: - if self.rec: - self.rec.write("'EOF'];\n") - self.rec.close() - if self.client and self.client != startsock: + if client and client != startsock: # Close the SSL wrapped socket # Original socket closed by caller - self.client.close() - - def new_client(self): - """ Do something with a WebSockets client connection. """ - raise("WebSocketServer.new_client() must be overloaded") + client.close() def start_server(self): """ Daemonize if requested. Listen for for connections. Run do_handshake() method for each connection. If the connection - is a WebSockets client then call new_client() method (which must + is a WebSockets client then call new_websocket_client() method (which must be overridden) for each new client connection. """ - lsock = self.socket(self.listen_host, self.listen_port, False, self.prefer_ipv6) + lsock = self.socket(self.listen_host, self.listen_port, False, + self.prefer_ipv6, + tcp_keepalive=self.tcp_keepalive, + tcp_keepcnt=self.tcp_keepcnt, + tcp_keepidle=self.tcp_keepidle, + tcp_keepintvl=self.tcp_keepintvl) if self.daemon: self.daemonize(keepfd=lsock.fileno(), chdir=self.web) self.started() # Some things need to happen after daemonizing - # Allow override of SIGINT + # Allow override of signals + original_signals = { + signal.SIGINT: signal.getsignal(signal.SIGINT), + signal.SIGTERM: signal.getsignal(signal.SIGTERM), + signal.SIGCHLD: signal.getsignal(signal.SIGCHLD), + } signal.signal(signal.SIGINT, self.do_SIGINT) + signal.signal(signal.SIGTERM, self.do_SIGTERM) if not multiprocessing: # os.fork() (python 2.4) child reaper signal.signal(signal.SIGCHLD, self.fallback_SIGCHLD) + else: + # make sure that _cleanup is called when children die + # by calling active_children on SIGCHLD + signal.signal(signal.SIGCHLD, self.multiprocessing_SIGCHLD) last_active_time = self.launch_time - while True: - try: + try: + while True: try: - self.client = None - startsock = None - pid = err = 0 - child_count = 0 - - if multiprocessing and self.idle_timeout: - child_count = len(multiprocessing.active_children()) - - time_elapsed = time.time() - self.launch_time - if self.timeout and time_elapsed > self.timeout: - self.msg('listener exit due to --timeout %s' - % self.timeout) - break - - if self.idle_timeout: - idle_time = 0 - if child_count == 0: - idle_time = time.time() - last_active_time - else: - idle_time = 0 - last_active_time = time.time() - - if idle_time > self.idle_timeout and child_count == 0: - self.msg('listener exit due to --idle-timeout %s' - % self.idle_timeout) - break - try: - self.poll() + startsock = None + pid = err = 0 + child_count = 0 + + if multiprocessing: + # Collect zombie child processes + child_count = len(multiprocessing.active_children()) + + time_elapsed = time.time() - self.launch_time + if self.timeout and time_elapsed > self.timeout: + self.msg('listener exit due to --timeout %s' + % self.timeout) + break - ready = select.select([lsock], [], [], 1)[0] - if lsock in ready: - startsock, address = lsock.accept() - else: - continue - except Exception: - _, exc, _ = sys.exc_info() - if hasattr(exc, 'errno'): - err = exc.errno - elif hasattr(exc, 'args'): - err = exc.args[0] - else: - err = exc[0] - if err == errno.EINTR: - self.vmsg("Ignoring interrupted syscall") - continue - else: + if self.idle_timeout: + idle_time = 0 + if child_count == 0: + idle_time = time.time() - last_active_time + else: + idle_time = 0 + last_active_time = time.time() + + if idle_time > self.idle_timeout and child_count == 0: + self.msg('listener exit due to --idle-timeout %s' + % self.idle_timeout) + break + + try: + self.poll() + + ready = select.select([lsock], [], [], 1)[0] + if lsock in ready: + startsock, address = lsock.accept() + else: + continue + except self.Terminate: raise - - if self.run_once: - # Run in same process if run_once - self.top_new_client(startsock, address) - if self.ws_connection : - self.msg('%s: exiting due to --run-once' - % address[0]) - break - elif multiprocessing: - self.vmsg('%s: new handler Process' % address[0]) - p = multiprocessing.Process( - target=self.top_new_client, - args=(startsock, address)) - p.start() - # child will not return - else: - # python 2.4 - self.vmsg('%s: forking handler' % address[0]) - pid = os.fork() - if pid == 0: - # child handler process + except Exception: + _, exc, _ = sys.exc_info() + if hasattr(exc, 'errno'): + err = exc.errno + elif hasattr(exc, 'args'): + err = exc.args[0] + else: + err = exc[0] + if err == errno.EINTR: + self.vmsg("Ignoring interrupted syscall") + continue + else: + raise + + if self.run_once: + # Run in same process if run_once self.top_new_client(startsock, address) - break # child process exits - - # parent process - self.handler_id += 1 - - except KeyboardInterrupt: - _, exc, _ = sys.exc_info() - print("In KeyboardInterrupt") - pass - except SystemExit: - _, exc, _ = sys.exc_info() - print("In SystemExit") - break - except Exception: - _, exc, _ = sys.exc_info() - self.msg("handler exception: %s" % str(exc)) - if self.verbose: - self.msg(traceback.format_exc()) - - finally: - if startsock: - startsock.close() - - # Close listen port - self.vmsg("Closing socket listening at %s:%s" - % (self.listen_host, self.listen_port)) - lsock.close() - + if self.ws_connection : + self.msg('%s: exiting due to --run-once' + % address[0]) + break + elif multiprocessing: + self.vmsg('%s: new handler Process' % address[0]) + p = multiprocessing.Process( + target=self.top_new_client, + args=(startsock, address)) + p.start() + # child will not return + else: + # python 2.4 + self.vmsg('%s: forking handler' % address[0]) + pid = os.fork() + if pid == 0: + # child handler process + self.top_new_client(startsock, address) + break # child process exits + + # parent process + self.handler_id += 1 + + except (self.Terminate, SystemExit, KeyboardInterrupt): + self.msg("In exit") + break + except Exception: + self.msg("handler exception: %s", str(exc)) + self.vmsg("exception", exc_info=True) -# HTTP handler with WebSocket upgrade support -class WSRequestHandler(SimpleHTTPRequestHandler): - def __init__(self, req, addr, only_upgrade=False): - self.only_upgrade = only_upgrade # only allow upgrades - SimpleHTTPRequestHandler.__init__(self, req, addr, object()) + finally: + if startsock: + startsock.close() + finally: + # Close listen port + self.vmsg("Closing socket listening at %s:%s", + self.listen_host, self.listen_port) + lsock.close() - def do_GET(self): - if (self.headers.get('upgrade') and - self.headers.get('upgrade').lower() == 'websocket'): - - # Just indicate that an WebSocket upgrade is needed - self.last_code = 101 - self.last_message = "101 Switching Protocols" - elif self.only_upgrade: - # Normal web request responses are disabled - self.last_code = 405 - self.last_message = "405 Method Not Allowed" - else: - SimpleHTTPRequestHandler.do_GET(self) + # Restore signals + for sig, func in original_signals.items(): + signal.signal(sig, func) - def send_response(self, code, message=None): - # Save the status code - self.last_code = code - SimpleHTTPRequestHandler.send_response(self, code, message) - def log_message(self, f, *args): - # Save instead of printing - self.last_message = f % args diff --git a/utils/websockify b/utils/websockify index 1154d92..7b3ec11 100755 --- a/utils/websockify +++ b/utils/websockify @@ -11,7 +11,11 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' -import signal, socket, optparse, time, os, sys, subprocess +import signal, socket, optparse, time, os, sys, subprocess, logging +try: from socketserver import ForkingMixIn +except: from SocketServer import ForkingMixIn +try: from http.server import HTTPServer +except: from BaseHTTPServer import HTTPServer from select import select import websocket try: @@ -20,15 +24,7 @@ except: from cgi import parse_qs from urlparse import urlparse -class WebSocketProxy(websocket.WebSocketServer): - """ - Proxy traffic to and from a WebSockets client to a normal TCP - socket server target. All traffic to/from the client is base64 - encoded/decoded to allow binary data to be sent/received to/from - the target. - """ - - buffer_size = 65536 +class ProxyRequestHandler(websocket.WebSocketRequestHandler): traffic_legend = """ Traffic Legend: @@ -42,148 +38,33 @@ Traffic Legend: <. - Client send partial """ - def __init__(self, *args, **kwargs): - # Save off proxy specific options - self.target_host = kwargs.pop('target_host', None) - self.target_port = kwargs.pop('target_port', None) - self.wrap_cmd = kwargs.pop('wrap_cmd', None) - self.wrap_mode = kwargs.pop('wrap_mode', None) - self.unix_target = kwargs.pop('unix_target', None) - self.ssl_target = kwargs.pop('ssl_target', None) - self.target_cfg = kwargs.pop('target_cfg', None) - # Last 3 timestamps command was run - self.wrap_times = [0, 0, 0] - - if self.wrap_cmd: - rebinder_path = ['./', os.path.dirname(sys.argv[0])] - self.rebinder = None - - for rdir in rebinder_path: - rpath = os.path.join(rdir, "rebind.so") - if os.path.exists(rpath): - self.rebinder = rpath - break - - if not self.rebinder: - raise Exception("rebind.so not found, perhaps you need to run make") - self.rebinder = os.path.abspath(self.rebinder) - - self.target_host = "127.0.0.1" # Loopback - # Find a free high port - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('', 0)) - self.target_port = sock.getsockname()[1] - sock.close() - - os.environ.update({ - "LD_PRELOAD": self.rebinder, - "REBIND_OLD_PORT": str(kwargs['listen_port']), - "REBIND_NEW_PORT": str(self.target_port)}) - - if self.target_cfg: - self.target_cfg = os.path.abspath(self.target_cfg) - - websocket.WebSocketServer.__init__(self, *args, **kwargs) - - def run_wrap_cmd(self): - print("Starting '%s'" % " ".join(self.wrap_cmd)) - self.wrap_times.append(time.time()) - self.wrap_times.pop(0) - self.cmd = subprocess.Popen( - self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) - self.spawn_message = True - - def started(self): - """ - Called after Websockets server startup (i.e. after daemonize) - """ - # Need to call wrapped command after daemonization so we can - # know when the wrapped command exits - if self.wrap_cmd: - dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) - elif self.unix_target: - dst_string = self.unix_target - else: - dst_string = "%s:%s" % (self.target_host, self.target_port) - - if self.target_cfg: - msg = " - proxying from %s:%s to targets in %s" % ( - self.listen_host, self.listen_port, self.target_cfg) - else: - msg = " - proxying from %s:%s to %s" % ( - self.listen_host, self.listen_port, dst_string) - - if self.ssl_target: - msg += " (using SSL)" - - print(msg + "\n") - - if self.wrap_cmd: - self.run_wrap_cmd() - - def poll(self): - # If we are wrapping a command, check it's status - - if self.wrap_cmd and self.cmd: - ret = self.cmd.poll() - if ret != None: - self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret) - self.cmd = None - - if self.wrap_cmd and self.cmd == None: - # Response to wrapped command being gone - if self.wrap_mode == "ignore": - pass - elif self.wrap_mode == "exit": - sys.exit(ret) - elif self.wrap_mode == "respawn": - now = time.time() - avg = sum(self.wrap_times)/len(self.wrap_times) - if (now - avg) < 10: - # 3 times in the last 10 seconds - if self.spawn_message: - print("Command respawning too fast") - self.spawn_message = False - else: - self.run_wrap_cmd() - - # - # Routines above this point are run in the master listener - # process. - # - - # - # Routines below this point are connection handler routines and - # will be run in a separate forked process for each connection. - # - - def new_client(self): + def new_websocket_client(self): """ Called after a new WebSocket connection has been established. """ # Checks if we receive a token, and look # for a valid target for it then - if self.target_cfg: - (self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) + if self.server.target_cfg: + (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path) # Connect to the target - if self.wrap_cmd: - msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) - elif self.unix_target: - msg = "connecting to unix socket: %s" % self.unix_target + if self.server.wrap_cmd: + msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) + elif self.server.unix_target: + msg = "connecting to unix socket: %s" % self.server.unix_target else: msg = "connecting to: %s:%s" % ( - self.target_host, self.target_port) + self.server.target_host, self.server.target_port) - if self.ssl_target: + if self.server.ssl_target: msg += " (using SSL)" - self.msg(msg) + self.log_message(msg) - tsock = self.socket(self.target_host, self.target_port, - connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) + tsock = websocket.WebSocketServer.socket(self.server.target_host, + self.server.target_port, + connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target) - if self.verbose and not self.daemon: - print(self.traffic_legend) + self.print_traffic(self.traffic_legend) # Start proxying try: @@ -192,8 +73,9 @@ Traffic Legend: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() - self.vmsg("%s:%s: Closed target" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Closed target", + self.server.target_host, self.server.target_port) raise def get_target(self, target_cfg, path): @@ -242,31 +124,32 @@ Traffic Legend: cqueue = [] c_pend = 0 tqueue = [] - rlist = [self.client, target] + rlist = [self.request, target] while True: wlist = [] if tqueue: wlist.append(target) - if cqueue or c_pend: wlist.append(self.client) + if cqueue or c_pend: wlist.append(self.request) ins, outs, excepts = select(rlist, wlist, [], 1) if excepts: raise Exception("Socket exception") - if self.client in outs: + if self.request in outs: # Send queued target data to the client c_pend = self.send_frames(cqueue) cqueue = [] - if self.client in ins: + if self.request in ins: # Receive client data, decode it, and queue for target bufs, closed = self.recv_frames() tqueue.extend(bufs) if closed: # TODO: What about blocking on client socket? - self.vmsg("%s:%s: Client closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Client closed connection", + self.server.target_host, self.server.target_port) raise self.CClose(closed['code'], closed['reason']) @@ -275,24 +158,139 @@ Traffic Legend: dat = tqueue.pop(0) sent = target.send(dat) if sent == len(dat): - self.traffic(">") + self.print_traffic(">") else: # requeue the remaining data tqueue.insert(0, dat[sent:]) - self.traffic(".>") + self.print_traffic(".>") if target in ins: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: - self.vmsg("%s:%s: Target closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Target closed connection", + self.server.target_host, self.server.target_port) raise self.CClose(1000, "Target closed") cqueue.append(buf) - self.traffic("{") + self.print_traffic("{") + +class WebSocketProxy(websocket.WebSocketServer): + """ + Proxy traffic to and from a WebSockets client to a normal TCP + socket server target. All traffic to/from the client is base64 + encoded/decoded to allow binary data to be sent/received to/from + the target. + """ + + buffer_size = 65536 + + def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): + # Save off proxy specific options + self.target_host = kwargs.pop('target_host', None) + self.target_port = kwargs.pop('target_port', None) + self.wrap_cmd = kwargs.pop('wrap_cmd', None) + self.wrap_mode = kwargs.pop('wrap_mode', None) + self.unix_target = kwargs.pop('unix_target', None) + self.ssl_target = kwargs.pop('ssl_target', None) + self.target_cfg = kwargs.pop('target_cfg', None) + # Last 3 timestamps command was run + self.wrap_times = [0, 0, 0] + + if self.wrap_cmd: + wsdir = os.path.dirname(sys.argv[0]) + rebinder_path = [os.path.join(wsdir, "..", "lib"), + os.path.join(wsdir, "..", "lib", "websockify"), + wsdir] + self.rebinder = None + + for rdir in rebinder_path: + rpath = os.path.join(rdir, "rebind.so") + if os.path.exists(rpath): + self.rebinder = rpath + break + + if not self.rebinder: + raise Exception("rebind.so not found, perhaps you need to run make") + self.rebinder = os.path.abspath(self.rebinder) + + self.target_host = "127.0.0.1" # Loopback + # Find a free high port + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('', 0)) + self.target_port = sock.getsockname()[1] + sock.close() + os.environ.update({ + "LD_PRELOAD": self.rebinder, + "REBIND_OLD_PORT": str(kwargs['listen_port']), + "REBIND_NEW_PORT": str(self.target_port)}) + + websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs) + + def run_wrap_cmd(self): + self.msg("Starting '%s'", " ".join(self.wrap_cmd)) + self.wrap_times.append(time.time()) + self.wrap_times.pop(0) + self.cmd = subprocess.Popen( + self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) + self.spawn_message = True + + def started(self): + """ + Called after Websockets server startup (i.e. after daemonize) + """ + # Need to call wrapped command after daemonization so we can + # know when the wrapped command exits + if self.wrap_cmd: + dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) + elif self.unix_target: + dst_string = self.unix_target + else: + dst_string = "%s:%s" % (self.target_host, self.target_port) + + if self.target_cfg: + msg = " - proxying from %s:%s to targets in %s" % ( + self.listen_host, self.listen_port, self.target_cfg) + else: + msg = " - proxying from %s:%s to %s" % ( + self.listen_host, self.listen_port, dst_string) + + if self.ssl_target: + msg += " (using SSL)" + + self.msg("%s", msg) + + if self.wrap_cmd: + self.run_wrap_cmd() + + def poll(self): + # If we are wrapping a command, check it's status + + if self.wrap_cmd and self.cmd: + ret = self.cmd.poll() + if ret != None: + self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret) + self.cmd = None + + if self.wrap_cmd and self.cmd == None: + # Response to wrapped command being gone + if self.wrap_mode == "ignore": + pass + elif self.wrap_mode == "exit": + sys.exit(ret) + elif self.wrap_mode == "respawn": + now = time.time() + avg = sum(self.wrap_times)/len(self.wrap_times) + if (now - avg) < 10: + # 3 times in the last 10 seconds + if self.spawn_message: + self.warn("Command respawning too fast") + self.spawn_message = False + else: + self.run_wrap_cmd() def _subprocess_setup(): @@ -301,14 +299,28 @@ def _subprocess_setup(): signal.signal(signal.SIGPIPE, signal.SIG_DFL) +def logger_init(): + logger = logging.getLogger(WebSocketProxy.log_prefix) + logger.propagate = False + logger.setLevel(logging.INFO) + h = logging.StreamHandler() + h.setLevel(logging.DEBUG) + h.setFormatter(logging.Formatter("%(message)s")) + logger.addHandler(h) + + def websockify_init(): + logger_init() + usage = "\n %prog [options]" usage += " [source_addr:]source_port [target_addr:target_port]" usage += "\n %prog [options]" usage += " [source_addr:]source_port -- WRAP_COMMAND_LINE" parser = optparse.OptionParser(usage=usage) parser.add_option("--verbose", "-v", action="store_true", - help="verbose messages and per frame traffic") + help="verbose messages") + parser.add_option("--traffic", action="store_true", + help="per frame traffic") parser.add_option("--record", help="record sessions to FILE.[session_number]", metavar="FILE") parser.add_option("--daemon", "-D", @@ -345,8 +357,13 @@ def websockify_init(): help="Configuration file containing valid targets " "in the form 'token: host:port' or, alternatively, a " "directory containing configuration files of this form") + parser.add_option("--libserver", action="store_true", + help="use Python library SocketServer engine") (opts, args) = parser.parse_args() + if opts.verbose: + logging.getLogger(WebSocketProxy.log_prefix).setLevel(logging.DEBUG) + # Sanity checks if len(args) < 2 and not (opts.target_cfg or opts.unix_target): parser.error("Too few arguments") @@ -385,9 +402,70 @@ def websockify_init(): try: opts.target_port = int(opts.target_port) except: parser.error("Error parsing target port") + # Transform to absolute path as daemon may chdir + if opts.target_cfg: + opts.target_cfg = os.path.abspath(opts.target_cfg) + # Create and start the WebSockets proxy - server = WebSocketProxy(**opts.__dict__) - server.start_server() + libserver = opts.libserver + del opts.libserver + if libserver: + # Use standard Python SocketServer framework + server = LibProxyServer(**opts.__dict__) + server.serve_forever() + else: + # Use internal service framework + server = WebSocketProxy(**opts.__dict__) + server.start_server() + + +class LibProxyServer(ForkingMixIn, HTTPServer): + """ + Just like WebSocketProxy, but uses standard Python SocketServer + framework. + """ + + def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs): + # Save off proxy specific options + self.target_host = kwargs.pop('target_host', None) + self.target_port = kwargs.pop('target_port', None) + self.wrap_cmd = kwargs.pop('wrap_cmd', None) + self.wrap_mode = kwargs.pop('wrap_mode', None) + self.unix_target = kwargs.pop('unix_target', None) + self.ssl_target = kwargs.pop('ssl_target', None) + self.target_cfg = kwargs.pop('target_cfg', None) + self.daemon = False + self.target_cfg = None + + # Server configuration + listen_host = kwargs.pop('listen_host', '') + listen_port = kwargs.pop('listen_port', None) + web = kwargs.pop('web', '') + + # Configuration affecting base request handler + self.only_upgrade = not web + self.verbose = kwargs.pop('verbose', False) + record = kwargs.pop('record', '') + if record: + self.record = os.path.abspath(record) + self.run_once = kwargs.pop('run_once', False) + self.handler_id = 0 + + for arg in kwargs.keys(): + print("warning: option %s ignored when using --libserver" % arg) + + if web: + os.chdir(web) + + HTTPServer.__init__(self, (listen_host, listen_port), + RequestHandlerClass) + + + def process_request(self, request, client_address): + """Override process_request to implement a counter""" + self.handler_id += 1 + ForkingMixIn.process_request(self, request, client_address) + if __name__ == '__main__': websockify_init() |