diff options
author | tifayuki <tifayuki@gmail.com> | 2016-12-21 16:25:52 -0800 |
---|---|---|
committer | tifayuki <tifayuki@gmail.com> | 2016-12-21 16:28:14 -0800 |
commit | e174c0350db6180e38e10c210078ad6515c5f039 (patch) | |
tree | 09ada81ff9730984a6913f0caf5fbf22d34fd70c /websocket | |
parent | 7d682da323f421241ed11dcf797d9baa15e2debe (diff) | |
download | websocket-client-e174c0350db6180e38e10c210078ad6515c5f039.tar.gz |
Implement simple cookie jar
Diffstat (limited to 'websocket')
-rw-r--r-- | websocket/_cookiejar.py | 52 | ||||
-rw-r--r-- | websocket/_handshake.py | 9 | ||||
-rw-r--r-- | websocket/tests/test_cookiejar.py | 98 |
3 files changed, 158 insertions, 1 deletions
diff --git a/websocket/_cookiejar.py b/websocket/_cookiejar.py new file mode 100644 index 0000000..8a30352 --- /dev/null +++ b/websocket/_cookiejar.py @@ -0,0 +1,52 @@ +try: + import Cookie +except: + import http.cookies as Cookie + + +class SimpleCookieJar(object): + def __init__(self): + self.jar = dict() + + def add(self, set_cookie): + if set_cookie: + try: + simpleCookie = Cookie.SimpleCookie(set_cookie) + except: + simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore')) + + for k, v in simpleCookie.items(): + domain = v.get("domain") + if domain: + if not domain.startswith("."): + domain = "." + domain + cookie = self.jar.get(domain) if self.jar.get(domain) else Cookie.SimpleCookie() + cookie.update(simpleCookie) + self.jar[domain.lower()] = cookie + + def set(self, set_cookie): + if set_cookie: + try: + simpleCookie = Cookie.SimpleCookie(set_cookie) + except: + simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore')) + + for k, v in simpleCookie.items(): + domain = v.get("domain") + if domain: + if not domain.startswith("."): + domain = "." + domain + self.jar[domain.lower()] = simpleCookie + + def get(self, host): + if not host: + return "" + + cookies = [] + for domain, simpleCookie in self.jar.items(): + host = host.lower() + if host.endswith(domain) or host == domain[1:]: + cookies.append(self.jar.get(domain)) + + return "; ".join(filter(None, ["%s=%s" % (k, v.value) for cookie in filter(None, sorted(cookies)) for k, v in + cookie.items()])) diff --git a/websocket/_handshake.py b/websocket/_handshake.py index f2c5352..d8116ed 100644 --- a/websocket/_handshake.py +++ b/websocket/_handshake.py @@ -25,6 +25,7 @@ import os import six +from ._cookiejar import SimpleCookieJar from ._exceptions import * from ._http import * from ._logging import * @@ -46,6 +47,8 @@ else: # websocket supported version. VERSION = 13 +CookieJar = SimpleCookieJar() + class handshake_response(object): @@ -53,6 +56,7 @@ class handshake_response(object): self.status = status self.headers = headers self.subprotocol = subprotocol + CookieJar.add(headers.get("set-cookie")) def handshake(sock, hostname, port, resource, **options): @@ -105,7 +109,10 @@ def _get_handshake_headers(resource, host, port, options): header = map(": ".join, header.items()) headers.extend(header) - cookie = options.get("cookie", None) + server_cookie = CookieJar.get(host) + client_cookie = options.get("cookie", None) + + cookie = "; ".join(filter(None, [server_cookie, client_cookie])) if cookie: headers.append("Cookie: %s" % cookie) diff --git a/websocket/tests/test_cookiejar.py b/websocket/tests/test_cookiejar.py new file mode 100644 index 0000000..c40a00b --- /dev/null +++ b/websocket/tests/test_cookiejar.py @@ -0,0 +1,98 @@ +import unittest + +from websocket._cookiejar import SimpleCookieJar + +try: + import Cookie +except: + import http.cookies as Cookie + + +class CookieJarTest(unittest.TestCase): + def testAdd(self): + cookie_jar = SimpleCookieJar() + cookie_jar.add("") + self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b") + self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; domain=.abc") + self.assertTrue(".abc" in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; domain=abc") + self.assertTrue(".abc" in cookie_jar.jar) + self.assertTrue("abc" not in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=.abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d; e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.add("a=b; c=d; domain=abc") + cookie_jar.add("e=f; domain=xyz") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + self.assertEquals(cookie_jar.get("xyz"), "e=f") + self.assertEquals(cookie_jar.get("something"), "") + + def testSet(self): + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b") + self.assertFalse(cookie_jar.jar, "Cookie with no domain should not be added to the jar") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; domain=.abc") + self.assertTrue(".abc" in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; domain=abc") + self.assertTrue(".abc" in cookie_jar.jar) + self.assertTrue("abc" not in cookie_jar.jar) + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=abc") + self.assertEquals(cookie_jar.get("abc"), "e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=.abc") + self.assertEquals(cookie_jar.get("abc"), "e=f") + + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc") + cookie_jar.set("e=f; domain=xyz") + self.assertEquals(cookie_jar.get("abc"), "a=b; c=d") + self.assertEquals(cookie_jar.get("xyz"), "e=f") + self.assertEquals(cookie_jar.get("something"), "") + + def testGet(self): + cookie_jar = SimpleCookieJar() + cookie_jar.set("a=b; c=d; domain=abc.com") + self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("abc.com.es"), "") + self.assertEquals(cookie_jar.get("xabc.com"), "") + + cookie_jar.set("a=b; c=d; domain=.abc.com") + self.assertEquals(cookie_jar.get("abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("x.abc.com"), "a=b; c=d") + self.assertEquals(cookie_jar.get("abc.com.es"), "") + self.assertEquals(cookie_jar.get("xabc.com"), "") |