summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroki Ohtani <hiro@HiroMacBookAir11.local>2011-01-05 09:19:19 +0900
committerHiroki Ohtani <hiro@HiroMacBookAir11.local>2011-01-05 09:19:19 +0900
commit5f615f694942748e54878cca3bf7de83844cc8e0 (patch)
treeede9d36cfcc64673a6991a01ece6e1048a5de42d
parent2be0632fd7e62e11efd4903e6b1a61329a44c14b (diff)
downloadwebsocket-client-5f615f694942748e54878cca3bf7de83844cc8e0.tar.gz
- support wss
-rw-r--r--test.py22
-rw-r--r--websocket.py51
2 files changed, 51 insertions, 22 deletions
diff --git a/test.py b/test.py
index eb1b94a..e2ff829 100644
--- a/test.py
+++ b/test.py
@@ -49,34 +49,44 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/r")
+ self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com/")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/")
+ self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 80)
self.assertEquals(p[2], "/")
+ self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080/r")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/r")
+ self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080/")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/")
+ self.assertEquals(p[3], False)
p = ws._parse_url("ws://www.example.com:8080")
self.assertEquals(p[0], "www.example.com")
self.assertEquals(p[1], 8080)
self.assertEquals(p[2], "/")
+ self.assertEquals(p[3], False)
+
+ p = ws._parse_url("wss://www.example.com:8080/r")
+ self.assertEquals(p[0], "www.example.com")
+ self.assertEquals(p[1], 8080)
+ self.assertEquals(p[2], "/r")
+ self.assertEquals(p[3], True)
- # we do not support wss for a while
- self.assertRaises(ValueError, ws._parse_url, "wss://www.example.com/r")
self.assertRaises(ValueError, ws._parse_url, "http://www.example.com/r")
def testWSKey(self):
@@ -127,7 +137,7 @@ class WebSocketTest(unittest.TestCase):
def testReadHeader(self):
sock = ws.WebSocket()
- sock.sock = HeaderSockMock("data/header01.txt")
+ sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
status, header = sock._read_headers()
self.assertEquals(status, 101)
self.assertEquals(header["connection"], "upgrade")
@@ -135,12 +145,12 @@ class WebSocketTest(unittest.TestCase):
self.assertEquals(sock._get_resp(), "ssssss\r\naaaaaaaa")
- sock.sock = HeaderSockMock("data/header02.txt")
+ sock.io_sock = sock.sock = HeaderSockMock("data/header02.txt")
self.assertRaises(ws.WebSocketException, sock._read_headers)
def testSend(self):
sock = ws.WebSocket()
- s = sock.sock = HeaderSockMock("data/header01.txt")
+ s = sock.io_sock = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello")
self.assertEquals(s.sent[0], "\x00Hello\xff")
sock.send("こんにちは")
@@ -150,7 +160,7 @@ class WebSocketTest(unittest.TestCase):
def testRecv(self):
sock = ws.WebSocket()
- s = sock.sock = StringSockMock()
+ s = sock.io_sock = sock.sock = StringSockMock()
s.set_data("\x00こんにちは\xff")
data = sock.recv()
self.assertEquals(data, "こんにちは")
diff --git a/websocket.py b/websocket.py
index 2e7ae30..5a33c41 100644
--- a/websocket.py
+++ b/websocket.py
@@ -33,20 +33,27 @@ def getdefaulttimeout():
return default_timeout
def _parse_url(url):
+ """
+ parse url and the result is tuple of
+ (hostname, port, resource path and the flag of secure mode)
+ """
parsed = urlparse(url)
if parsed.hostname:
hostname = parsed.hostname
else:
raise ValueError("hostname is invalid")
-
+ port = 0
+ if parsed.port:
+ port = parsed.port
+
+ is_secure = False
if parsed.scheme == "ws":
- if parsed.port:
- port = parsed.port
- else:
+ if not port:
port = 80
elif parsed.scheme == "wss":
- # TODO: support wss
- raise ValueError("scheme wss is not supported")
+ is_secure = True
+ if not port:
+ port = 443
else:
raise ValueError("scheme %s is invalid" % parsed.scheme)
@@ -55,7 +62,7 @@ def _parse_url(url):
else:
resource = "/"
- return (hostname, port, resource)
+ return (hostname, port, resource, is_secure)
def create_connection(url, timeout=None, **options):
@@ -111,7 +118,16 @@ HEADERS_TO_EXIST_FOR_HIXIE75 = [
"websocket-origin",
"websocket-location",
]
+
+class SSLSocketWrapper(object):
+ def __init__(self, sock):
+ self.ssl = socket.ssl(sock)
+
+ def recv(self, bufsize):
+ return self.ssl.read(bufsize)
+ def send(self, payload):
+ return self.ssl.write(payload)
class WebSocket(object):
def __init__(self):
@@ -119,7 +135,7 @@ class WebSocket(object):
Initalize WebSocket object.
"""
self.connected = False
- self.sock = socket.socket()
+ self.io_sock = self.sock = socket.socket()
def settimeout(self, timeout):
"""
@@ -137,13 +153,15 @@ class WebSocket(object):
"""
Connect to url. url is websocket url scheme. ie. ws://host:port/resource
"""
- hostname, port, resource = _parse_url(url)
+ hostname, port, resource, is_secure = _parse_url(url)
# TODO: we need to support proxy
self.sock.connect((hostname, port))
+ if is_secure:
+ self.io_sock = SSLSocketWrapper(self.sock)
self._handshake(hostname, port, resource, **options)
def _handshake(self, host, port, resource, **options):
- sock = self.sock
+ sock = self.io_sock
headers = []
if "header" in options:
headers.extend(options["header"])
@@ -175,17 +193,17 @@ class WebSocket(object):
status, resp_headers = self._read_headers()
if status != 101:
- self.sock.close()
+ self.close()
raise WebSocketException("Handshake Status %d" % status)
success, secure = self._validate_header(resp_headers)
if not success:
- self.sock.close()
+ self.close()
raise WebSocketException("Invalid WebSocket Header")
if secure:
resp = self._get_resp()
if not self._validate_resp(number_1, number_2, key3, resp):
- self.sock.close()
+ self.close()
raise WebSocketException("challenge-response error")
self.connected = True
@@ -268,7 +286,7 @@ class WebSocket(object):
"""
if isinstance(payload, unicode):
payload = payload.encode("utf-8")
- self.sock.send("".join(["\x00", payload, "\xff"]))
+ self.io_sock.send("".join(["\x00", payload, "\xff"]))
def recv(self):
"""
@@ -307,16 +325,17 @@ class WebSocket(object):
"""
if self.connected:
try:
- self.sock.send("\xff\x00")
+ self.io_sock.send("\xff\x00")
result = self._recv(2)
if result != "\xff\x00":
logger.error("bad closing Handshake")
except:
pass
self.sock.close()
+ self.io_sock = self.sock
def _recv(self, bufsize):
- bytes = self.sock.recv(bufsize)
+ bytes = self.io_sock.recv(bufsize)
if not bytes:
raise ConnectionClosedException()
return bytes