summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDana Powers <dana.powers@gmail.com>2019-12-29 15:40:28 -0800
committerDana Powers <dana.powers@gmail.com>2019-12-29 15:45:02 -0800
commite3362aca8c12a07ebe88575b073c91475585f21d (patch)
tree7705b32f02f284c40853483adc18e79fe14d49a0
parentee1c4a42ef3c7f0aa7c98f0c48b6ab0ae76d77da (diff)
downloadkafka-python-e3362aca8c12a07ebe88575b073c91475585f21d.tar.gz
Style updates to scram sasl support
-rw-r--r--kafka/conn.py83
-rw-r--r--kafka/scram.py82
-rw-r--r--test/fixtures.py10
3 files changed, 93 insertions, 82 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index e4938c7..dfb8d78 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -1,16 +1,11 @@
from __future__ import absolute_import, division
-import base64
import copy
import errno
-import hashlib
-import hmac
import io
import logging
from random import shuffle, uniform
-from uuid import uuid4
-
# selectors in stdlib as of py3.4
try:
import selectors # pylint: disable=import-error
@@ -34,6 +29,7 @@ from kafka.protocol.commit import OffsetFetchRequest
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.parser import KafkaProtocol
from kafka.protocol.types import Int32, Int8
+from kafka.scram import ScramClient
from kafka.version import __version__
@@ -42,12 +38,6 @@ if six.PY2:
TimeoutError = socket.error
BlockingIOError = Exception
- def xor_bytes(left, right):
- return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
-else:
- def xor_bytes(left, right):
- return bytes(lb ^ rb for lb, rb in zip(left, right))
-
log = logging.getLogger(__name__)
DEFAULT_KAFKA_PORT = 9092
@@ -107,69 +97,6 @@ class ConnectionStates(object):
AUTHENTICATING = '<authenticating>'
-class ScramClient:
- MECHANISMS = {
- 'SCRAM-SHA-256': hashlib.sha256,
- 'SCRAM-SHA-512': hashlib.sha512
- }
-
- def __init__(self, user, password, mechanism):
- self.nonce = str(uuid4()).replace('-', '')
- self.auth_message = ''
- self.salted_password = None
- self.user = user
- self.password = password.encode()
- self.hashfunc = self.MECHANISMS[mechanism]
- self.hashname = ''.join(mechanism.lower().split('-')[1:3])
- self.stored_key = None
- self.client_key = None
- self.client_signature = None
- self.client_proof = None
- self.server_key = None
- self.server_signature = None
-
- def first_message(self):
- client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
- self.auth_message += client_first_bare
- return 'n,,' + client_first_bare
-
- def process_server_first_message(self, server_first_message):
- self.auth_message += ',' + server_first_message
- params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
- server_nonce = params['r']
- if not server_nonce.startswith(self.nonce):
- raise ValueError("Server nonce, did not start with client nonce!")
- self.nonce = server_nonce
- self.auth_message += ',c=biws,r=' + self.nonce
-
- salt = base64.b64decode(params['s'].encode())
- iterations = int(params['i'])
- self.create_salted_password(salt, iterations)
-
- self.client_key = self.hmac(self.salted_password, b'Client Key')
- self.stored_key = self.hashfunc(self.client_key).digest()
- self.client_signature = self.hmac(self.stored_key, self.auth_message.encode())
- self.client_proof = xor_bytes(self.client_key, self.client_signature)
- self.server_key = self.hmac(self.salted_password, b'Server Key')
- self.server_signature = self.hmac(self.server_key, self.auth_message.encode())
-
- def hmac(self, key, msg):
- return hmac.new(key, msg, digestmod=self.hashfunc).digest()
-
- def create_salted_password(self, salt, iterations):
- self.salted_password = hashlib.pbkdf2_hmac(
- self.hashname, self.password, salt, iterations
- )
-
- def final_message(self):
- client_final_no_proof = 'c=biws,r=' + self.nonce
- return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode())
-
- def process_server_final_message(self, server_final_message):
- params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
- if self.server_signature != base64.b64decode(params['v'].encode()):
- raise ValueError("Server sent wrong signature!")
-
class BrokerConnection(object):
"""Initialize a Kafka broker connection
@@ -747,20 +674,20 @@ class BrokerConnection(object):
close = False
else:
try:
- client_first = scram_client.first_message().encode()
+ client_first = scram_client.first_message().encode('utf-8')
size = Int32.encode(len(client_first))
self._send_bytes_blocking(size + client_first)
(data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
- server_first = self._recv_bytes_blocking(data_len).decode()
+ server_first = self._recv_bytes_blocking(data_len).decode('utf-8')
scram_client.process_server_first_message(server_first)
- client_final = scram_client.final_message().encode()
+ client_final = scram_client.final_message().encode('utf-8')
size = Int32.encode(len(client_final))
self._send_bytes_blocking(size + client_final)
(data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
- server_final = self._recv_bytes_blocking(data_len).decode()
+ server_final = self._recv_bytes_blocking(data_len).decode('utf-8')
scram_client.process_server_final_message(server_final)
except (ConnectionError, TimeoutError) as e:
diff --git a/kafka/scram.py b/kafka/scram.py
new file mode 100644
index 0000000..684925c
--- /dev/null
+++ b/kafka/scram.py
@@ -0,0 +1,82 @@
+from __future__ import absolute_import
+
+import base64
+import hashlib
+import hmac
+import uuid
+
+from kafka.vendor import six
+
+
+if six.PY2:
+ def xor_bytes(left, right):
+ return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
+else:
+ def xor_bytes(left, right):
+ return bytes(lb ^ rb for lb, rb in zip(left, right))
+
+
+class ScramClient:
+ MECHANISMS = {
+ 'SCRAM-SHA-256': hashlib.sha256,
+ 'SCRAM-SHA-512': hashlib.sha512
+ }
+
+ def __init__(self, user, password, mechanism):
+ self.nonce = str(uuid.uuid4()).replace('-', '')
+ self.auth_message = ''
+ self.salted_password = None
+ self.user = user
+ self.password = password.encode('utf-8')
+ self.hashfunc = self.MECHANISMS[mechanism]
+ self.hashname = ''.join(mechanism.lower().split('-')[1:3])
+ self.stored_key = None
+ self.client_key = None
+ self.client_signature = None
+ self.client_proof = None
+ self.server_key = None
+ self.server_signature = None
+
+ def first_message(self):
+ client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
+ self.auth_message += client_first_bare
+ return 'n,,' + client_first_bare
+
+ def process_server_first_message(self, server_first_message):
+ self.auth_message += ',' + server_first_message
+ params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
+ server_nonce = params['r']
+ if not server_nonce.startswith(self.nonce):
+ raise ValueError("Server nonce, did not start with client nonce!")
+ self.nonce = server_nonce
+ self.auth_message += ',c=biws,r=' + self.nonce
+
+ salt = base64.b64decode(params['s'].encode('utf-8'))
+ iterations = int(params['i'])
+ self.create_salted_password(salt, iterations)
+
+ self.client_key = self.hmac(self.salted_password, b'Client Key')
+ self.stored_key = self.hashfunc(self.client_key).digest()
+ self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8'))
+ self.client_proof = xor_bytes(self.client_key, self.client_signature)
+ self.server_key = self.hmac(self.salted_password, b'Server Key')
+ self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8'))
+
+ def hmac(self, key, msg):
+ return hmac.new(key, msg, digestmod=self.hashfunc).digest()
+
+ def create_salted_password(self, salt, iterations):
+ self.salted_password = hashlib.pbkdf2_hmac(
+ self.hashname, self.password, salt, iterations
+ )
+
+ def final_message(self):
+ client_final_no_proof = 'c=biws,r=' + self.nonce
+ return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8'))
+
+ def process_server_final_message(self, server_final_message):
+ params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
+ if self.server_signature != base64.b64decode(params['v'].encode('utf-8')):
+ raise ValueError("Server sent wrong signature!")
+
+
diff --git a/test/fixtures.py b/test/fixtures.py
index 78cdc5c..26fb5e8 100644
--- a/test/fixtures.py
+++ b/test/fixtures.py
@@ -318,8 +318,10 @@ class KafkaFixture(Fixture):
if not self.sasl_enabled:
return ''
- sasl_config = "sasl.enabled.mechanisms={mechanism}\n"
- sasl_config += "sasl.mechanism.inter.broker.protocol={mechanism}\n"
+ sasl_config = (
+ 'sasl.enabled.mechanisms={mechanism}\n'
+ 'sasl.mechanism.inter.broker.protocol={mechanism}\n'
+ )
return sasl_config.format(mechanism=self.sasl_mechanism)
def _jaas_config(self):
@@ -328,12 +330,12 @@ class KafkaFixture(Fixture):
elif self.sasl_mechanism == 'PLAIN':
jaas_config = (
- "org.apache.kafka.common.security.plain.PlainLoginModule required\n"
+ 'org.apache.kafka.common.security.plain.PlainLoginModule required\n'
' username="{user}" password="{password}" user_{user}="{password}";\n'
)
elif self.sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512"):
jaas_config = (
- "org.apache.kafka.common.security.scram.ScramLoginModule required\n"
+ 'org.apache.kafka.common.security.scram.ScramLoginModule required\n'
' username="{user}" password="{password}";\n'
)
else: