diff options
author | Patryk Zawadzki <patrys@room-303.com> | 2013-11-06 18:52:19 +0100 |
---|---|---|
committer | Patryk Zawadzki <patrys@room-303.com> | 2013-11-06 18:52:19 +0100 |
commit | 9872a811f33f7843411431f3e0e12b2cc56c7f2d (patch) | |
tree | c959501ddd3b74889af498f490ff44b5c310ef06 | |
parent | 2b5dd54dabe599676f3fa5b4f4c90b59856e908f (diff) | |
download | pyjwt-9872a811f33f7843411431f3e0e12b2cc56c7f2d.tar.gz |
Basic Python 3 compatibility
-rw-r--r-- | jwt/__init__.py | 54 | ||||
-rwxr-xr-x[-rw-r--r--] | setup.py | 2 | ||||
-rw-r--r-- | tests/test_jwt.py | 71 |
3 files changed, 88 insertions, 39 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py index 88aa56c..910f120 100644 --- a/jwt/__init__.py +++ b/jwt/__init__.py @@ -3,9 +3,12 @@ Minimum implementation based on this spec: http://self-issued.info/docs/draft-jones-json-web-token-01.html """ +from __future__ import unicode_literals import base64 +import binascii import hashlib import hmac +import sys from datetime import datetime from calendar import timegm @@ -19,6 +22,10 @@ except ImportError: __all__ = ['encode', 'decode', 'DecodeError'] +if sys.version_info >= (3, 0, 0): + unicode = str + + class DecodeError(Exception): pass @@ -43,26 +50,31 @@ def constant_time_compare(val1, val2): if len(val1) != len(val2): return False result = 0 - for x, y in zip(val1, val2): - result |= ord(x) ^ ord(y) + if sys.version_info >= (3, 0, 0): # bytes are numbers + for x, y in zip(val1, val2): + result |= x ^ y + else: + for x, y in zip(val1, val2): + result |= ord(x) ^ ord(y) return result == 0 def base64url_decode(input): rem = len(input) % 4 if rem > 0: - input += '=' * (4 - rem) + input += b'=' * (4 - rem) return base64.urlsafe_b64decode(input) def base64url_encode(input): - return base64.urlsafe_b64encode(input).replace('=', '') + return base64.urlsafe_b64encode(input).replace(b'=', b'') def header(jwt): - header_segment = jwt.split('.', 1)[0] + header_segment = jwt.split(b'.', 1)[0] try: - return json.loads(base64url_decode(header_segment)) + header_data = base64url_decode(header_segment).decode('utf-8') + return json.loads(header_data) except (ValueError, TypeError): raise DecodeError("Invalid header encoding") @@ -77,15 +89,17 @@ def encode(payload, key, algorithm='HS256'): # Header header = {"typ": "JWT", "alg": algorithm} - segments.append(base64url_encode(json.dumps(header))) + json_header = json.dumps(header).encode('utf-8') + segments.append(base64url_encode(json_header)) # Payload if isinstance(payload.get('exp'), datetime): payload['exp'] = timegm(payload['exp'].utctimetuple()) - segments.append(base64url_encode(json.dumps(payload))) + json_payload = json.dumps(payload).encode('utf-8') + segments.append(base64url_encode(json_payload)) # Segments - signing_input = '.'.join(segments) + signing_input = b'.'.join(segments) try: if isinstance(key, unicode): key = key.encode('utf-8') @@ -93,33 +107,39 @@ def encode(payload, key, algorithm='HS256'): except KeyError: raise NotImplementedError("Algorithm not supported") segments.append(base64url_encode(signature)) - return '.'.join(segments) + return b'.'.join(segments) def decode(jwt, key='', verify=True, verify_expiration=True, leeway=0): + if isinstance(jwt, unicode): + jwt = jwt.encode('utf-8') try: - signing_input, crypto_segment = str(jwt).rsplit('.', 1) - header_segment, payload_segment = signing_input.split('.', 1) + signing_input, crypto_segment = jwt.rsplit(b'.', 1) + header_segment, payload_segment = signing_input.split(b'.', 1) except ValueError: raise DecodeError("Not enough segments") try: - header = json.loads(base64url_decode(header_segment)) - except TypeError: + header_data = base64url_decode(header_segment) + except (TypeError, binascii.Error): raise DecodeError("Invalid header padding") + try: + header = json.loads(header_data.decode('utf-8')) except ValueError as e: raise DecodeError("Invalid header string: %s" % e) try: - payload = json.loads(base64url_decode(payload_segment)) - except TypeError: + payload_data = base64url_decode(payload_segment) + except (TypeError, binascii.Error): raise DecodeError("Invalid payload padding") + try: + payload = json.loads(payload_data.decode('utf-8')) except ValueError as e: raise DecodeError("Invalid payload string: %s" % e) try: signature = base64url_decode(crypto_segment) - except TypeError: + except (TypeError, binascii.Error): raise DecodeError("Invalid crypto padding") if verify: @@ -1,3 +1,4 @@ +#!/usr/bin/env python import os from setuptools import setup @@ -21,4 +22,5 @@ setup( "Topic :: Utilities", "License :: OSI Approved :: MIT License", ], + test_suite='tests.test_jwt' ) diff --git a/tests/test_jwt.py b/tests/test_jwt.py index aec0d22..74f1ece 100644 --- a/tests/test_jwt.py +++ b/tests/test_jwt.py @@ -1,10 +1,14 @@ -import unittest +from __future__ import unicode_literals +from calendar import timegm +from datetime import datetime +import sys import time +import unittest import jwt -from datetime import datetime -from calendar import timegm +if sys.version_info >= (3, 0, 0): + unicode = str def utc_timestamp(): @@ -14,7 +18,8 @@ def utc_timestamp(): class TestJWT(unittest.TestCase): def setUp(self): - self.payload = {"iss": "jeff", "exp": utc_timestamp() + 1, "claim": "insanity"} + self.payload = {"iss": "jeff", "exp": utc_timestamp() + 1, + "claim": "insanity"} def test_encode_decode(self): secret = 'secret' @@ -36,7 +41,8 @@ class TestJWT(unittest.TestCase): payload = {"exp": current_datetime} jwt_message = jwt.encode(payload, secret) decoded_payload = jwt.decode(jwt_message, secret, leeway=1) - self.assertEqual(decoded_payload['exp'], + self.assertEqual( + decoded_payload['exp'], timegm(current_datetime.utctimetuple())) def test_bad_secret(self): @@ -49,27 +55,29 @@ class TestJWT(unittest.TestCase): def test_decodes_valid_jwt(self): example_payload = {"hello": "world"} example_secret = "secret" - example_jwt = "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8" + example_jwt = ( + b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + b".eyJoZWxsbyI6ICJ3b3JsZCJ9" + b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") decoded_payload = jwt.decode(example_jwt, example_secret) self.assertEqual(decoded_payload, example_payload) def test_allow_skip_verification(self): right_secret = 'foo' - bad_secret = 'bar' jwt_message = jwt.encode(self.payload, right_secret) decoded_payload = jwt.decode(jwt_message, verify=False) self.assertEqual(decoded_payload, self.payload) def test_no_secret(self): right_secret = 'foo' - bad_secret = 'bar' jwt_message = jwt.encode(self.payload, right_secret) with self.assertRaises(jwt.DecodeError): jwt.decode(jwt_message) def test_invalid_crypto_alg(self): - self.assertRaises(NotImplementedError, jwt.encode, self.payload, "secret", "HS1024") + self.assertRaises(NotImplementedError, jwt.encode, self.payload, + "secret", "HS1024") def test_unicode_secret(self): secret = u'\xc2' @@ -78,7 +86,7 @@ class TestJWT(unittest.TestCase): self.assertEqual(decoded_payload, self.payload) def test_nonascii_secret(self): - secret = '\xc2' # char value that ascii codec cannot decode + secret = '\xc2' # char value that ascii codec cannot decode jwt_message = jwt.encode(self.payload, secret) decoded_payload = jwt.decode(jwt_message, secret) self.assertEqual(decoded_payload, self.payload) @@ -86,39 +94,58 @@ class TestJWT(unittest.TestCase): def test_decode_unicode_value(self): example_payload = {"hello": "world"} example_secret = "secret" - example_jwt = unicode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") + example_jwt = ( + "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + ".eyJoZWxsbyI6ICJ3b3JsZCJ9" + ".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") decoded_payload = jwt.decode(example_jwt, example_secret) self.assertEqual(decoded_payload, example_payload) def test_decode_invalid_header_padding(self): - example_jwt = unicode("aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") + example_jwt = ( + "aeyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + ".eyJoZWxsbyI6ICJ3b3JsZCJ9" + ".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") example_secret = "secret" with self.assertRaises(jwt.DecodeError): - jwt_message = jwt.decode(example_jwt, example_secret) + jwt.decode(example_jwt, example_secret) def test_decode_invalid_header_string(self): - example_jwt = unicode("eyJhbGciOiAiSFMyNTbpIiwgInR5cCI6ICJKV1QifQ==.eyJoZWxsbyI6ICJ3b3JsZCJ9.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") + example_jwt = ( + "eyJhbGciOiAiSFMyNTbpIiwgInR5cCI6ICJKV1QifQ==" + ".eyJoZWxsbyI6ICJ3b3JsZCJ9" + ".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") example_secret = "secret" with self.assertRaisesRegexp(jwt.DecodeError, "Invalid header string"): - jwt_message = jwt.decode(example_jwt, example_secret) + jwt.decode(example_jwt, example_secret) def test_decode_invalid_payload_padding(self): - example_jwt = unicode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.aeyJoZWxsbyI6ICJ3b3JsZCJ9.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") + example_jwt = ( + "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + ".aeyJoZWxsbyI6ICJ3b3JsZCJ9" + ".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") example_secret = "secret" with self.assertRaises(jwt.DecodeError): - jwt_message = jwt.decode(example_jwt, example_secret) + jwt.decode(example_jwt, example_secret) def test_decode_invalid_payload_string(self): - example_jwt = unicode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsb-kiOiAid29ybGQifQ==.tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") + example_jwt = ( + "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + ".eyJoZWxsb-kiOiAid29ybGQifQ==" + ".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") example_secret = "secret" - with self.assertRaisesRegexp(jwt.DecodeError, "Invalid payload string"): - jwt_message = jwt.decode(example_jwt, example_secret) + with self.assertRaisesRegexp(jwt.DecodeError, + "Invalid payload string"): + jwt.decode(example_jwt, example_secret) def test_decode_invalid_crypto_padding(self): - example_jwt = unicode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJoZWxsbyI6ICJ3b3JsZCJ9.aatvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") + example_jwt = ( + "eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9" + ".eyJoZWxsbyI6ICJ3b3JsZCJ9" + ".aatvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8") example_secret = "secret" with self.assertRaises(jwt.DecodeError): - jwt_message = jwt.decode(example_jwt, example_secret) + jwt.decode(example_jwt, example_secret) def test_decode_with_expiration(self): self.payload['exp'] = utc_timestamp() - 1 |