summaryrefslogtreecommitdiff
path: root/kafka/conn.py
diff options
context:
space:
mode:
authorSwen Wenzel <5111028+swenzel@users.noreply.github.com>2019-12-30 00:12:30 +0100
committerDana Powers <dana.powers@gmail.com>2019-12-29 15:12:30 -0800
commitee1c4a42ef3c7f0aa7c98f0c48b6ab0ae76d77da (patch)
treee536d4854bf88b0d24b7b2a2f5ee3d731eda9716 /kafka/conn.py
parent31f846c782b9dc6f2107340d269a7558e99bdfe2 (diff)
downloadkafka-python-ee1c4a42ef3c7f0aa7c98f0c48b6ab0ae76d77da.tar.gz
Enable SCRAM-SHA-256 and SCRAM-SHA-512 for sasl (#1918)
Diffstat (limited to 'kafka/conn.py')
-rw-r--r--kafka/conn.py147
1 files changed, 136 insertions, 11 deletions
diff --git a/kafka/conn.py b/kafka/conn.py
index d4c5464..e4938c7 100644
--- a/kafka/conn.py
+++ b/kafka/conn.py
@@ -1,12 +1,16 @@
from __future__ import absolute_import, division
-import collections
+import base64
import copy
import errno
+import hashlib
+import hmac
import io
import logging
from random import shuffle, uniform
+from uuid import uuid4
+
# selectors in stdlib as of py3.4
try:
import selectors # pylint: disable=import-error
@@ -16,7 +20,6 @@ except ImportError:
import socket
import struct
-import sys
import threading
import time
@@ -39,6 +42,12 @@ if six.PY2:
TimeoutError = socket.error
BlockingIOError = Exception
+ def xor_bytes(left, right):
+ return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right))
+else:
+ def xor_bytes(left, right):
+ return bytes(lb ^ rb for lb, rb in zip(left, right))
+
log = logging.getLogger(__name__)
DEFAULT_KAFKA_PORT = 9092
@@ -98,6 +107,69 @@ class ConnectionStates(object):
AUTHENTICATING = '<authenticating>'
+class ScramClient:
+ MECHANISMS = {
+ 'SCRAM-SHA-256': hashlib.sha256,
+ 'SCRAM-SHA-512': hashlib.sha512
+ }
+
+ def __init__(self, user, password, mechanism):
+ self.nonce = str(uuid4()).replace('-', '')
+ self.auth_message = ''
+ self.salted_password = None
+ self.user = user
+ self.password = password.encode()
+ self.hashfunc = self.MECHANISMS[mechanism]
+ self.hashname = ''.join(mechanism.lower().split('-')[1:3])
+ self.stored_key = None
+ self.client_key = None
+ self.client_signature = None
+ self.client_proof = None
+ self.server_key = None
+ self.server_signature = None
+
+ def first_message(self):
+ client_first_bare = 'n={},r={}'.format(self.user, self.nonce)
+ self.auth_message += client_first_bare
+ return 'n,,' + client_first_bare
+
+ def process_server_first_message(self, server_first_message):
+ self.auth_message += ',' + server_first_message
+ params = dict(pair.split('=', 1) for pair in server_first_message.split(','))
+ server_nonce = params['r']
+ if not server_nonce.startswith(self.nonce):
+ raise ValueError("Server nonce, did not start with client nonce!")
+ self.nonce = server_nonce
+ self.auth_message += ',c=biws,r=' + self.nonce
+
+ salt = base64.b64decode(params['s'].encode())
+ iterations = int(params['i'])
+ self.create_salted_password(salt, iterations)
+
+ self.client_key = self.hmac(self.salted_password, b'Client Key')
+ self.stored_key = self.hashfunc(self.client_key).digest()
+ self.client_signature = self.hmac(self.stored_key, self.auth_message.encode())
+ self.client_proof = xor_bytes(self.client_key, self.client_signature)
+ self.server_key = self.hmac(self.salted_password, b'Server Key')
+ self.server_signature = self.hmac(self.server_key, self.auth_message.encode())
+
+ def hmac(self, key, msg):
+ return hmac.new(key, msg, digestmod=self.hashfunc).digest()
+
+ def create_salted_password(self, salt, iterations):
+ self.salted_password = hashlib.pbkdf2_hmac(
+ self.hashname, self.password, salt, iterations
+ )
+
+ def final_message(self):
+ client_final_no_proof = 'c=biws,r=' + self.nonce
+ return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode())
+
+ def process_server_final_message(self, server_final_message):
+ params = dict(pair.split('=', 1) for pair in server_final_message.split(','))
+ if self.server_signature != base64.b64decode(params['v'].encode()):
+ raise ValueError("Server sent wrong signature!")
+
class BrokerConnection(object):
"""Initialize a Kafka broker connection
@@ -178,11 +250,11 @@ class BrokerConnection(object):
metric_group_prefix (str): Prefix for metric names. Default: ''
sasl_mechanism (str): Authentication mechanism when security_protocol
is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are:
- PLAIN, GSSAPI, OAUTHBEARER.
- sasl_plain_username (str): username for sasl PLAIN authentication.
- Required if sasl_mechanism is PLAIN.
- sasl_plain_password (str): password for sasl PLAIN authentication.
- Required if sasl_mechanism is PLAIN.
+ PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512.
+ sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication.
+ Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms.
+ sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication.
+ Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms.
sasl_kerberos_service_name (str): Service name to include in GSSAPI
sasl mechanism handshake. Default: 'kafka'
sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI
@@ -225,7 +297,7 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
- SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
+ SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512")
def __init__(self, host, port, afi, **configs):
self.host = host
@@ -260,9 +332,13 @@ class BrokerConnection(object):
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, (
'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS))
- if self.config['sasl_mechanism'] == 'PLAIN':
- assert self.config['sasl_plain_username'] is not None, 'sasl_plain_username required for PLAIN sasl'
- assert self.config['sasl_plain_password'] is not None, 'sasl_plain_password required for PLAIN sasl'
+ if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'):
+ assert self.config['sasl_plain_username'] is not None, (
+ 'sasl_plain_username required for PLAIN or SCRAM sasl'
+ )
+ assert self.config['sasl_plain_password'] is not None, (
+ 'sasl_plain_password required for PLAIN or SCRAM sasl'
+ )
if self.config['sasl_mechanism'] == 'GSSAPI':
assert gssapi is not None, 'GSSAPI lib not available'
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
@@ -553,6 +629,8 @@ class BrokerConnection(object):
return self._try_authenticate_gssapi(future)
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
return self._try_authenticate_oauth(future)
+ elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"):
+ return self._try_authenticate_scram(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
@@ -653,6 +731,53 @@ class BrokerConnection(object):
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)
+ def _try_authenticate_scram(self, future):
+ if self.config['security_protocol'] == 'SASL_PLAINTEXT':
+ log.warning('%s: Exchanging credentials in the clear', self)
+
+ scram_client = ScramClient(
+ self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism']
+ )
+
+ err = None
+ close = False
+ with self._lock:
+ if not self._can_send_recv():
+ err = Errors.NodeNotReadyError(str(self))
+ close = False
+ else:
+ try:
+ client_first = scram_client.first_message().encode()
+ size = Int32.encode(len(client_first))
+ self._send_bytes_blocking(size + client_first)
+
+ (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
+ server_first = self._recv_bytes_blocking(data_len).decode()
+ scram_client.process_server_first_message(server_first)
+
+ client_final = scram_client.final_message().encode()
+ size = Int32.encode(len(client_final))
+ self._send_bytes_blocking(size + client_final)
+
+ (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
+ server_final = self._recv_bytes_blocking(data_len).decode()
+ scram_client.process_server_final_message(server_final)
+
+ except (ConnectionError, TimeoutError) as e:
+ log.exception("%s: Error receiving reply from server", self)
+ err = Errors.KafkaConnectionError("%s: %s" % (self, e))
+ close = True
+
+ if err is not None:
+ if close:
+ self.close(error=err)
+ return future.failure(err)
+
+ log.info(
+ '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism']
+ )
+ return future.success(True)
+
def _try_authenticate_gssapi(self, future):
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name