diff options
author | costasgambit <costas@gambitresearch.com> | 2017-09-11 18:28:57 +0100 |
---|---|---|
committer | Sergey Shepelev <temotor@gmail.com> | 2017-09-11 20:28:57 +0300 |
commit | b7d2a251ad55e1c161aa6c8aa236db456c4c4a21 (patch) | |
tree | 3b945af403dbd839ef08e7b65cf4dfac60fc8047 | |
parent | 82f1877ff5b950e1ee9debbb8be880ebedb8abcb (diff) | |
download | eventlet-b7d2a251ad55e1c161aa6c8aa236db456c4c4a21.tar.gz |
websocket: support permessage-deflate extension; Thanks to Costas Christofi and Peter Kovary
Support for compression extension as described in RFC7692 https://tools.ietf.org/html/rfc7692
https://github.com/eventlet/eventlet/pull/417
-rw-r--r-- | AUTHORS | 2 | ||||
-rw-r--r-- | eventlet/websocket.py | 183 | ||||
-rw-r--r-- | tests/websocket_new_test.py | 306 |
3 files changed, 471 insertions, 20 deletions
@@ -152,3 +152,5 @@ Thanks To * Aayush Kasurde * Linbing * Geoffrey Thomas +* Costas Christofi, adding permessage-deflate weboscket extension support +* Peter Kovary, adding permessage-deflate weboscket extension support diff --git a/eventlet/websocket.py b/eventlet/websocket.py index 1fdb3bf..857e94d 100644 --- a/eventlet/websocket.py +++ b/eventlet/websocket.py @@ -9,6 +9,8 @@ import struct import sys import time +import zlib + try: from hashlib import md5, sha1 except ImportError: # pragma NO COVER @@ -196,6 +198,76 @@ class WebSocketWSGI(object): sock.sendall(handshake_reply) return WebSocket(sock, environ, self.protocol_version) + def _parse_extension_header(self, header): + if header is None: + return None + res = {} + for ext in header.split(","): + parts = ext.split(";") + config = {} + for part in parts[1:]: + key_val = part.split("=") + if len(key_val) == 1: + config[key_val[0].strip().lower()] = True + else: + config[key_val[0].strip().lower()] = key_val[1].strip().strip('"').lower() + res.setdefault(parts[0].strip().lower(), []).append(config) + return res + + def _negotiate_permessage_deflate(self, extensions): + if not extensions: + return None + deflate = extensions.get("permessage-deflate") + if deflate is None: + return None + for config in deflate: + # We'll evaluate each config in the client's preferred order and pick + # the first that we can support. + want_config = { + # These are bool options, we can support both + "server_no_context_takeover": config.get("server_no_context_takeover", False), + "client_no_context_takeover": config.get("client_no_context_takeover", False) + } + # These are either bool OR int options. True means the client can accept a value + # for the option, a number means the client wants that specific value. + max_wbits = min(zlib.MAX_WBITS, 15) + mwb = config.get("server_max_window_bits") + if mwb is not None: + if mwb is True: + want_config["server_max_window_bits"] = max_wbits + else: + want_config["server_max_window_bits"] = \ + int(config.get("server_max_window_bits", max_wbits)) + if not (8 <= want_config["server_max_window_bits"] <= 15): + continue + mwb = config.get("client_max_window_bits") + if mwb is not None: + if mwb is True: + want_config["client_max_window_bits"] = max_wbits + else: + want_config["client_max_window_bits"] = \ + int(config.get("client_max_window_bits", max_wbits)) + if not (8 <= want_config["client_max_window_bits"] <= 15): + continue + return want_config + return None + + def _format_extension_header(self, parsed_extensions): + if not parsed_extensions: + return None + parts = [] + for name, config in parsed_extensions.items(): + ext_parts = [six.b(name)] + for key, value in config.items(): + if value is False: + pass + elif value is True: + ext_parts.append(six.b(key)) + else: + ext_parts.append(six.b("%s=%s" % (key, str(value)))) + parts.append(b"; ".join(ext_parts)) + return b", ".join(parts) + def _handle_hybi_request(self, environ): if 'eventlet.input' in environ: sock = environ['eventlet.input'].get_socket() @@ -226,9 +298,6 @@ class WebSocketWSGI(object): if p in self.supported_protocols: negotiated_protocol = p break - # extensions = environ.get('HTTP_SEC_WEBSOCKET_EXTENSIONS', None) - # if extensions: - # extensions = [i.strip() for i in extensions.split(',')] key = environ['HTTP_SEC_WEBSOCKET_KEY'] response = base64.b64encode(sha1(six.b(key) + PROTOCOL_GUID).digest()) @@ -238,9 +307,22 @@ class WebSocketWSGI(object): b"Sec-WebSocket-Accept: " + response] if negotiated_protocol: handshake_reply.append(b"Sec-WebSocket-Protocol: " + six.b(negotiated_protocol)) + + parsed_extensions = {} + extensions = self._parse_extension_header(environ.get("HTTP_SEC_WEBSOCKET_EXTENSIONS")) + + deflate = self._negotiate_permessage_deflate(extensions) + if deflate is not None: + parsed_extensions["permessage-deflate"] = deflate + + formatted_ext = self._format_extension_header(parsed_extensions) + if formatted_ext is not None: + handshake_reply.append(b"Sec-WebSocket-Extensions: " + formatted_ext) + sock.sendall(b'\r\n'.join(handshake_reply) + b'\r\n\r\n') return RFC6455WebSocket(sock, environ, self.protocol_version, - protocol=negotiated_protocol) + protocol=negotiated_protocol, + extensions=parsed_extensions) def _extract_number(self, value): """ @@ -296,8 +378,7 @@ class WebSocket(object): self._msgs = collections.deque() self._sendlock = semaphore.Semaphore() - @staticmethod - def _pack_message(message): + def _pack_message(self, message): """Pack the message inside ``00`` and ``FF`` As per the dataframing section (5.3) for the websocket spec @@ -409,11 +490,15 @@ class ProtocolError(ValueError): class RFC6455WebSocket(WebSocket): - def __init__(self, sock, environ, version=13, protocol=None, client=False): + def __init__(self, sock, environ, version=13, protocol=None, client=False, extensions=None): super(RFC6455WebSocket, self).__init__(sock, environ, version) self.iterator = self._iter_frames() self.client = client self.protocol = protocol + self.extensions = extensions or {} + + self._deflate_enc = None + self._deflate_dec = None class UTF8Decoder(object): def __init__(self): @@ -436,6 +521,45 @@ class RFC6455WebSocket(WebSocket): raise ValueError('Data is not valid unicode') return self.decoder.decode(data, final) + def _get_permessage_deflate_enc(self): + options = self.extensions.get("permessage-deflate") + if options is None: + return None + + def _make(): + return zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, + -options.get("client_max_window_bits" if self.client + else "server_max_window_bits", + zlib.MAX_WBITS)) + + if options.get("client_no_context_takeover" if self.client + else "server_no_context_takeover"): + # This option means we have to make a new one every time + return _make() + else: + if self._deflate_enc is None: + self._deflate_enc = _make() + return self._deflate_enc + + def _get_permessage_deflate_dec(self, rsv1): + options = self.extensions.get("permessage-deflate") + if options is None or not rsv1: + return None + + def _make(): + return zlib.decompressobj(-options.get("server_max_window_bits" if self.client + else "client_max_window_bits", + zlib.MAX_WBITS)) + + if options.get("server_no_context_takeover" if self.client + else "client_no_context_takeover"): + # This option means we have to make a new one every time + return _make() + else: + if self._deflate_dec is None: + self._deflate_dec = _make() + return self._deflate_dec + def _get_bytes(self, numbytes): data = b'' while len(data) < numbytes: @@ -446,20 +570,24 @@ class RFC6455WebSocket(WebSocket): return data class Message(object): - def __init__(self, opcode, decoder=None): + def __init__(self, opcode, decoder=None, decompressor=None): self.decoder = decoder self.data = [] self.finished = False self.opcode = opcode + self.decompressor = decompressor def push(self, data, final=False): - if self.decoder: - data = self.decoder.decode(data, final=final) self.finished = final self.data.append(data) def getvalue(self): - return ('' if self.decoder else b'').join(self.data) + data = b"".join(self.data) + if not self.opcode & 8 and self.decompressor: + data = self.decompressor.decompress(data + b'\x00\x00\xff\xff') + if self.decoder: + data = self.decoder.decode(data, self.finished) + return data @staticmethod def _apply_mask(data, mask, length=None, offset=0): @@ -523,16 +651,21 @@ class RFC6455WebSocket(WebSocket): def _recv_frame(self, message=None): recv = self._get_bytes + + # Unpacking the frame described in Section 5.2 of RFC6455 + # (https://tools.ietf.org/html/rfc6455#section-5.2) header = recv(2) a, b = struct.unpack('!BB', header) finished = a >> 7 == 1 rsv123 = a >> 4 & 7 + rsv1 = rsv123 & 4 if rsv123: - # must be zero - raise FailedConnectionError( - 1002, - "RSV1, RSV2, RSV3: MUST be 0 unless an extension is" - " negotiated that defines meanings for non-zero values.") + if rsv1 and "permessage-deflate" not in self.extensions: + # must be zero - unless it's compressed then rsv1 is true + raise FailedConnectionError( + 1002, + "RSV1, RSV2, RSV3: MUST be 0 unless an extension is" + " negotiated that defines meanings for non-zero values.") opcode = a & 15 if opcode not in (0, 1, 2, 8, 9, 0xA): raise FailedConnectionError(1002, "Unknown opcode received.") @@ -569,7 +702,8 @@ class RFC6455WebSocket(WebSocket): received = 0 if not message or opcode & 8: decoder = self.UTF8Decoder() if opcode == 1 else None - message = self.Message(opcode, decoder=decoder) + decompressor = self._get_permessage_deflate_dec(rsv1) + message = self.Message(opcode, decoder=decoder, decompressor=decompressor) if not length: message.push(b'', final=finished) else: @@ -588,13 +722,22 @@ class RFC6455WebSocket(WebSocket): 1007, "Text data must be valid utf-8") return message - @staticmethod - def _pack_message(message, masked=False, + def _pack_message(self, message, masked=False, continuation=False, final=True, control_code=None): is_text = False if isinstance(message, six.text_type): message = message.encode('utf-8') is_text = True + + compress_bit = 0 + compressor = self._get_permessage_deflate_enc() + if message and compressor: + message = compressor.compress(message) + message += compressor.flush(zlib.Z_SYNC_FLUSH) + assert message[-4:] == b"\x00\x00\xff\xff" + message = message[:-4] + compress_bit = 1 << 6 + length = len(message) if not length: # no point masking empty data @@ -608,7 +751,7 @@ class RFC6455WebSocket(WebSocket): raise ProtocolError('Control frame data too large (>125).') header = struct.pack('!B', control_code | 1 << 7) else: - opcode = 0 if continuation else (1 if is_text else 2) + opcode = 0 if continuation else ((1 if is_text else 2) | compress_bit) header = struct.pack('!B', opcode | (1 << 7 if final else 0)) lengthdata = 1 << 7 if masked else 0 if length > 65535: diff --git a/tests/websocket_new_test.py b/tests/websocket_new_test.py index 712bccd..cdbbf8b 100644 --- a/tests/websocket_new_test.py +++ b/tests/websocket_new_test.py @@ -1,5 +1,6 @@ import errno import struct +import re import eventlet from eventlet import event @@ -228,3 +229,308 @@ class TestWebSocket(tests.wsgi_test._TestBase): sock.sendall(b'\x07\xff') # Weird packet. done_with_request.wait() assert not error_detected[0] + + +class TestWebSocketWithCompression(tests.wsgi_test._TestBase): + TEST_TIMEOUT = 5 + + def set_site(self): + self.site = wsapp + + def setUp(self): + super(TestWebSocketWithCompression, self).setUp() + self.connect = '\r\n'.join([ + "GET /echo HTTP/1.1", + "Upgrade: websocket", + "Connection: upgrade", + "Host: %s:%s" % self.server_addr, + "Origin: http://%s:%s" % self.server_addr, + "Sec-WebSocket-Version: 13", + "Sec-WebSocket-Key: d9MXuOzlVQ0h+qRllvSCIg==", + "Sec-WebSocket-Extensions: %s", + '\r\n' + ]) + self.handshake_re = re.compile(six.b('\r\n'.join([ + 'HTTP/1.1 101 Switching Protocols', + 'Upgrade: websocket', + 'Connection: Upgrade', + 'Sec-WebSocket-Accept: ywSyWXCPNsDxLrQdQrn5RFNRfBU=', + 'Sec-WebSocket-Extensions: (.+)' + '\r\n', + ]))) + + @staticmethod + def get_deflated_reply(ws): + msg = ws._recv_frame(None) + msg.decompressor = None + return msg.getvalue() + + def test_accept_basic_deflate_ext_13(self): + for extension in [ + 'permessage-deflate', + 'PeRMessAGe-dEFlaTe', + ]: + sock = eventlet.connect(self.server_addr) + + sock.sendall(six.b(self.connect % extension)) + result = sock.recv(1024) + + # The server responds the correct Websocket handshake + # print('Extension offer: %r' % extension) + match = re.match(self.handshake_re, result) + assert match is not None + assert len(match.groups()) == 1 + + def test_accept_deflate_ext_context_takeover_13(self): + for extension in [ + 'permessage-deflate;CLient_No_conteXT_TAkeOver', + 'permessage-deflate; SerVER_No_conteXT_TAkeOver', + 'permessage-deflate; server_no_context_takeover; client_no_context_takeover', + ]: + sock = eventlet.connect(self.server_addr) + + sock.sendall(six.b(self.connect % extension)) + result = sock.recv(1024) + + # The server responds the correct Websocket handshake + # print('Extension offer: %r' % extension) + match = re.match(self.handshake_re, result) + assert match is not None + assert len(match.groups()) == 1 + offered_ext_parts = (ex.strip().lower() for ex in extension.split(';')) + accepted_ext_parts = match.groups()[0].decode().split('; ') + assert all(oep in accepted_ext_parts for oep in offered_ext_parts) + + def test_accept_deflate_ext_window_max_bits_13(self): + for extension_string, vals in [ + ('permessage-deflate; client_max_window_bits', [15]), + ('permessage-deflate; Server_Max_Window_Bits = 11', [11]), + ('permessage-deflate; server_max_window_bits; ' + 'client_max_window_bits=9', [15, 9]) + ]: + sock = eventlet.connect(self.server_addr) + + sock.sendall(six.b(self.connect % extension_string)) + result = sock.recv(1024) + + # The server responds the correct Websocket handshake + # print('Extension offer: %r' % extension_string) + match = re.match(self.handshake_re, result) + assert match is not None + assert len(match.groups()) == 1 + + offered_parts = [part.strip().lower() for part in extension_string.split(';')] + offered_parts_names = [part.split('=')[0].strip() for part in offered_parts] + offered_parts_dict = dict(zip(offered_parts_names[1:], vals)) + + accepted_ext_parts = match.groups()[0].decode().split('; ') + assert accepted_ext_parts[0] == 'permessage-deflate' + for param, val in (part.split('=') for part in accepted_ext_parts[1:]): + assert int(val) == offered_parts_dict[param] + + def test_reject_max_window_bits_out_of_range_13(self): + extension_string = ('permessage-deflate; client_max_window_bits=7,' + 'permessage-deflate; server_max_window_bits=16, ' + 'permessage-deflate; client_max_window_bits=16; ' + 'server_max_window_bits=7, ' + 'permessage-deflate') + sock = eventlet.connect(self.server_addr) + + sock.sendall(six.b(self.connect % extension_string)) + result = sock.recv(1024) + + # The server responds the correct Websocket handshake + # print('Extension offer: %r' % extension_string) + match = re.match(self.handshake_re, result) + assert match.groups()[0] == b'permessage-deflate' + + def test_server_compress_with_context_takeover_13(self): + extensions_string = 'permessage-deflate; client_no_context_takeover;' + extensions = {'permessage-deflate': { + 'client_no_context_takeover': True, + 'server_no_context_takeover': False}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + ws = websocket.RFC6455WebSocket(sock, {}, client=True, + extensions=extensions) + + # Deflated values taken from Section 7.2.3 of RFC 7692 + # https://tools.ietf.org/html/rfc7692#section-7.2.3 + ws.send(b'Hello') + msg1 = self.get_deflated_reply(ws) + assert msg1 == b'\xf2\x48\xcd\xc9\xc9\x07\x00' + + ws.send(b'Hello') + msg2 = self.get_deflated_reply(ws) + assert msg2 == b'\xf2\x00\x11\x00\x00' + + ws.close() + eventlet.sleep(0.01) + + def test_server_compress_no_context_takeover_13(self): + extensions_string = 'permessage-deflate; server_no_context_takeover;' + extensions = {'permessage-deflate': { + 'client_no_context_takeover': False, + 'server_no_context_takeover': True}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + ws = websocket.RFC6455WebSocket(sock, {}, client=True, + extensions=extensions) + + masked_msg1 = ws._pack_message(b'Hello', masked=True) + ws._send(masked_msg1) + masked_msg2 = ws._pack_message(b'Hello', masked=True) + ws._send(masked_msg2) + # Verify that client uses context takeover by checking + # that the second message + assert len(masked_msg2) < len(masked_msg1) + + # Verify that server drops context between messages + # Deflated values taken from Section 7.2.3 of RFC 7692 + # https://tools.ietf.org/html/rfc7692#section-7.2.3 + reply_msg1 = self.get_deflated_reply(ws) + assert reply_msg1 == b'\xf2\x48\xcd\xc9\xc9\x07\x00' + reply_msg2 = self.get_deflated_reply(ws) + assert reply_msg2 == b'\xf2\x48\xcd\xc9\xc9\x07\x00' + + def test_client_compress_with_context_takeover_13(self): + extensions = {'permessage-deflate': { + 'client_no_context_takeover': False, + 'server_no_context_takeover': True}} + ws = websocket.RFC6455WebSocket(None, {}, client=True, + extensions=extensions) + + # Deflated values taken from Section 7.2.3 of RFC 7692 + # modified opcode to Binary instead of Text + # https://tools.ietf.org/html/rfc7692#section-7.2.3 + packed_msg_1 = ws._pack_message(b'Hello', masked=False) + assert packed_msg_1 == b'\xc2\x07\xf2\x48\xcd\xc9\xc9\x07\x00' + packed_msg_2 = ws._pack_message(b'Hello', masked=False) + assert packed_msg_2 == b'\xc2\x05\xf2\x00\x11\x00\x00' + + eventlet.sleep(0.01) + + def test_client_compress_no_context_takeover_13(self): + extensions = {'permessage-deflate': { + 'client_no_context_takeover': True, + 'server_no_context_takeover': False}} + ws = websocket.RFC6455WebSocket(None, {}, client=True, + extensions=extensions) + + # Deflated values taken from Section 7.2.3 of RFC 7692 + # modified opcode to Binary instead of Text + # https://tools.ietf.org/html/rfc7692#section-7.2.3 + packed_msg_1 = ws._pack_message(b'Hello', masked=False) + assert packed_msg_1 == b'\xc2\x07\xf2\x48\xcd\xc9\xc9\x07\x00' + packed_msg_2 = ws._pack_message(b'Hello', masked=False) + assert packed_msg_2 == b'\xc2\x07\xf2\x48\xcd\xc9\xc9\x07\x00' + + def test_compressed_send_recv_13(self): + extensions_string = 'permessage-deflate' + extensions = {'permessage-deflate': { + 'client_no_context_takeover': False, + 'server_no_context_takeover': False}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions) + + ws.send(b'hello') + assert ws.wait() == b'hello' + ws.send(b'hello world!') + ws.send(u'hello world again!') + assert ws.wait() == b'hello world!' + assert ws.wait() == u'hello world again!' + + ws.close() + eventlet.sleep(0.01) + + def test_send_uncompressed_msg_13(self): + extensions_string = 'permessage-deflate' + extensions = {'permessage-deflate': { + 'client_no_context_takeover': False, + 'server_no_context_takeover': False}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + + # Send without using deflate, having rsv1 unset + ws = websocket.RFC6455WebSocket(sock, {}, client=True) + ws.send(b'Hello') + + # Adding extensions to recognise deflated response + ws.extensions = extensions + assert ws.wait() == b'Hello' + + ws.close() + eventlet.sleep(0.01) + + def test_compressed_send_recv_client_no_context_13(self): + extensions_string = 'permessage-deflate; client_no_context_takeover' + extensions = {'permessage-deflate': { + 'client_no_context_takeover': True, + 'server_no_context_takeover': False}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions) + + ws.send(b'hello') + assert ws.wait() == b'hello' + ws.send(b'hello world!') + ws.send(u'hello world again!') + assert ws.wait() == b'hello world!' + assert ws.wait() == u'hello world again!' + + ws.close() + eventlet.sleep(0.01) + + def test_compressed_send_recv_server_no_context_13(self): + extensions_string = 'permessage-deflate; server_no_context_takeover' + extensions = {'permessage-deflate': { + 'client_no_context_takeover': False, + 'server_no_context_takeover': False}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions) + + ws.send(b'hello') + assert ws.wait() == b'hello' + ws.send(b'hello world!') + ws.send(u'hello world again!') + assert ws.wait() == b'hello world!' + assert ws.wait() == u'hello world again!' + + ws.close() + eventlet.sleep(0.01) + + def test_compressed_send_recv_both_no_context_13(self): + extensions_string = ('permessage-deflate;' + ' server_no_context_takeover; client_no_context_takeover') + extensions = {'permessage-deflate': { + 'client_no_context_takeover': True, + 'server_no_context_takeover': True}} + + sock = eventlet.connect(self.server_addr) + sock.sendall(six.b(self.connect % extensions_string)) + sock.recv(1024) + ws = websocket.RFC6455WebSocket(sock, {}, client=True, extensions=extensions) + + ws.send(b'hello') + assert ws.wait() == b'hello' + ws.send(b'hello world!') + ws.send(u'hello world again!') + assert ws.wait() == b'hello world!' + assert ws.wait() == u'hello world again!' + + ws.close() + eventlet.sleep(0.01) |