diff options
author | Dana Powers <dana.powers@gmail.com> | 2019-12-29 15:40:28 -0800 |
---|---|---|
committer | Dana Powers <dana.powers@gmail.com> | 2019-12-29 15:45:02 -0800 |
commit | e3362aca8c12a07ebe88575b073c91475585f21d (patch) | |
tree | 7705b32f02f284c40853483adc18e79fe14d49a0 | |
parent | ee1c4a42ef3c7f0aa7c98f0c48b6ab0ae76d77da (diff) | |
download | kafka-python-e3362aca8c12a07ebe88575b073c91475585f21d.tar.gz |
Style updates to scram sasl support
-rw-r--r-- | kafka/conn.py | 83 | ||||
-rw-r--r-- | kafka/scram.py | 82 | ||||
-rw-r--r-- | test/fixtures.py | 10 |
3 files changed, 93 insertions, 82 deletions
diff --git a/kafka/conn.py b/kafka/conn.py index e4938c7..dfb8d78 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -1,16 +1,11 @@ from __future__ import absolute_import, division -import base64 import copy import errno -import hashlib -import hmac import io import logging from random import shuffle, uniform -from uuid import uuid4 - # selectors in stdlib as of py3.4 try: import selectors # pylint: disable=import-error @@ -34,6 +29,7 @@ from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.metadata import MetadataRequest from kafka.protocol.parser import KafkaProtocol from kafka.protocol.types import Int32, Int8 +from kafka.scram import ScramClient from kafka.version import __version__ @@ -42,12 +38,6 @@ if six.PY2: TimeoutError = socket.error BlockingIOError = Exception - def xor_bytes(left, right): - return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) -else: - def xor_bytes(left, right): - return bytes(lb ^ rb for lb, rb in zip(left, right)) - log = logging.getLogger(__name__) DEFAULT_KAFKA_PORT = 9092 @@ -107,69 +97,6 @@ class ConnectionStates(object): AUTHENTICATING = '<authenticating>' -class ScramClient: - MECHANISMS = { - 'SCRAM-SHA-256': hashlib.sha256, - 'SCRAM-SHA-512': hashlib.sha512 - } - - def __init__(self, user, password, mechanism): - self.nonce = str(uuid4()).replace('-', '') - self.auth_message = '' - self.salted_password = None - self.user = user - self.password = password.encode() - self.hashfunc = self.MECHANISMS[mechanism] - self.hashname = ''.join(mechanism.lower().split('-')[1:3]) - self.stored_key = None - self.client_key = None - self.client_signature = None - self.client_proof = None - self.server_key = None - self.server_signature = None - - def first_message(self): - client_first_bare = 'n={},r={}'.format(self.user, self.nonce) - self.auth_message += client_first_bare - return 'n,,' + client_first_bare - - def process_server_first_message(self, server_first_message): - self.auth_message += ',' + server_first_message - params = dict(pair.split('=', 1) for pair in server_first_message.split(',')) - server_nonce = params['r'] - if not server_nonce.startswith(self.nonce): - raise ValueError("Server nonce, did not start with client nonce!") - self.nonce = server_nonce - self.auth_message += ',c=biws,r=' + self.nonce - - salt = base64.b64decode(params['s'].encode()) - iterations = int(params['i']) - self.create_salted_password(salt, iterations) - - self.client_key = self.hmac(self.salted_password, b'Client Key') - self.stored_key = self.hashfunc(self.client_key).digest() - self.client_signature = self.hmac(self.stored_key, self.auth_message.encode()) - self.client_proof = xor_bytes(self.client_key, self.client_signature) - self.server_key = self.hmac(self.salted_password, b'Server Key') - self.server_signature = self.hmac(self.server_key, self.auth_message.encode()) - - def hmac(self, key, msg): - return hmac.new(key, msg, digestmod=self.hashfunc).digest() - - def create_salted_password(self, salt, iterations): - self.salted_password = hashlib.pbkdf2_hmac( - self.hashname, self.password, salt, iterations - ) - - def final_message(self): - client_final_no_proof = 'c=biws,r=' + self.nonce - return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode()) - - def process_server_final_message(self, server_final_message): - params = dict(pair.split('=', 1) for pair in server_final_message.split(',')) - if self.server_signature != base64.b64decode(params['v'].encode()): - raise ValueError("Server sent wrong signature!") - class BrokerConnection(object): """Initialize a Kafka broker connection @@ -747,20 +674,20 @@ class BrokerConnection(object): close = False else: try: - client_first = scram_client.first_message().encode() + client_first = scram_client.first_message().encode('utf-8') size = Int32.encode(len(client_first)) self._send_bytes_blocking(size + client_first) (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_first = self._recv_bytes_blocking(data_len).decode() + server_first = self._recv_bytes_blocking(data_len).decode('utf-8') scram_client.process_server_first_message(server_first) - client_final = scram_client.final_message().encode() + client_final = scram_client.final_message().encode('utf-8') size = Int32.encode(len(client_final)) self._send_bytes_blocking(size + client_final) (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_final = self._recv_bytes_blocking(data_len).decode() + server_final = self._recv_bytes_blocking(data_len).decode('utf-8') scram_client.process_server_final_message(server_final) except (ConnectionError, TimeoutError) as e: diff --git a/kafka/scram.py b/kafka/scram.py new file mode 100644 index 0000000..684925c --- /dev/null +++ b/kafka/scram.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import + +import base64 +import hashlib +import hmac +import uuid + +from kafka.vendor import six + + +if six.PY2: + def xor_bytes(left, right): + return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) +else: + def xor_bytes(left, right): + return bytes(lb ^ rb for lb, rb in zip(left, right)) + + +class ScramClient: + MECHANISMS = { + 'SCRAM-SHA-256': hashlib.sha256, + 'SCRAM-SHA-512': hashlib.sha512 + } + + def __init__(self, user, password, mechanism): + self.nonce = str(uuid.uuid4()).replace('-', '') + self.auth_message = '' + self.salted_password = None + self.user = user + self.password = password.encode('utf-8') + self.hashfunc = self.MECHANISMS[mechanism] + self.hashname = ''.join(mechanism.lower().split('-')[1:3]) + self.stored_key = None + self.client_key = None + self.client_signature = None + self.client_proof = None + self.server_key = None + self.server_signature = None + + def first_message(self): + client_first_bare = 'n={},r={}'.format(self.user, self.nonce) + self.auth_message += client_first_bare + return 'n,,' + client_first_bare + + def process_server_first_message(self, server_first_message): + self.auth_message += ',' + server_first_message + params = dict(pair.split('=', 1) for pair in server_first_message.split(',')) + server_nonce = params['r'] + if not server_nonce.startswith(self.nonce): + raise ValueError("Server nonce, did not start with client nonce!") + self.nonce = server_nonce + self.auth_message += ',c=biws,r=' + self.nonce + + salt = base64.b64decode(params['s'].encode('utf-8')) + iterations = int(params['i']) + self.create_salted_password(salt, iterations) + + self.client_key = self.hmac(self.salted_password, b'Client Key') + self.stored_key = self.hashfunc(self.client_key).digest() + self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8')) + self.client_proof = xor_bytes(self.client_key, self.client_signature) + self.server_key = self.hmac(self.salted_password, b'Server Key') + self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8')) + + def hmac(self, key, msg): + return hmac.new(key, msg, digestmod=self.hashfunc).digest() + + def create_salted_password(self, salt, iterations): + self.salted_password = hashlib.pbkdf2_hmac( + self.hashname, self.password, salt, iterations + ) + + def final_message(self): + client_final_no_proof = 'c=biws,r=' + self.nonce + return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8')) + + def process_server_final_message(self, server_final_message): + params = dict(pair.split('=', 1) for pair in server_final_message.split(',')) + if self.server_signature != base64.b64decode(params['v'].encode('utf-8')): + raise ValueError("Server sent wrong signature!") + + diff --git a/test/fixtures.py b/test/fixtures.py index 78cdc5c..26fb5e8 100644 --- a/test/fixtures.py +++ b/test/fixtures.py @@ -318,8 +318,10 @@ class KafkaFixture(Fixture): if not self.sasl_enabled: return '' - sasl_config = "sasl.enabled.mechanisms={mechanism}\n" - sasl_config += "sasl.mechanism.inter.broker.protocol={mechanism}\n" + sasl_config = ( + 'sasl.enabled.mechanisms={mechanism}\n' + 'sasl.mechanism.inter.broker.protocol={mechanism}\n' + ) return sasl_config.format(mechanism=self.sasl_mechanism) def _jaas_config(self): @@ -328,12 +330,12 @@ class KafkaFixture(Fixture): elif self.sasl_mechanism == 'PLAIN': jaas_config = ( - "org.apache.kafka.common.security.plain.PlainLoginModule required\n" + 'org.apache.kafka.common.security.plain.PlainLoginModule required\n' ' username="{user}" password="{password}" user_{user}="{password}";\n' ) elif self.sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512"): jaas_config = ( - "org.apache.kafka.common.security.scram.ScramLoginModule required\n" + 'org.apache.kafka.common.security.scram.ScramLoginModule required\n' ' username="{user}" password="{password}";\n' ) else: |