summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosé Padilla <jpadilla@webapplicate.com>2014-02-08 01:20:36 -0500
committerJosé Padilla <jpadilla@webapplicate.com>2014-02-08 01:20:36 -0500
commitd4c437e45ce7690317fc9752b85cbf11b44219e0 (patch)
treee95128236b244f63cb7c307a2e6e218c22658eaa
parent3bade2705b75909e7e589723d7d61c808fe16d3d (diff)
parentc7fb44820506c99df2d48a81660e501e9716e845 (diff)
downloadpyjwt-d4c437e45ce7690317fc9752b85cbf11b44219e0.tar.gz
Merge pull request #25 from dystedium/split_decode
refactor decode(), fix setup.py for automated sdist builds
-rw-r--r--jwt/__init__.py51
-rwxr-xr-xsetup.py7
-rw-r--r--tests/test_jwt.py75
3 files changed, 110 insertions, 23 deletions
diff --git a/jwt/__init__.py b/jwt/__init__.py
index 708d7ca..4e1d5d2 100644
--- a/jwt/__init__.py
+++ b/jwt/__init__.py
@@ -137,6 +137,16 @@ def encode(payload, key, algorithm='HS256'):
def decode(jwt, key='', verify=True, verify_expiration=True, leeway=0):
+ payload, signing_input, header, signature = load(jwt)
+
+ if verify:
+ verify_signature(payload, signing_input, header, signature, key,
+ verify_expiration, leeway)
+
+ return payload
+
+
+def load(jwt):
if isinstance(jwt, unicode):
jwt = jwt.encode('utf-8')
try:
@@ -168,22 +178,25 @@ def decode(jwt, key='', verify=True, verify_expiration=True, leeway=0):
except (TypeError, binascii.Error):
raise DecodeError("Invalid crypto padding")
- if verify:
- try:
- if isinstance(key, unicode):
- key = key.encode('utf-8')
- if header['alg'].startswith('HS'):
- expected = verify_methods[header['alg']](signing_input, key)
- if not constant_time_compare(signature, expected):
- raise DecodeError("Signature verification failed")
- else:
- if not verify_methods[header['alg']](signing_input, key, signature):
- raise DecodeError("Signature verification failed")
- except KeyError:
- raise DecodeError("Algorithm not supported")
-
- if 'exp' in payload and verify_expiration:
- utc_timestamp = timegm(datetime.utcnow().utctimetuple())
- if payload['exp'] < (utc_timestamp - leeway):
- raise ExpiredSignature("Signature has expired")
- return payload
+ return (payload, signing_input, header, signature)
+
+
+def verify_signature(payload, signing_input, header, signature, key='',
+ verify_expiration=True, leeway=0):
+ try:
+ if isinstance(key, unicode):
+ key = key.encode('utf-8')
+ if header['alg'].startswith('HS'):
+ expected = verify_methods[header['alg']](signing_input, key)
+ if not constant_time_compare(signature, expected):
+ raise DecodeError("Signature verification failed")
+ else:
+ if not verify_methods[header['alg']](signing_input, key, signature):
+ raise DecodeError("Signature verification failed")
+ except KeyError:
+ raise DecodeError("Algorithm not supported")
+
+ if 'exp' in payload and verify_expiration:
+ utc_timestamp = timegm(datetime.utcnow().utctimetuple())
+ if payload['exp'] < (utc_timestamp - leeway):
+ raise ExpiredSignature("Signature has expired")
diff --git a/setup.py b/setup.py
index 32f2271..cdd14a3 100755
--- a/setup.py
+++ b/setup.py
@@ -3,8 +3,9 @@ import os
from setuptools import setup
-def read(fname):
- return open(os.path.join(os.path.dirname(__file__), fname)).read()
+with open(os.path.join(os.path.dirname(__file__), 'README.md')) as readme:
+ long_description = readme.read()
+
setup(
name="PyJWT",
@@ -17,7 +18,7 @@ setup(
url="http://github.com/progrium/pyjwt",
packages=['jwt'],
scripts=['bin/jwt'],
- long_description=read('README.md'),
+ long_description=long_description,
classifiers=[
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: MIT License",
diff --git a/tests/test_jwt.py b/tests/test_jwt.py
index fb83136..4192068 100644
--- a/tests/test_jwt.py
+++ b/tests/test_jwt.py
@@ -62,12 +62,29 @@ class TestJWT(unittest.TestCase):
decoded_payload = jwt.decode(example_jwt, example_secret)
self.assertEqual(decoded_payload, example_payload)
+ def test_load_verify_valid_jwt(self):
+ example_payload = {"hello": "world"}
+ example_secret = "secret"
+ example_jwt = (
+ b"eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9"
+ b".eyJoZWxsbyI6ICJ3b3JsZCJ9"
+ b".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8")
+ decoded_payload, signing_input, header, signature = jwt.load(example_jwt)
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, example_secret)
+ self.assertEqual(decoded_payload, example_payload)
+
def test_allow_skip_verification(self):
right_secret = 'foo'
jwt_message = jwt.encode(self.payload, right_secret)
decoded_payload = jwt.decode(jwt_message, verify=False)
self.assertEqual(decoded_payload, self.payload)
+ def test_load_no_verification(self):
+ right_secret = 'foo'
+ jwt_message = jwt.encode(self.payload, right_secret)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+ self.assertEqual(decoded_payload, self.payload)
+
def test_no_secret(self):
right_secret = 'foo'
jwt_message = jwt.encode(self.payload, right_secret)
@@ -75,6 +92,14 @@ class TestJWT(unittest.TestCase):
with self.assertRaises(jwt.DecodeError):
jwt.decode(jwt_message)
+ def test_verify_signature_no_secret(self):
+ right_secret = 'foo'
+ jwt_message = jwt.encode(self.payload, right_secret)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+
+ with self.assertRaises(jwt.DecodeError):
+ jwt.verify_signature(decoded_payload, signing_input, header, signature)
+
def test_invalid_crypto_alg(self):
self.assertRaises(NotImplementedError, jwt.encode, self.payload,
"secret", "HS1024")
@@ -82,15 +107,25 @@ class TestJWT(unittest.TestCase):
def test_unicode_secret(self):
secret = u'\xc2'
jwt_message = jwt.encode(self.payload, secret)
+
decoded_payload = jwt.decode(jwt_message, secret)
self.assertEqual(decoded_payload, self.payload)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, secret)
+ self.assertEqual(decoded_payload, self.payload)
+
def test_nonascii_secret(self):
secret = '\xc2' # char value that ascii codec cannot decode
jwt_message = jwt.encode(self.payload, secret)
+
decoded_payload = jwt.decode(jwt_message, secret)
self.assertEqual(decoded_payload, self.payload)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, secret)
+ self.assertEqual(decoded_payload, self.payload)
+
def test_decode_unicode_value(self):
example_payload = {"hello": "world"}
example_secret = "secret"
@@ -100,6 +135,8 @@ class TestJWT(unittest.TestCase):
".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8")
decoded_payload = jwt.decode(example_jwt, example_secret)
self.assertEqual(decoded_payload, example_payload)
+ decoded_payload, signing_input, header, signature = jwt.load(example_jwt)
+ self.assertEqual(decoded_payload, example_payload)
def test_decode_invalid_header_padding(self):
example_jwt = (
@@ -108,6 +145,8 @@ class TestJWT(unittest.TestCase):
".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8")
example_secret = "secret"
with self.assertRaises(jwt.DecodeError):
+ jwt.load(example_jwt)
+ with self.assertRaises(jwt.DecodeError):
jwt.decode(example_jwt, example_secret)
def test_decode_invalid_header_string(self):
@@ -117,6 +156,8 @@ class TestJWT(unittest.TestCase):
".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8")
example_secret = "secret"
with self.assertRaisesRegexp(jwt.DecodeError, "Invalid header string"):
+ jwt.load(example_jwt)
+ with self.assertRaisesRegexp(jwt.DecodeError, "Invalid header string"):
jwt.decode(example_jwt, example_secret)
def test_decode_invalid_payload_padding(self):
@@ -126,6 +167,8 @@ class TestJWT(unittest.TestCase):
".tvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8")
example_secret = "secret"
with self.assertRaises(jwt.DecodeError):
+ jwt.load(example_jwt)
+ with self.assertRaises(jwt.DecodeError):
jwt.decode(example_jwt, example_secret)
def test_decode_invalid_payload_string(self):
@@ -136,6 +179,9 @@ class TestJWT(unittest.TestCase):
example_secret = "secret"
with self.assertRaisesRegexp(jwt.DecodeError,
"Invalid payload string"):
+ jwt.load(example_jwt)
+ with self.assertRaisesRegexp(jwt.DecodeError,
+ "Invalid payload string"):
jwt.decode(example_jwt, example_secret)
def test_decode_invalid_crypto_padding(self):
@@ -145,33 +191,51 @@ class TestJWT(unittest.TestCase):
".aatvagLDLoaiJKxOKqpBXSEGy7SYSifZhjntgm9ctpyj8")
example_secret = "secret"
with self.assertRaises(jwt.DecodeError):
+ jwt.load(example_jwt)
+ with self.assertRaises(jwt.DecodeError):
jwt.decode(example_jwt, example_secret)
def test_decode_with_expiration(self):
self.payload['exp'] = utc_timestamp() - 1
secret = 'secret'
jwt_message = jwt.encode(self.payload, secret)
+
with self.assertRaises(jwt.ExpiredSignature):
jwt.decode(jwt_message, secret)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+ with self.assertRaises(jwt.ExpiredSignature):
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, secret)
+
def test_decode_skip_expiration_verification(self):
self.payload['exp'] = time.time() - 1
secret = 'secret'
jwt_message = jwt.encode(self.payload, secret)
+
jwt.decode(jwt_message, secret, verify_expiration=False)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, secret, verify_expiration=False)
+
def test_decode_with_expiration_with_leeway(self):
self.payload['exp'] = utc_timestamp() - 2
secret = 'secret'
jwt_message = jwt.encode(self.payload, secret)
+ decoded_payload, signing_input, header, signature = jwt.load(jwt_message)
+
# With 3 seconds leeway, should be ok
jwt.decode(jwt_message, secret, leeway=3)
- # With 1 secondes, should fail
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, secret, leeway=3)
+
+ # With 1 second, should fail
with self.assertRaises(jwt.ExpiredSignature):
jwt.decode(jwt_message, secret, leeway=1)
+ with self.assertRaises(jwt.ExpiredSignature):
+ jwt.verify_signature(decoded_payload, signing_input, header, signature, secret, leeway=1)
+
def test_encode_decode_with_rsa_sha256(self):
try:
from Crypto.PublicKey import RSA
@@ -183,6 +247,9 @@ class TestJWT(unittest.TestCase):
with open('tests/testkey.pub','r') as rsa_pub_file:
pub_rsakey = RSA.importKey(rsa_pub_file.read())
assert jwt.decode(jwt_message, pub_rsakey)
+
+ load_output = jwt.load(jwt_message)
+ jwt.verify_signature(key=pub_rsakey, *load_output)
except ImportError:
pass
@@ -197,6 +264,9 @@ class TestJWT(unittest.TestCase):
with open('tests/testkey.pub','r') as rsa_pub_file:
pub_rsakey = RSA.importKey(rsa_pub_file.read())
assert jwt.decode(jwt_message, pub_rsakey)
+
+ load_output = jwt.load(jwt_message)
+ jwt.verify_signature(key=pub_rsakey, *load_output)
except ImportError:
pass
@@ -211,6 +281,9 @@ class TestJWT(unittest.TestCase):
with open('tests/testkey.pub','r') as rsa_pub_file:
pub_rsakey = RSA.importKey(rsa_pub_file.read())
assert jwt.decode(jwt_message, pub_rsakey)
+
+ load_output = jwt.load(jwt_message)
+ jwt.verify_signature(key=pub_rsakey, *load_output)
except ImportError:
pass