diff options
-rwxr-xr-x | tests/echo.py | 17 | ||||
-rwxr-xr-x | tests/load.py | 27 | ||||
-rw-r--r-- | websockify/websocket.py | 712 | ||||
-rwxr-xr-x | websockify/websocketproxy.py | 356 |
4 files changed, 608 insertions, 504 deletions
diff --git a/tests/echo.py b/tests/echo.py index 1d46d50..27bdc46 100755 --- a/tests/echo.py +++ b/tests/echo.py @@ -12,15 +12,15 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates import os, sys, select, optparse sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer +from websockify.websocket import WebSocketServer, WebSocketRequestHandler -class WebSocketEcho(WebSocketServer): +class WebSocketEcho(WebSocketRequestHandler): """ WebSockets server that echos back whatever is received from the client. """ buffer_size = 8096 - def new_client(self): + def new_websocket_client(self): """ Echo back whatever is received. """ @@ -28,28 +28,27 @@ class WebSocketEcho(WebSocketServer): cqueue = [] c_pend = 0 cpartial = "" - rlist = [self.client] + rlist = [self.request] while True: wlist = [] - if cqueue or c_pend: wlist.append(self.client) + if cqueue or c_pend: wlist.append(self.request) ins, outs, excepts = select.select(rlist, wlist, [], 1) if excepts: raise Exception("Socket exception") - if self.client in outs: + if self.request in outs: # Send queued target data to the client c_pend = self.send_frames(cqueue) cqueue = [] - if self.client in ins: + if self.request in ins: # Receive client data, decode it, and send it back frames, closed = self.recv_frames() cqueue.extend(frames) if closed: self.send_close() - raise self.EClose(closed) if __name__ == '__main__': parser = optparse.OptionParser(usage="%prog [options] listen_port") @@ -70,6 +69,6 @@ if __name__ == '__main__': parser.error("Invalid arguments") opts.web = "." - server = WebSocketEcho(**opts.__dict__) + server = WebSocketServer(WebSocketEcho, **opts.__dict__) server.start_server() diff --git a/tests/load.py b/tests/load.py index e1354c9..2565368 100755 --- a/tests/load.py +++ b/tests/load.py @@ -8,33 +8,35 @@ given a sequence number. Any errors are reported and counted. import sys, os, select, random, time, optparse sys.path.insert(0,os.path.join(os.path.dirname(__file__), "..")) -from websockify.websocket import WebSocketServer +from websockify.websocket import WebSocketServer, WebSocketRequestHandler -class WebSocketLoad(WebSocketServer): +class WebSocketLoadServer(WebSocketServer): - buffer_size = 65536 - - max_packet_size = 10000 recv_cnt = 0 send_cnt = 0 def __init__(self, *args, **kwargs): - self.errors = 0 self.delay = kwargs.pop('delay') + WebSocketServer.__init__(self, *args, **kwargs) + + +class WebSocketLoad(WebSocketRequestHandler): + + max_packet_size = 10000 + + def new_websocket_client(self): print "Prepopulating random array" self.rand_array = [] for i in range(0, self.max_packet_size): self.rand_array.append(random.randint(0, 9)) - WebSocketServer.__init__(self, *args, **kwargs) - - def new_client(self): + self.errors = 0 self.send_cnt = 0 self.recv_cnt = 0 try: - self.responder(self.client) + self.responder(self.request) except: print "accumulated errors:", self.errors self.errors = 0 @@ -61,14 +63,13 @@ class WebSocketLoad(WebSocketServer): if closed: self.send_close() - raise self.EClose(closed) now = time.time() * 1000 if client in outs: if c_pend: last_send = now c_pend = self.send_frames() - elif now > (last_send + self.delay): + elif now > (last_send + self.server.delay): last_send = now c_pend = self.send_frames([self.generate()]) @@ -162,6 +163,6 @@ if __name__ == '__main__': parser.error("Invalid arguments") opts.web = "." - server = WebSocketLoad(**opts.__dict__) + server = WebSocketLoadServer(WebSocketLoad, **opts.__dict__) server.start_server() diff --git a/websockify/websocket.py b/websockify/websocket.py index d5ea96b..210ef09 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -65,210 +65,47 @@ if multiprocessing and sys.platform == 'win32': import multiprocessing.reduction -class WebSocketServer(object): +# HTTP handler with WebSocket upgrade support +class WebSocketRequestHandler(SimpleHTTPRequestHandler): """ - WebSockets server class. - Must be sub-classed with new_client method definition. + WebSocket Request Handler Class, derived from SimpleHTTPRequestHandler. + Must be sub-classed with new_websocket_client method definition. + The request handler can be configured by setting optional + attributes on the server object: + + * only_upgrade: If true, SimpleHTTPRequestHandler will not be enabled, + only websocket is allowed. + * verbose: If true, verbose logging is activated. + * daemon: Running as daemon, do not write to console etc + * record: Record raw frame data as JavaScript array into specified filename + * run_once: Handle a single request + * handler_id: A sequence number for this connection, appended to record filename """ - - log_prefix = "websocket" buffer_size = 65536 - server_handshake_hybi = """HTTP/1.1 101 Switching Protocols\r -Upgrade: websocket\r -Connection: Upgrade\r -Sec-WebSocket-Accept: %s\r -""" - GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" + server_version = "WebSockify" - # An exception before the WebSocket connection was established - class EClose(Exception): - pass + protocol_version = "HTTP/1.1" # An exception while the WebSocket client was connected class CClose(Exception): pass - class Terminate(Exception): - pass - - def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False, - verbose=False, cert='', key='', ssl_only=None, - daemon=False, record='', web='', - file_only=False, no_parent=False, - run_once=False, timeout=0, idle_timeout=0, traffic=False, - tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, - tcp_keepintvl=None): - - # settings - self.verbose = verbose - self.listen_host = listen_host - self.listen_port = listen_port - self.prefer_ipv6 = source_is_ipv6 - self.ssl_only = ssl_only - self.daemon = daemon - self.run_once = run_once - self.timeout = timeout - self.idle_timeout = idle_timeout - self.traffic = traffic - - self.launch_time = time.time() - self.ws_connection = False - self.i_am_client = False - self.handler_id = 1 - - self.file_only = file_only - self.no_parent = no_parent - - self.logger = self.get_logger() - self.tcp_keepalive = tcp_keepalive - self.tcp_keepcnt = tcp_keepcnt - self.tcp_keepidle = tcp_keepidle - self.tcp_keepintvl = tcp_keepintvl - - # Make paths settings absolute - self.cert = os.path.abspath(cert) - self.key = self.web = self.record = '' - if key: - self.key = os.path.abspath(key) - if web: - self.web = os.path.abspath(web) - if record: - self.record = os.path.abspath(record) - - if self.web: - os.chdir(self.web) - - # Sanity checks - if not ssl and self.ssl_only: - raise Exception("No 'ssl' module and SSL-only specified") - if self.daemon and not resource: - raise Exception("Module 'resource' required to daemonize") - - # Show configuration - self.msg("WebSocket server settings:") - self.msg(" - Listen on %s:%s", - self.listen_host, self.listen_port) - self.msg(" - Flash security policy server") - if self.web: - self.msg(" - Web server. Web root: %s", self.web) - if ssl: - if os.path.exists(self.cert): - self.msg(" - SSL/TLS support") - if self.ssl_only: - self.msg(" - Deny non-SSL/TLS connections") - else: - self.msg(" - No SSL/TLS support (no cert file)") - else: - self.msg(" - No SSL/TLS support (no 'ssl' module)") - if self.daemon: - self.msg(" - Backgrounding (daemon)") - if self.record: - self.msg(" - Recording to '%s.*'", self.record) - - # - # WebSocketServer static methods - # - - @staticmethod - def get_logger(): - return logging.getLogger("%s.%s" % ( - WebSocketServer.log_prefix, - WebSocketServer.__class__.__name__)) - - @staticmethod - def socket(host, port=None, connect=False, prefer_ipv6=False, - unix_socket=None, use_ssl=False, tcp_keepalive=True, - tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): - """ Resolve a host (and optional port) to an IPv4 or IPv6 - address. Create a socket. Bind to it if listen is set, - otherwise connect to it. Return the socket. - """ - flags = 0 - if host == '': - host = None - if connect and not (port or unix_socket): - raise Exception("Connect mode requires a port") - if use_ssl and not ssl: - raise Exception("SSL socket requested but Python SSL module not loaded."); - if not connect and use_ssl: - raise Exception("SSL only supported in connect mode (for now)") - if not connect: - flags = flags | socket.AI_PASSIVE - - if not unix_socket: - addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, - socket.IPPROTO_TCP, flags) - if not addrs: - raise Exception("Could not resolve host '%s'" % host) - addrs.sort(key=lambda x: x[0]) - if prefer_ipv6: - addrs.reverse() - sock = socket.socket(addrs[0][0], addrs[0][1]) - - if tcp_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if tcp_keepcnt: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, - tcp_keepcnt) - if tcp_keepidle: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, - tcp_keepidle) - if tcp_keepintvl: - sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, - tcp_keepintvl) - - if connect: - sock.connect(addrs[0][4]) - if use_ssl: - sock = ssl.wrap_socket(sock) - else: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(addrs[0][4]) - sock.listen(100) - else: - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(unix_socket) - - return sock - - @staticmethod - def daemonize(keepfd=None, chdir='/'): - os.umask(0) - if chdir: - os.chdir(chdir) - else: - os.chdir('/') - os.setgid(os.getgid()) # relinquish elevations - os.setuid(os.getuid()) # relinquish elevations - - # Double fork to daemonize - if os.fork() > 0: os._exit(0) # Parent exits - os.setsid() # Obtain new process group - if os.fork() > 0: os._exit(0) # Parent exits - - # Signal handling - signal.signal(signal.SIGTERM, signal.SIG_IGN) - signal.signal(signal.SIGINT, signal.SIG_IGN) - - # Close open files - maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] - if maxfd == resource.RLIM_INFINITY: maxfd = 256 - for fd in reversed(range(maxfd)): - try: - if fd != keepfd: - os.close(fd) - except OSError: - _, exc, _ = sys.exc_info() - if exc.errno != errno.EBADF: raise - - # Redirect I/O to /dev/null - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) - os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) + def __init__(self, req, addr, server): + # Retrieve a few configuration variables from the server + self.only_upgrade = getattr(server, "only_upgrade", False) + self.verbose = getattr(server, "verbose", False) + self.daemon = getattr(server, "daemon", False) + self.record = getattr(server, "record", False) + self.run_once = getattr(server, "run_once", False) + self.rec = None + self.handler_id = getattr(server, "handler_id", False) + self.file_only = getattr(server, "file_only", False) + self.traffic = getattr(server, "traffic", False) + + SimpleHTTPRequestHandler.__init__(self, req, addr, server) @staticmethod def unmask(buf, hlen, plen): @@ -393,7 +230,7 @@ Sec-WebSocket-Accept: %s\r # Process 1 frame if f['masked']: # unmask payload - f['payload'] = WebSocketServer.unmask(buf, f['hlen'], + f['payload'] = WebSocketRequestHandler.unmask(buf, f['hlen'], f['length']) else: logger.debug("Unmasked frame: %s" % repr(buf)) @@ -415,9 +252,8 @@ Sec-WebSocket-Accept: %s\r return f - # - # WebSocketServer logging/output functions + # WebSocketRequestHandler logging/output functions # def print_traffic(self, token="."): @@ -426,29 +262,23 @@ Sec-WebSocket-Accept: %s\r sys.stdout.write(token) sys.stdout.flush() - - def log(self, lvl, msg, *args, **kwargs): - """ Wrapper around python logging """ - prefix = "" - if self.i_am_client: - prefix = "% 3d: " % self.handler_id - self.logger.log(lvl, "%s%s" % (prefix, msg), - *args, **kwargs) - - def msg(self, *args, **kwargs): + def msg(self, msg, *args, **kwargs): """ Output message with handler_id prefix. """ - self.log(logging.INFO, *args, **kwargs) + prefix = "% 3d: " % self.handler_id + self.server.msg("%s%s" % (prefix, msg), *args, **kwargs) - def vmsg(self, *args, **kwargs): + def vmsg(self, msg, *args, **kwargs): """ Same as msg() but as debug. """ - self.log(logging.DEBUG, *args, **kwargs) + prefix = "% 3d: " % self.handler_id + self.server.vmsg("%s%s" % (prefix, msg), *args, **kwargs) - def warn(self, *args, **kwargs): + def warn(self, msg, *args, **kwargs): """ Same as msg() but as warning. """ - self.log(logging.WARN, *args, **kwargs) + prefix = "% 3d: " % self.handler_id + self.server.warn("%s%s" % (prefix, msg), *args, **kwargs) # - # Main WebSocketServer methods + # Main WebSocketRequestHandler methods # def send_frames(self, bufs=None): """ Encode and send WebSocket frames. Any frames already @@ -477,7 +307,7 @@ Sec-WebSocket-Accept: %s\r while self.send_parts: # Send pending frames buf = self.send_parts.pop(0) - sent = self.client.send(buf) + sent = self.request.send(buf) if sent == len(buf): self.print_traffic("<") @@ -499,7 +329,7 @@ Sec-WebSocket-Accept: %s\r bufs = [] tdelta = int(time.time()*1000) - self.start_time - buf = self.client.recv(self.buffer_size) + buf = self.request.recv(self.buffer_size) if len(buf) == 0: closed = {'code': 1000, 'reason': "Client closed abruptly"} return bufs, closed @@ -532,7 +362,7 @@ Sec-WebSocket-Accept: %s\r start = frame['hlen'] end = frame['hlen'] + frame['length'] if frame['masked']: - recbuf = WebSocketServer.unmask(buf, frame['hlen'], + recbuf = WebSocketRequestHandler.unmask(buf, frame['hlen'], frame['length']) else: recbuf = buf[frame['hlen']:frame['hlen'] + @@ -555,11 +385,10 @@ Sec-WebSocket-Accept: %s\r msg = pack(">H%ds" % len(reason), code, reason) buf, h, t = self.encode_hybi(msg, opcode=0x08, base64=False) - self.client.send(buf) + self.request.send(buf) - def do_websocket_handshake(self, headers, path): - h = self.headers = headers - self.path = path + def do_websocket_handshake(self): + h = self.headers prot = 'WebSocket-Protocol' protocols = h.get('Sec-'+prot, h.get(prot, '')).split(',') @@ -574,7 +403,8 @@ Sec-WebSocket-Accept: %s\r if ver in ['7', '8', '13']: self.version = "hybi-%02d" % int(ver) else: - raise self.EClose('Unsupported protocol version %s' % ver) + self.send_error(400, "Unsupported protocol version %s" % ver) + return False key = h['Sec-WebSocket-Key'] @@ -584,23 +414,319 @@ Sec-WebSocket-Accept: %s\r elif 'base64' in protocols: self.base64 = True else: - raise self.EClose("Client must support 'binary' or 'base64' protocol") + self.send_error(400, "Client must support 'binary' or 'base64' protocol") + return False # Generate the hash value for the accept header accept = b64encode(sha1(s2b(key + self.GUID)).digest()) - response = self.server_handshake_hybi % b2s(accept) + self.send_response(101, "Switching Protocols") + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "Upgrade") + self.send_header("Sec-WebSocket-Accept", b2s(accept)) if self.base64: - response += "Sec-WebSocket-Protocol: base64\r\n" + self.send_header("Sec-WebSocket-Protocol", "base64") + else: + self.send_header("Sec-WebSocket-Protocol", "binary") + self.end_headers() + return True + else: + self.send_error(400, "Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + + return False + + def handle_websocket(self): + """Upgrade a connection to Websocket, if requested. If this succeeds, + new_websocket_client() will be called. Otherwise, False is returned. + """ + if (self.headers.get('upgrade') and + self.headers.get('upgrade').lower() == 'websocket'): + + if not self.do_websocket_handshake(): + return False + + # Indicate to server that a Websocket upgrade was done + self.server.ws_connection = True + # Initialize per client settings + self.send_parts = [] + self.recv_part = None + self.start_time = int(time.time()*1000) + + # client_address is empty with, say, UNIX domain sockets + client_addr = "" + is_ssl = False + try: + client_addr = self.client_address[0] + is_ssl = self.client_address[2] + except IndexError: + pass + + if is_ssl: + self.stype = "SSL/TLS (wss://)" + else: + self.stype = "Plain non-SSL (ws://)" + + self.log_message("%s: %s WebSocket connection" % (client_addr, + self.stype)) + self.log_message("%s: Version %s, base64: '%s'" % (client_addr, + self.version, self.base64)) + if self.path != '/': + self.log_message("%s: Path: '%s'" % (client_addr, self.path)) + + if self.record: + # Record raw frame data as JavaScript array + fname = "%s.%s" % (self.record, + self.handler_id) + self.log_message("opening record file: %s" % fname) + self.rec = open(fname, 'w+') + encoding = "binary" + if self.base64: encoding = "base64" + self.rec.write("var VNC_frame_encoding = '%s';\n" + % encoding) + self.rec.write("var VNC_frame_data = [\n") + + try: + self.new_websocket_client() + except self.CClose: + # Close the client + _, exc, _ = sys.exc_info() + self.send_close(exc.args[0], exc.args[1]) + return True + else: + return False + + def do_GET(self): + """Handle GET request. Calls handle_websocket(). If unsuccessful, + and web server is enabled, SimpleHTTPRequestHandler.do_GET will be called.""" + if not self.handle_websocket(): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_GET(self) + + def list_directory(self, path): + if self.file_only: + self.send_error(404, "No such file") + else: + return SimpleHTTPRequestHandler.list_directory(self, path) + + def new_websocket_client(self): + """ Do something with a WebSockets client connection. """ + raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + + def do_HEAD(self): + if self.only_upgrade: + self.send_error(405, "Method Not Allowed") + else: + SimpleHTTPRequestHandler.do_HEAD(self) + + def finish(self): + if self.rec: + self.rec.write("'EOF'];\n") + self.rec.close() + + def handle(self): + # When using run_once, we have a single process, so + # we cannot loop in BaseHTTPRequestHandler.handle; we + # must return and handle new connections + if self.run_once: + self.handle_one_request() + else: + SimpleHTTPRequestHandler.handle(self) + + def log_request(self, code='-', size='-'): + if self.verbose: + SimpleHTTPRequestHandler.log_request(self, code, size) + + +class WebSocketServer(object): + """ + WebSockets server class. + As an alternative, the standard library SocketServer can be used + """ + + policy_response = """<cross-domain-policy><allow-access-from domain="*" to-ports="*" /></cross-domain-policy>\n""" + log_prefix = "websocket" + + # An exception before the WebSocket connection was established + class EClose(Exception): + pass + + class Terminate(Exception): + pass + + def __init__(self, RequestHandlerClass, listen_host='', + listen_port=None, source_is_ipv6=False, + verbose=False, cert='', key='', ssl_only=None, + daemon=False, record='', web='', file_only=False, + run_once=False, timeout=0, idle_timeout=0, traffic=False, + tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None, + tcp_keepintvl=None): + + # settings + self.RequestHandlerClass = RequestHandlerClass + self.verbose = verbose + self.listen_host = listen_host + self.listen_port = listen_port + self.prefer_ipv6 = source_is_ipv6 + self.ssl_only = ssl_only + self.daemon = daemon + self.run_once = run_once + self.timeout = timeout + self.idle_timeout = idle_timeout + self.traffic = traffic + + self.launch_time = time.time() + self.ws_connection = False + self.handler_id = 1 + + self.logger = self.get_logger() + self.tcp_keepalive = tcp_keepalive + self.tcp_keepcnt = tcp_keepcnt + self.tcp_keepidle = tcp_keepidle + self.tcp_keepintvl = tcp_keepintvl + + # Make paths settings absolute + self.cert = os.path.abspath(cert) + self.key = self.web = self.record = '' + if key: + self.key = os.path.abspath(key) + if web: + self.web = os.path.abspath(web) + if record: + self.record = os.path.abspath(record) + + if self.web: + os.chdir(self.web) + self.only_upgrade = not self.web + + # Sanity checks + if not ssl and self.ssl_only: + raise Exception("No 'ssl' module and SSL-only specified") + if self.daemon and not resource: + raise Exception("Module 'resource' required to daemonize") + + # Show configuration + self.msg("WebSocket server settings:") + self.msg(" - Listen on %s:%s", + self.listen_host, self.listen_port) + self.msg(" - Flash security policy server") + if self.web: + self.msg(" - Web server. Web root: %s", self.web) + if ssl: + if os.path.exists(self.cert): + self.msg(" - SSL/TLS support") + if self.ssl_only: + self.msg(" - Deny non-SSL/TLS connections") + else: + self.msg(" - No SSL/TLS support (no cert file)") + else: + self.msg(" - No SSL/TLS support (no 'ssl' module)") + if self.daemon: + self.msg(" - Backgrounding (daemon)") + if self.record: + self.msg(" - Recording to '%s.*'", self.record) + + # + # WebSocketServer static methods + # + + @staticmethod + def get_logger(): + return logging.getLogger("%s.%s" % ( + WebSocketServer.log_prefix, + WebSocketServer.__class__.__name__)) + + @staticmethod + def socket(host, port=None, connect=False, prefer_ipv6=False, + unix_socket=None, use_ssl=False, tcp_keepalive=True, + tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None): + """ Resolve a host (and optional port) to an IPv4 or IPv6 + address. Create a socket. Bind to it if listen is set, + otherwise connect to it. Return the socket. + """ + flags = 0 + if host == '': + host = None + if connect and not (port or unix_socket): + raise Exception("Connect mode requires a port") + if use_ssl and not ssl: + raise Exception("SSL socket requested but Python SSL module not loaded."); + if not connect and use_ssl: + raise Exception("SSL only supported in connect mode (for now)") + if not connect: + flags = flags | socket.AI_PASSIVE + + if not unix_socket: + addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, + socket.IPPROTO_TCP, flags) + if not addrs: + raise Exception("Could not resolve host '%s'" % host) + addrs.sort(key=lambda x: x[0]) + if prefer_ipv6: + addrs.reverse() + sock = socket.socket(addrs[0][0], addrs[0][1]) + + if tcp_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if tcp_keepcnt: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, + tcp_keepcnt) + if tcp_keepidle: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, + tcp_keepidle) + if tcp_keepintvl: + sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, + tcp_keepintvl) + + if connect: + sock.connect(addrs[0][4]) + if use_ssl: + sock = ssl.wrap_socket(sock) else: - response += "Sec-WebSocket-Protocol: binary\r\n" - response += "\r\n" + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(addrs[0][4]) + sock.listen(100) + else: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.connect(unix_socket) + + return sock + @staticmethod + def daemonize(keepfd=None, chdir='/'): + os.umask(0) + if chdir: + os.chdir(chdir) else: - raise self.EClose("Missing Sec-WebSocket-Version header. Hixie protocols not supported.") + os.chdir('/') + os.setgid(os.getgid()) # relinquish elevations + os.setuid(os.getuid()) # relinquish elevations - return response + # Double fork to daemonize + if os.fork() > 0: os._exit(0) # Parent exits + os.setsid() # Obtain new process group + if os.fork() > 0: os._exit(0) # Parent exits + + # Signal handling + signal.signal(signal.SIGTERM, signal.SIG_IGN) + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Close open files + maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1] + if maxfd == resource.RLIM_INFINITY: maxfd = 256 + for fd in reversed(range(maxfd)): + try: + if fd != keepfd: + os.close(fd) + except OSError: + _, exc, _ = sys.exc_info() + if exc.errno != errno.EBADF: raise + # Redirect I/O to /dev/null + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno()) + os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno()) def do_handshake(self, sock, address): """ @@ -619,7 +745,6 @@ Sec-WebSocket-Accept: %s\r - Send a WebSockets handshake server response. - Return the socket for this WebSocket client. """ - stype = "" ready = select.select([sock], [], [], 3)[0] @@ -663,46 +788,39 @@ Sec-WebSocket-Accept: %s\r else: raise - self.scheme = "wss" - stype = "SSL/TLS (wss://)" - elif self.ssl_only: raise self.EClose("non-SSL connection received but disallowed") else: retsock = sock - self.scheme = "ws" - stype = "Plain non-SSL (ws://)" - wsh = WSRequestHandler(retsock, address, not self.web, - self.file_only, self.no_parent) - if wsh.last_code == 101: - # Continue on to handle WebSocket upgrade - pass - elif wsh.last_code == 405: - raise self.EClose("Normal web request received but disallowed") - elif wsh.last_code < 200 or wsh.last_code >= 300: - raise self.EClose(wsh.last_message) - elif self.verbose: - raise self.EClose(wsh.last_message) - else: - raise self.EClose("") + # If the address is like (host, port), we are extending it + # with a flag indicating SSL. Not many other options + # available... + if len(address) == 2: + address = (address[0], address[1], (retsock != sock)) + + self.RequestHandlerClass(retsock, address, self) + + # Return the WebSockets socket which may be SSL wrapped + return retsock - response = self.do_websocket_handshake(wsh.headers, wsh.path) - self.msg("%s: %s WebSocket connection" % (address[0], stype)) - self.msg("%s: Version %s, base64: '%s'" % (address[0], - self.version, self.base64)) - if self.path != '/': - self.msg("%s: Path: '%s'" % (address[0], self.path)) + # + # WebSocketServer logging/output functions + # + def msg(self, *args, **kwargs): + """ Output message as info """ + self.logger.log(logging.INFO, *args, **kwargs) - # Send server WebSockets handshake response - #self.msg("sending response [%s]" % response) - retsock.send(s2b(response)) + def vmsg(self, *args, **kwargs): + """ Same as msg() but as debug. """ + self.logger.log(logging.DEBUG, *args, **kwargs) - # Return the WebSockets socket which may be SSL wrapped - return retsock + def warn(self, *args, **kwargs): + """ Same as msg() but as warning. """ + self.logger.log(logging.WARN, *args, **kwargs) # @@ -744,38 +862,11 @@ Sec-WebSocket-Accept: %s\r def top_new_client(self, startsock, address): """ Do something with a WebSockets client connection. """ - # Initialize per client settings - self.i_am_client = True - self.send_parts = [] - self.recv_part = None - self.base64 = False - self.rec = None - self.start_time = int(time.time()*1000) - # handler process + client = None try: try: - self.client = self.do_handshake(startsock, address) - - if self.record: - # Record raw frame data as JavaScript array - fname = "%s.%s" % (self.record, - self.handler_id) - self.msg("opening record file: %s" % fname) - self.rec = open(fname, 'w+') - encoding = "binary" - if self.base64: encoding = "base64" - self.rec.write("var VNC_frame_encoding = '%s';\n" - % encoding) - self.rec.write("var VNC_frame_data = [\n") - - self.ws_connection = True - self.new_client() - except self.CClose: - # Close the client - _, exc, _ = sys.exc_info() - if self.client: - self.send_close(exc.args[0], exc.args[1]) + client = self.do_handshake(startsock, address) except self.EClose: _, exc, _ = sys.exc_info() # Connection was not a WebSockets connection @@ -788,18 +879,11 @@ Sec-WebSocket-Accept: %s\r self.msg("handler exception: %s" % str(exc)) self.vmsg("exception", exc_info=True) finally: - if self.rec: - self.rec.write("'EOF'];\n") - self.rec.close() - if self.client and self.client != startsock: + if client and client != startsock: # Close the SSL wrapped socket # Original socket closed by caller - self.client.close() - - def new_client(self): - """ Do something with a WebSockets client connection. """ - raise("WebSocketServer.new_client() must be overloaded") + client.close() def start_server(self): """ @@ -841,7 +925,6 @@ Sec-WebSocket-Accept: %s\r while True: try: try: - self.client = None startsock = None pid = err = 0 child_count = 0 @@ -937,43 +1020,6 @@ Sec-WebSocket-Accept: %s\r # Restore signals for sig, func in original_signals.items(): - signal.signal(sig, func) - - -# HTTP handler with WebSocket upgrade support -class WSRequestHandler(SimpleHTTPRequestHandler): - def __init__(self, req, addr, only_upgrade=False, file_only=False, - no_parent=False): - self.only_upgrade = only_upgrade # only allow upgrades - self.webroot = os.path.realpath(".") - self.file_only = file_only - self.no_parent = no_parent - SimpleHTTPRequestHandler.__init__(self, req, addr, object()) - - def do_GET(self): - abspath = os.path.realpath("." + (self.path.split('?')[0])) - if (self.headers.get('upgrade') and - self.headers.get('upgrade').lower() == 'websocket'): - - # Just indicate that an WebSocket upgrade is needed - self.last_code = 101 - self.last_message = "101 Switching Protocols" - elif self.only_upgrade: - # Normal web request responses are disabled - self.last_code = 405 - self.last_message = "405 Method Not Allowed" - elif self.file_only and not os.path.isfile(abspath): - self.send_response(404, "No such file") - elif self.no_parent and not abspath.startswith(self.webroot): - self.send_response(403, "Hidden resources") - else: - SimpleHTTPRequestHandler.do_GET(self) + signal.signal(sig, func) - def send_response(self, code, message=None): - # Save the status code - self.last_code = code - SimpleHTTPRequestHandler.send_response(self, code, message) - def log_message(self, f, *args): - # Save instead of printing - self.last_message = f % args diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 8e5d3fe..e8bbf02 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -12,6 +12,10 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates ''' import signal, socket, optparse, time, os, sys, subprocess, logging +try: from socketserver import ForkingMixIn +except: from SocketServer import ForkingMixIn +try: from http.server import HTTPServer +except: from BaseHTTPServer import HTTPServer from select import select import websocket try: @@ -20,167 +24,45 @@ except: from cgi import parse_qs from urlparse import urlparse -class WebSocketProxy(websocket.WebSocketServer): - """ - Proxy traffic to and from a WebSockets client to a normal TCP - socket server target. All traffic to/from the client is base64 - encoded/decoded to allow binary data to be sent/received to/from - the target. - """ - - buffer_size = 65536 +class ProxyRequestHandler(websocket.WebSocketRequestHandler): traffic_legend = """ Traffic Legend: } - Client receive }. - Client receive partial { - Target receive - + > - Target send >. - Target send partial < - Client send <. - Client send partial """ - def __init__(self, *args, **kwargs): - # Save off proxy specific options - self.target_host = kwargs.pop('target_host', None) - self.target_port = kwargs.pop('target_port', None) - self.wrap_cmd = kwargs.pop('wrap_cmd', None) - self.wrap_mode = kwargs.pop('wrap_mode', None) - self.unix_target = kwargs.pop('unix_target', None) - self.ssl_target = kwargs.pop('ssl_target', None) - self.target_cfg = kwargs.pop('target_cfg', None) - # Last 3 timestamps command was run - self.wrap_times = [0, 0, 0] - - if self.wrap_cmd: - wsdir = os.path.dirname(sys.argv[0]) - rebinder_path = [os.path.join(wsdir, "..", "lib"), - os.path.join(wsdir, "..", "lib", "websockify"), - wsdir] - self.rebinder = None - - for rdir in rebinder_path: - rpath = os.path.join(rdir, "rebind.so") - if os.path.exists(rpath): - self.rebinder = rpath - break - - if not self.rebinder: - raise Exception("rebind.so not found, perhaps you need to run make") - self.rebinder = os.path.abspath(self.rebinder) - - self.target_host = "127.0.0.1" # Loopback - # Find a free high port - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(('', 0)) - self.target_port = sock.getsockname()[1] - sock.close() - - os.environ.update({ - "LD_PRELOAD": self.rebinder, - "REBIND_OLD_PORT": str(kwargs['listen_port']), - "REBIND_NEW_PORT": str(self.target_port)}) - - websocket.WebSocketServer.__init__(self, *args, **kwargs) - - def run_wrap_cmd(self): - self.msg("Starting '%s'", " ".join(self.wrap_cmd)) - self.wrap_times.append(time.time()) - self.wrap_times.pop(0) - self.cmd = subprocess.Popen( - self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) - self.spawn_message = True - - def started(self): - """ - Called after Websockets server startup (i.e. after daemonize) - """ - # Need to call wrapped command after daemonization so we can - # know when the wrapped command exits - if self.wrap_cmd: - dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) - elif self.unix_target: - dst_string = self.unix_target - else: - dst_string = "%s:%s" % (self.target_host, self.target_port) - - if self.target_cfg: - msg = " - proxying from %s:%s to targets in %s" % ( - self.listen_host, self.listen_port, self.target_cfg) - else: - msg = " - proxying from %s:%s to %s" % ( - self.listen_host, self.listen_port, dst_string) - - if self.ssl_target: - msg += " (using SSL)" - - self.msg("%s", msg) - - if self.wrap_cmd: - self.run_wrap_cmd() - - def poll(self): - # If we are wrapping a command, check it's status - - if self.wrap_cmd and self.cmd: - ret = self.cmd.poll() - if ret != None: - self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret) - self.cmd = None - - if self.wrap_cmd and self.cmd == None: - # Response to wrapped command being gone - if self.wrap_mode == "ignore": - pass - elif self.wrap_mode == "exit": - sys.exit(ret) - elif self.wrap_mode == "respawn": - now = time.time() - avg = sum(self.wrap_times)/len(self.wrap_times) - if (now - avg) < 10: - # 3 times in the last 10 seconds - if self.spawn_message: - self.warn("Command respawning too fast") - self.spawn_message = False - else: - self.run_wrap_cmd() - - # - # Routines above this point are run in the master listener - # process. - # - - # - # Routines below this point are connection handler routines and - # will be run in a separate forked process for each connection. - # - - def new_client(self): + def new_websocket_client(self): """ Called after a new WebSocket connection has been established. """ # Checks if we receive a token, and look # for a valid target for it then - if self.target_cfg: - (self.target_host, self.target_port) = self.get_target(self.target_cfg, self.path) + if self.server.target_cfg: + (self.server.target_host, self.server.target_port) = self.get_target(self.server.target_cfg, self.path) # Connect to the target - if self.wrap_cmd: - msg = "connecting to command: '%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) - elif self.unix_target: - msg = "connecting to unix socket: %s" % self.unix_target + if self.server.wrap_cmd: + msg = "connecting to command: '%s' (port %s)" % (" ".join(self.server.wrap_cmd), self.server.target_port) + elif self.server.unix_target: + msg = "connecting to unix socket: %s" % self.server.unix_target else: msg = "connecting to: %s:%s" % ( - self.target_host, self.target_port) + self.server.target_host, self.server.target_port) - if self.ssl_target: + if self.server.ssl_target: msg += " (using SSL)" - self.msg(msg) + self.log_message(msg) - tsock = self.socket(self.target_host, self.target_port, - connect=True, use_ssl=self.ssl_target, unix_socket=self.unix_target) + tsock = websocket.WebSocketServer.socket(self.server.target_host, + self.server.target_port, + connect=True, use_ssl=self.server.ssl_target, unix_socket=self.server.unix_target) self.print_traffic(self.traffic_legend) @@ -191,8 +73,9 @@ Traffic Legend: if tsock: tsock.shutdown(socket.SHUT_RDWR) tsock.close() - self.vmsg("%s:%s: Closed target" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Closed target" %( + self.server.target_host, self.server.target_port)) raise def get_target(self, target_cfg, path): @@ -241,31 +124,32 @@ Traffic Legend: cqueue = [] c_pend = 0 tqueue = [] - rlist = [self.client, target] + rlist = [self.request, target] while True: wlist = [] if tqueue: wlist.append(target) - if cqueue or c_pend: wlist.append(self.client) + if cqueue or c_pend: wlist.append(self.request) ins, outs, excepts = select(rlist, wlist, [], 1) if excepts: raise Exception("Socket exception") - if self.client in outs: + if self.request in outs: # Send queued target data to the client c_pend = self.send_frames(cqueue) cqueue = [] - if self.client in ins: + if self.request in ins: # Receive client data, decode it, and queue for target bufs, closed = self.recv_frames() tqueue.extend(bufs) if closed: # TODO: What about blocking on client socket? - self.vmsg("%s:%s: Client closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Client closed connection" %( + self.server.target_host, self.server.target_port)) raise self.CClose(closed['code'], closed['reason']) @@ -285,13 +169,128 @@ Traffic Legend: # Receive target data, encode it and queue for client buf = target.recv(self.buffer_size) if len(buf) == 0: - self.vmsg("%s:%s: Target closed connection" %( - self.target_host, self.target_port)) + if self.verbose: + self.log_message("%s:%s: Target closed connection" %( + self.server.target_host, self.server.target_port)) raise self.CClose(1000, "Target closed") cqueue.append(buf) self.print_traffic("{") +class WebSocketProxy(websocket.WebSocketServer): + """ + Proxy traffic to and from a WebSockets client to a normal TCP + socket server target. All traffic to/from the client is base64 + encoded/decoded to allow binary data to be sent/received to/from + the target. + """ + + buffer_size = 65536 + + def __init__(self, RequestHandlerClass=ProxyRequestHandler, *args, **kwargs): + # Save off proxy specific options + self.target_host = kwargs.pop('target_host', None) + self.target_port = kwargs.pop('target_port', None) + self.wrap_cmd = kwargs.pop('wrap_cmd', None) + self.wrap_mode = kwargs.pop('wrap_mode', None) + self.unix_target = kwargs.pop('unix_target', None) + self.ssl_target = kwargs.pop('ssl_target', None) + self.target_cfg = kwargs.pop('target_cfg', None) + # Last 3 timestamps command was run + self.wrap_times = [0, 0, 0] + + if self.wrap_cmd: + wsdir = os.path.dirname(sys.argv[0]) + rebinder_path = [os.path.join(wsdir, "..", "lib"), + os.path.join(wsdir, "..", "lib", "websockify"), + wsdir] + self.rebinder = None + + for rdir in rebinder_path: + rpath = os.path.join(rdir, "rebind.so") + if os.path.exists(rpath): + self.rebinder = rpath + break + + if not self.rebinder: + raise Exception("rebind.so not found, perhaps you need to run make") + self.rebinder = os.path.abspath(self.rebinder) + + self.target_host = "127.0.0.1" # Loopback + # Find a free high port + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(('', 0)) + self.target_port = sock.getsockname()[1] + sock.close() + + os.environ.update({ + "LD_PRELOAD": self.rebinder, + "REBIND_OLD_PORT": str(kwargs['listen_port']), + "REBIND_NEW_PORT": str(self.target_port)}) + + websocket.WebSocketServer.__init__(self, RequestHandlerClass, *args, **kwargs) + + def run_wrap_cmd(self): + self.msg("Starting '%s'", " ".join(self.wrap_cmd)) + self.wrap_times.append(time.time()) + self.wrap_times.pop(0) + self.cmd = subprocess.Popen( + self.wrap_cmd, env=os.environ, preexec_fn=_subprocess_setup) + self.spawn_message = True + + def started(self): + """ + Called after Websockets server startup (i.e. after daemonize) + """ + # Need to call wrapped command after daemonization so we can + # know when the wrapped command exits + if self.wrap_cmd: + dst_string = "'%s' (port %s)" % (" ".join(self.wrap_cmd), self.target_port) + elif self.unix_target: + dst_string = self.unix_target + else: + dst_string = "%s:%s" % (self.target_host, self.target_port) + + if self.target_cfg: + msg = " - proxying from %s:%s to targets in %s" % ( + self.listen_host, self.listen_port, self.target_cfg) + else: + msg = " - proxying from %s:%s to %s" % ( + self.listen_host, self.listen_port, dst_string) + + if self.ssl_target: + msg += " (using SSL)" + + self.msg("%s", msg) + + if self.wrap_cmd: + self.run_wrap_cmd() + + def poll(self): + # If we are wrapping a command, check it's status + + if self.wrap_cmd and self.cmd: + ret = self.cmd.poll() + if ret != None: + self.vmsg("Wrapped command exited (or daemon). Returned %s" % ret) + self.cmd = None + + if self.wrap_cmd and self.cmd == None: + # Response to wrapped command being gone + if self.wrap_mode == "ignore": + pass + elif self.wrap_mode == "exit": + sys.exit(ret) + elif self.wrap_mode == "respawn": + now = time.time() + avg = sum(self.wrap_times)/len(self.wrap_times) + if (now - avg) < 10: + # 3 times in the last 10 seconds + if self.spawn_message: + self.warn("Command respawning too fast") + self.spawn_message = False + else: + self.run_wrap_cmd() def _subprocess_setup(): @@ -358,6 +357,8 @@ def websockify_init(): help="Configuration file containing valid targets " "in the form 'token: host:port' or, alternatively, a " "directory containing configuration files of this form") + parser.add_option("--libserver", action="store_true", + help="use Python library SocketServer engine") (opts, args) = parser.parse_args() if opts.verbose: @@ -406,8 +407,65 @@ def websockify_init(): opts.target_cfg = os.path.abspath(opts.target_cfg) # Create and start the WebSockets proxy - server = WebSocketProxy(**opts.__dict__) - server.start_server() + libserver = opts.libserver + del opts.libserver + if libserver: + # Use standard Python SocketServer framework + server = LibProxyServer(**opts.__dict__) + server.serve_forever() + else: + # Use internal service framework + server = WebSocketProxy(**opts.__dict__) + server.start_server() + + +class LibProxyServer(ForkingMixIn, HTTPServer): + """ + Just like WebSocketProxy, but uses standard Python SocketServer + framework. + """ + + def __init__(self, RequestHandlerClass=ProxyRequestHandler, **kwargs): + # Save off proxy specific options + self.target_host = kwargs.pop('target_host', None) + self.target_port = kwargs.pop('target_port', None) + self.wrap_cmd = kwargs.pop('wrap_cmd', None) + self.wrap_mode = kwargs.pop('wrap_mode', None) + self.unix_target = kwargs.pop('unix_target', None) + self.ssl_target = kwargs.pop('ssl_target', None) + self.target_cfg = kwargs.pop('target_cfg', None) + self.daemon = False + self.target_cfg = None + + # Server configuration + listen_host = kwargs.pop('listen_host', '') + listen_port = kwargs.pop('listen_port', None) + web = kwargs.pop('web', '') + + # Configuration affecting base request handler + self.only_upgrade = not web + self.verbose = kwargs.pop('verbose', False) + record = kwargs.pop('record', '') + if record: + self.record = os.path.abspath(record) + self.run_once = kwargs.pop('run_once', False) + self.handler_id = 0 + + for arg in kwargs.keys(): + print("warning: option %s ignored when using --libserver" % arg) + + if web: + os.chdir(web) + + HTTPServer.__init__(self, (listen_host, listen_port), + RequestHandlerClass) + + + def process_request(self, request, client_address): + """Override process_request to implement a counter""" + self.handler_id += 1 + ForkingMixIn.process_request(self, request, client_address) + if __name__ == '__main__': websockify_init() |