summaryrefslogtreecommitdiff
path: root/jwt/utils.py
blob: 0e8f210aa914d69fd642f43808128118350514d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import base64
import binascii
import struct

try:
    from cryptography.hazmat.primitives.asymmetric.utils import (
        decode_dss_signature,
        encode_dss_signature,
    )
except ImportError:
    pass


def force_unicode(value):
    if isinstance(value, bytes):
        return value.decode("utf-8")
    elif isinstance(value, str):
        return value
    else:
        raise TypeError("Expected a string value")


def force_bytes(value):
    if isinstance(value, str):
        return value.encode("utf-8")
    elif isinstance(value, bytes):
        return value
    else:
        raise TypeError("Expected a string value")


def base64url_decode(input):
    if isinstance(input, str):
        input = input.encode("ascii")

    rem = len(input) % 4

    if rem > 0:
        input += b"=" * (4 - rem)

    return base64.urlsafe_b64decode(input)


def base64url_encode(input):
    return base64.urlsafe_b64encode(input).replace(b"=", b"")


def to_base64url_uint(val):
    if val < 0:
        raise ValueError("Must be a positive integer")

    int_bytes = bytes_from_int(val)

    if len(int_bytes) == 0:
        int_bytes = b"\x00"

    return base64url_encode(int_bytes)


def from_base64url_uint(val):
    if isinstance(val, str):
        val = val.encode("ascii")

    data = base64url_decode(val)

    buf = struct.unpack("%sB" % len(data), data)
    return int("".join(["%02x" % byte for byte in buf]), 16)


def merge_dict(original, updates):
    if not updates:
        return original

    try:
        merged_options = original.copy()
        merged_options.update(updates)
    except (AttributeError, ValueError) as e:
        raise TypeError("original and updates must be a dictionary: %s" % e)

    return merged_options


def number_to_bytes(num, num_bytes):
    padded_hex = "%0*x" % (2 * num_bytes, num)
    big_endian = binascii.a2b_hex(padded_hex.encode("ascii"))
    return big_endian


def bytes_to_number(string):
    return int(binascii.b2a_hex(string), 16)


def bytes_from_int(val):
    remaining = val
    byte_length = 0

    while remaining != 0:
        remaining = remaining >> 8
        byte_length += 1

    return val.to_bytes(byte_length, "big", signed=False)


def der_to_raw_signature(der_sig, curve):
    num_bits = curve.key_size
    num_bytes = (num_bits + 7) // 8

    r, s = decode_dss_signature(der_sig)

    return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)


def raw_to_der_signature(raw_sig, curve):
    num_bits = curve.key_size
    num_bytes = (num_bits + 7) // 8

    if len(raw_sig) != 2 * num_bytes:
        raise ValueError("Invalid signature")

    r = bytes_to_number(raw_sig[:num_bytes])
    s = bytes_to_number(raw_sig[num_bytes:])

    return encode_dss_signature(r, s)