summaryrefslogtreecommitdiff
path: root/jwt/__init__.py
blob: 2cd16ca9d045bffe85778a79d9cef126d205bd0b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
""" JSON Web Token implementation

Minimum implementation based on this spec:
http://self-issued.info/docs/draft-jones-json-web-token-01.html
"""
import base64
import hashlib
import hmac

from time import time
from datetime import datetime
from calendar import timegm

try:
    import json
except ImportError:
    import simplejson as json

__all__ = ['encode', 'decode', 'DecodeError']


class DecodeError(Exception):
    pass


class ExpiredSignature(Exception):
    pass


signing_methods = {
    'HS256': lambda msg, key: hmac.new(key, msg, hashlib.sha256).digest(),
    'HS384': lambda msg, key: hmac.new(key, msg, hashlib.sha384).digest(),
    'HS512': lambda msg, key: hmac.new(key, msg, hashlib.sha512).digest(),
}


def constant_time_compare(val1, val2):
    """
    Returns True if the two strings are equal, False otherwise.

    The time taken is independent of the number of characters that match.
    """
    if len(val1) != len(val2):
        return False
    result = 0
    for x, y in zip(val1, val2):
        result |= ord(x) ^ ord(y)
    return result == 0


def base64url_decode(input):
    rem = len(input) % 4
    if rem > 0:
        input += '=' * (4 - rem)
    return base64.urlsafe_b64decode(input)


def base64url_encode(input):
    return base64.urlsafe_b64encode(input).replace('=', '')


def header(jwt):
    header_segment = jwt.split('.', 1)[0]
    try:
        return json.loads(base64url_decode(header_segment))
    except (ValueError, TypeError):
        raise DecodeError("Invalid header encoding")


def encode(payload, key, algorithm='HS256'):
    segments = []

    # Header
    header = {"typ": "JWT", "alg": algorithm}
    segments.append(base64url_encode(json.dumps(header)))

    # Payload
    if isinstance(payload.get('exp'), datetime):
        payload['exp'] = timegm(payload['exp'].utctimetuple())
    segments.append(base64url_encode(json.dumps(payload)))

    # Segments
    signing_input = '.'.join(segments)
    try:
        if isinstance(key, unicode):
            key = key.encode('utf-8')
        signature = signing_methods[algorithm](signing_input, key)
    except KeyError:
        raise NotImplementedError("Algorithm not supported")
    segments.append(base64url_encode(signature))
    return '.'.join(segments)


def decode(jwt, key='', verify=True, verify_expiration=True, leeway=0):
    try:
        signing_input, crypto_segment = str(jwt).rsplit('.', 1)
        header_segment, payload_segment = signing_input.split('.', 1)
    except ValueError:
        raise DecodeError("Not enough segments")

    try:
        header = json.loads(base64url_decode(header_segment))
    except TypeError:
        raise DecodeError("Invalid header padding")
    except ValueError as e:
        raise DecodeError("Invalid header string: %s" % e)

    try:
        payload = json.loads(base64url_decode(payload_segment))
    except TypeError:
        raise DecodeError("Invalid payload padding")
    except ValueError as e:
        raise DecodeError("Invalid payload string: %s" % e)

    try:
        signature = base64url_decode(crypto_segment)
    except TypeError:
        raise DecodeError("Invalid crypto padding")

    if verify:
        try:
            if isinstance(key, unicode):
                key = key.encode('utf-8')
            expected = signing_methods[header['alg']](signing_input, key)
            if not constant_time_compare(signature, expected):
                raise DecodeError("Signature verification failed")
        except KeyError:
            raise DecodeError("Algorithm not supported")

        if 'exp' in payload and verify_expiration:
            utc_timestamp = timegm(datetime.utcnow().utctimetuple())
            if payload['exp'] < (utc_timestamp - leeway):
                raise ExpiredSignature("Signature has expired")
    return payload