summaryrefslogtreecommitdiff
path: root/kafka/scram.py
blob: 684925caaad6423ff7c3e40842ed21f657c4877f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import absolute_import

import base64
import hashlib
import hmac
import uuid

from kafka.vendor import six


if six.PY2:
    def xor_bytes(left, right):
        return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
else:
    def xor_bytes(left, right):
        return bytes(lb ^ rb for lb, rb in zip(left, right))


class ScramClient:
    MECHANISMS = {
        'SCRAM-SHA-256': hashlib.sha256,
        'SCRAM-SHA-512': hashlib.sha512
    }

    def __init__(self, user, password, mechanism):
        self.nonce = str(uuid.uuid4()).replace('-', '')
        self.auth_message = ''
        self.salted_password = None
        self.user = user
        self.password = password.encode('utf-8')
        self.hashfunc = self.MECHANISMS[mechanism]
        self.hashname = ''.join(mechanism.lower().split('-')[1:3])
        self.stored_key = None
        self.client_key = None
        self.client_signature = None
        self.client_proof = None
        self.server_key = None
        self.server_signature = None

    def first_message(self):
        client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
        self.auth_message += client_first_bare
        return 'n,,' + client_first_bare

    def process_server_first_message(self, server_first_message):
        self.auth_message += ',' + server_first_message
        params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
        server_nonce = params['r']
        if not server_nonce.startswith(self.nonce):
            raise ValueError("Server nonce, did not start with client nonce!")
        self.nonce = server_nonce
        self.auth_message += ',c=biws,r=' + self.nonce

        salt = base64.b64decode(params['s'].encode('utf-8'))
        iterations = int(params['i'])
        self.create_salted_password(salt, iterations)

        self.client_key = self.hmac(self.salted_password, b'Client Key')
        self.stored_key = self.hashfunc(self.client_key).digest()
        self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8'))
        self.client_proof = xor_bytes(self.client_key, self.client_signature)
        self.server_key = self.hmac(self.salted_password, b'Server Key')
        self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8'))

    def hmac(self, key, msg):
        return hmac.new(key, msg, digestmod=self.hashfunc).digest()

    def create_salted_password(self, salt, iterations):
        self.salted_password = hashlib.pbkdf2_hmac(
            self.hashname, self.password, salt, iterations
        )

    def final_message(self):
        client_final_no_proof = 'c=biws,r=' + self.nonce
        return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8'))

    def process_server_final_message(self, server_final_message):
        params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
        if self.server_signature != base64.b64decode(params['v'].encode('utf-8')):
            raise ValueError("Server sent wrong signature!")