summaryrefslogtreecommitdiff
path: root/jwt/api.py
blob: e509e2385c733b8f414cb7d344b668462805266f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import binascii
import json

from calendar import timegm
from collections import Mapping
from datetime import datetime, timedelta

from .compat import string_types, text_type, timedelta_total_seconds
from .exceptions import (
    DecodeError, ExpiredSignatureError,
    InvalidAudienceError, InvalidIssuerError
)
from .utils import base64url_decode, base64url_encode


_algorithms = {}


def register_algorithm(alg_id, alg_obj):
    """
    Registers a new Algorithm for use when creating and verifying tokens.
    """
    if alg_id in _algorithms:
        raise ValueError('Algorithm already has a handler.')

    if not isinstance(alg_obj, Algorithm):
        raise TypeError('Object is not of type `Algorithm`')

    _algorithms[alg_id] = alg_obj

from jwt.algorithms import Algorithm, _register_default_algorithms  # NOQA
_register_default_algorithms()


def encode(payload, key, algorithm='HS256', headers=None, json_encoder=None):
    segments = []

    if algorithm is None:
        algorithm = 'none'

    # Check that we get a mapping
    if not isinstance(payload, Mapping):
        raise TypeError('Expecting a mapping object, as json web token only'
                        'support json objects.')

    # Header
    header = {'typ': 'JWT', 'alg': algorithm}
    if headers:
        header.update(headers)

    json_header = json.dumps(
        header,
        separators=(',', ':'),
        cls=json_encoder
    ).encode('utf-8')

    segments.append(base64url_encode(json_header))

    # Payload
    for time_claim in ['exp', 'iat', 'nbf']:
        # Convert datetime to a intDate value in known time-format claims
        if isinstance(payload.get(time_claim), datetime):
            payload[time_claim] = timegm(payload[time_claim].utctimetuple())

    json_payload = json.dumps(
        payload,
        separators=(',', ':'),
        cls=json_encoder
    ).encode('utf-8')

    segments.append(base64url_encode(json_payload))

    # Segments
    signing_input = b'.'.join(segments)
    try:
        alg_obj = _algorithms[algorithm]
        key = alg_obj.prepare_key(key)
        signature = alg_obj.sign(signing_input, key)

    except KeyError:
        raise NotImplementedError('Algorithm not supported')

    segments.append(base64url_encode(signature))

    return b'.'.join(segments)


def decode(jwt, key='', verify=True, **kwargs):
    payload, signing_input, header, signature = load(jwt)

    if verify:
        verify_signature(payload, signing_input, header, signature, key,
                         **kwargs)

    return payload


def load(jwt):
    if isinstance(jwt, text_type):
        jwt = jwt.encode('utf-8')
    try:
        signing_input, crypto_segment = jwt.rsplit(b'.', 1)
        header_segment, payload_segment = signing_input.split(b'.', 1)
    except ValueError:
        raise DecodeError('Not enough segments')

    try:
        header_data = base64url_decode(header_segment)
    except (TypeError, binascii.Error):
        raise DecodeError('Invalid header padding')
    try:
        header = json.loads(header_data.decode('utf-8'))
    except ValueError as e:
        raise DecodeError('Invalid header string: %s' % e)
    if not isinstance(header, Mapping):
        raise DecodeError('Invalid header string: must be a json object')

    try:
        payload_data = base64url_decode(payload_segment)
    except (TypeError, binascii.Error):
        raise DecodeError('Invalid payload padding')
    try:
        payload = json.loads(payload_data.decode('utf-8'))
    except ValueError as e:
        raise DecodeError('Invalid payload string: %s' % e)
    if not isinstance(payload, Mapping):
        raise DecodeError('Invalid payload string: must be a json object')

    try:
        signature = base64url_decode(crypto_segment)
    except (TypeError, binascii.Error):
        raise DecodeError('Invalid crypto padding')

    return (payload, signing_input, header, signature)


def verify_signature(payload, signing_input, header, signature, key='',
                     verify_expiration=True, leeway=0, audience=None,
                     issuer=None):

    if isinstance(leeway, timedelta):
        leeway = timedelta_total_seconds(leeway)

    if not isinstance(audience, (string_types, type(None))):
        raise TypeError('audience must be a string or None')

    try:
        alg_obj = _algorithms[header['alg']]
        key = alg_obj.prepare_key(key)

        if not alg_obj.verify(signing_input, key, signature):
            raise DecodeError('Signature verification failed')

    except KeyError:
        raise DecodeError('Algorithm not supported')

    if 'nbf' in payload and verify_expiration:
        utc_timestamp = timegm(datetime.utcnow().utctimetuple())

        if payload['nbf'] > (utc_timestamp + leeway):
            raise ExpiredSignatureError('Signature not yet valid')

    if 'exp' in payload and verify_expiration:
        utc_timestamp = timegm(datetime.utcnow().utctimetuple())

        if payload['exp'] < (utc_timestamp - leeway):
            raise ExpiredSignatureError('Signature has expired')

    if 'aud' in payload:
        audience_claims = payload['aud']
        if isinstance(audience_claims, string_types):
            audience_claims = [audience_claims]
        if not isinstance(audience_claims, list):
            raise InvalidAudienceError('Invalid claim format in token')
        if any(not isinstance(c, string_types) for c in audience_claims):
            raise InvalidAudienceError('Invalid claim format in token')
        if audience not in audience_claims:
            raise InvalidAudienceError('Invalid audience')
    elif audience is not None:
        # Application specified an audience, but it could not be
        # verified since the token does not contain a claim.
        raise InvalidAudienceError('No audience claim in token')

    if issuer is not None:
        if payload.get('iss') != issuer:
            raise InvalidIssuerError('Invalid issuer')