diff options
author | Hiroki Ohtani <hiro@HiroMacBookAir11.local> | 2011-01-05 09:19:19 +0900 |
---|---|---|
committer | Hiroki Ohtani <hiro@HiroMacBookAir11.local> | 2011-01-05 09:19:19 +0900 |
commit | 5f615f694942748e54878cca3bf7de83844cc8e0 (patch) | |
tree | ede9d36cfcc64673a6991a01ece6e1048a5de42d | |
parent | 2be0632fd7e62e11efd4903e6b1a61329a44c14b (diff) | |
download | websocket-client-5f615f694942748e54878cca3bf7de83844cc8e0.tar.gz |
- support wss
-rw-r--r-- | test.py | 22 | ||||
-rw-r--r-- | websocket.py | 51 |
2 files changed, 51 insertions, 22 deletions
@@ -49,34 +49,44 @@ class WebSocketTest(unittest.TestCase): self.assertEquals(p[0], "www.example.com") self.assertEquals(p[1], 80) self.assertEquals(p[2], "/r") + self.assertEquals(p[3], False) p = ws._parse_url("ws://www.example.com/") self.assertEquals(p[0], "www.example.com") self.assertEquals(p[1], 80) self.assertEquals(p[2], "/") + self.assertEquals(p[3], False) p = ws._parse_url("ws://www.example.com") self.assertEquals(p[0], "www.example.com") self.assertEquals(p[1], 80) self.assertEquals(p[2], "/") + self.assertEquals(p[3], False) p = ws._parse_url("ws://www.example.com:8080/r") self.assertEquals(p[0], "www.example.com") self.assertEquals(p[1], 8080) self.assertEquals(p[2], "/r") + self.assertEquals(p[3], False) p = ws._parse_url("ws://www.example.com:8080/") self.assertEquals(p[0], "www.example.com") self.assertEquals(p[1], 8080) self.assertEquals(p[2], "/") + self.assertEquals(p[3], False) p = ws._parse_url("ws://www.example.com:8080") self.assertEquals(p[0], "www.example.com") self.assertEquals(p[1], 8080) self.assertEquals(p[2], "/") + self.assertEquals(p[3], False) + + p = ws._parse_url("wss://www.example.com:8080/r") + self.assertEquals(p[0], "www.example.com") + self.assertEquals(p[1], 8080) + self.assertEquals(p[2], "/r") + self.assertEquals(p[3], True) - # we do not support wss for a while - self.assertRaises(ValueError, ws._parse_url, "wss://www.example.com/r") self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r") def testWSKey(self): @@ -127,7 +137,7 @@ class WebSocketTest(unittest.TestCase): def testReadHeader(self): sock = ws.WebSocket() - sock.sock = HeaderSockMock("data/header01.txt") + sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt") status, header = sock._read_headers() self.assertEquals(status, 101) self.assertEquals(header["connection"], "upgrade") @@ -135,12 +145,12 @@ class WebSocketTest(unittest.TestCase): self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa") - sock.sock = HeaderSockMock("data/header02.txt") + sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt") self.assertRaises(ws.WebSocketException, sock._read_headers) def testSend(self): sock = ws.WebSocket() - s = sock.sock = HeaderSockMock("data/header01.txt") + s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt") sock.send("Hello") self.assertEquals(s.sent[0], "\x00Hello\xff") sock.send("こんにちは") @@ -150,7 +160,7 @@ class WebSocketTest(unittest.TestCase): def testRecv(self): sock = ws.WebSocket() - s = sock.sock = StringSockMock() + s = sock.io_sock = sock.sock = StringSockMock() s.set_data("\x00こんにちは\xff") data = sock.recv() self.assertEquals(data, "こんにちは") diff --git a/websocket.py b/websocket.py index 2e7ae30..5a33c41 100644 --- a/websocket.py +++ b/websocket.py @@ -33,20 +33,27 @@ def getdefaulttimeout(): return default_timeout def _parse_url(url): + """ + parse url and the result is tuple of + (hostname, port, resource path and the flag of secure mode) + """ parsed = urlparse(url) if parsed.hostname: hostname = parsed.hostname else: raise ValueError("hostname is invalid") - + port = 0 + if parsed.port: + port = parsed.port + + is_secure = False if parsed.scheme == "ws": - if parsed.port: - port = parsed.port - else: + if not port: port = 80 elif parsed.scheme == "wss": - # TODO: support wss - raise ValueError("scheme wss is not supported") + is_secure = True + if not port: + port = 443 else: raise ValueError("scheme %s is invalid" % parsed.scheme) @@ -55,7 +62,7 @@ def _parse_url(url): else: resource = "/" - return (hostname, port, resource) + return (hostname, port, resource, is_secure) def create_connection(url, timeout=None, **options): @@ -111,7 +118,16 @@ HEADERS_TO_EXIST_FOR_HIXIE75 = [ "websocket-origin", "websocket-location", ] + +class SSLSocketWrapper(object): + def __init__(self, sock): + self.ssl = socket.ssl(sock) + + def recv(self, bufsize): + return self.ssl.read(bufsize) + def send(self, payload): + return self.ssl.write(payload) class WebSocket(object): def __init__(self): @@ -119,7 +135,7 @@ class WebSocket(object): Initalize WebSocket object. """ self.connected = False - self.sock = socket.socket() + self.io_sock = self.sock = socket.socket() def settimeout(self, timeout): """ @@ -137,13 +153,15 @@ class WebSocket(object): """ Connect to url. url is websocket url scheme. ie. ws://host:port/resource """ - hostname, port, resource = _parse_url(url) + hostname, port, resource, is_secure = _parse_url(url) # TODO: we need to support proxy self.sock.connect((hostname, port)) + if is_secure: + self.io_sock = SSLSocketWrapper(self.sock) self._handshake(hostname, port, resource, **options) def _handshake(self, host, port, resource, **options): - sock = self.sock + sock = self.io_sock headers = [] if "header" in options: headers.extend(options["header"]) @@ -175,17 +193,17 @@ class WebSocket(object): status, resp_headers = self._read_headers() if status != 101: - self.sock.close() + self.close() raise WebSocketException("Handshake Status %d" % status) success, secure = self._validate_header(resp_headers) if not success: - self.sock.close() + self.close() raise WebSocketException("Invalid WebSocket Header") if secure: resp = self._get_resp() if not self._validate_resp(number_1, number_2, key3, resp): - self.sock.close() + self.close() raise WebSocketException("challenge-response error") self.connected = True @@ -268,7 +286,7 @@ class WebSocket(object): """ if isinstance(payload, unicode): payload = payload.encode("utf-8") - self.sock.send("".join(["\x00", payload, "\xff"])) + self.io_sock.send("".join(["\x00", payload, "\xff"])) def recv(self): """ @@ -307,16 +325,17 @@ class WebSocket(object): """ if self.connected: try: - self.sock.send("\xff\x00") + self.io_sock.send("\xff\x00") result = self._recv(2) if result != "\xff\x00": logger.error("bad closing Handshake") except: pass self.sock.close() + self.io_sock = self.sock def _recv(self, bufsize): - bytes = self.sock.recv(bufsize) + bytes = self.io_sock.recv(bufsize) if not bytes: raise ConnectionClosedException() return bytes |