diff options
Diffstat (limited to 'kazoo/protocol/connection.py')
-rw-r--r-- | kazoo/protocol/connection.py | 224 |
1 files changed, 128 insertions, 96 deletions
diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 67d5789..b4320f2 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -9,12 +9,15 @@ import socket import sys import time +import six + from kazoo.exceptions import ( AuthFailedError, ConnectionDropped, EXCEPTIONS, SessionExpiredError, - NoNodeError + NoNodeError, + SASLException, ) from kazoo.loggingsupport import BLATHER from kazoo.protocol.serialization import ( @@ -30,7 +33,7 @@ from kazoo.protocol.serialization import ( SASL, Transaction, Watch, - int_struct + int_struct, ) from kazoo.protocol.states import ( Callback, @@ -40,10 +43,12 @@ from kazoo.protocol.states import ( ) from kazoo.retry import ( ForceRetryError, - RetryFailedError + RetryFailedError, ) + try: - from puresasl.client import SASLClient + import puresasl + import puresasl.client PURESASL_AVAILABLE = True except ImportError: PURESASL_AVAILABLE = False @@ -139,7 +144,7 @@ class RWServerAvailable(Exception): class ConnectionHandler(object): """Zookeeper connection handler""" - def __init__(self, client, retry_sleeper, logger=None): + def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.client = client self.handler = client.handler self.retry_sleeper = retry_sleeper @@ -159,10 +164,10 @@ class ConnectionHandler(object): self._xid = None self._rw_server = None self._ro_mode = False - self._ro = False self._connection_routine = None + self.sasl_options = sasl_options self.sasl_cli = None # This is instance specific to avoid odd thread bug issues in Python @@ -232,8 +237,8 @@ class ConnectionHandler(object): # have anything to select, but the wrapped object may still # have something to read as it has previously gotten enough # data from the underlying socket. - if (hasattr(self._socket, "pending") - and self._socket.pending() > 0): + if (hasattr(self._socket, "pending") and + self._socket.pending() > 0): pass else: s = self.handler.select([self._socket], [], [], timeout)[0] @@ -427,24 +432,6 @@ class ConnectionHandler(object): async_object.set(True) elif header.xid == WATCH_XID: self._read_watch_event(buffer, offset) - elif self.sasl_cli and not self.sasl_cli.complete: - # SASL authentication is not yet finished, this can only - # be a SASL packet - self.logger.log(BLATHER, 'Received SASL') - try: - challenge, _ = SASL.deserialize(buffer, offset) - except Exception: - raise ConnectionDropped('error while SASL authentication.') - response = self.sasl_cli.process(challenge) - if response: - # authentication not yet finished, answering the challenge - self._send_sasl_request(challenge=response, - timeout=client._session_timeout) - else: - # authentication is ok, state is CONNECTED or CONNECTED_RO - # remove sensible information from the object - self._set_connected_ro_or_rw(client) - self.sasl_cli.dispose() else: self.logger.log(BLATHER, 'Reading for header %r', header) @@ -522,12 +509,13 @@ class ConnectionHandler(object): host_ports = [] for host, port in self.client.hosts: try: - for rhost in socket.getaddrinfo(host.strip(), port, 0, 0, + host = host.strip() + for rhost in socket.getaddrinfo(host, port, 0, 0, socket.IPPROTO_TCP): - host_ports.append((rhost[4][0], rhost[4][1])) + host_ports.append((host, rhost[4][0], rhost[4][1])) except socket.gaierror as e: # Skip hosts that don't resolve - self.logger.warning("Cannot resolve %s: %s", host.strip(), e) + self.logger.warning("Cannot resolve %s: %s", host, e) pass if self.client.randomize_hosts: random.shuffle(host_ports) @@ -542,11 +530,11 @@ class ConnectionHandler(object): if len(host_ports) == 0: return STOP_CONNECTING - for host, port in host_ports: + for host, hostip, port in host_ports: if self.client._stopped.is_set(): status = STOP_CONNECTING break - status = self._connect_attempt(host, port, retry) + status = self._connect_attempt(host, hostip, port, retry) if status is STOP_CONNECTING: break @@ -555,7 +543,7 @@ class ConnectionHandler(object): else: raise ForceRetryError('Reconnecting') - def _connect_attempt(self, host, port, retry): + def _connect_attempt(self, host, hostip, port, retry): client = self.client KazooTimeoutError = self.handler.timeout_exception close_connection = False @@ -574,7 +562,7 @@ class ConnectionHandler(object): try: self._xid = 0 - read_timeout, connect_timeout = self._connect(host, port) + read_timeout, connect_timeout = self._connect(host, hostip, port) read_timeout = read_timeout / 1000.0 connect_timeout = connect_timeout / 1000.0 retry.reset() @@ -611,9 +599,9 @@ class ConnectionHandler(object): if client._state != KeeperState.CONNECTING: self.logger.warning("Transition to CONNECTING") client._session_callback(KeeperState.CONNECTING) - except AuthFailedError: + except AuthFailedError as err: retry.reset() - self.logger.warning('AUTH_FAILED closing') + self.logger.warning('AUTH_FAILED closing: %s', err) client._session_callback(KeeperState.AUTH_FAILED) return STOP_CONNECTING except SessionExpiredError: @@ -631,10 +619,10 @@ class ConnectionHandler(object): if self._socket is not None: self._socket.close() - def _connect(self, host, port): + def _connect(self, host, hostip, port): client = self.client - self.logger.info('Connecting to %s:%s, use_ssl: %r', - host, port, self.client.use_ssl) + self.logger.info('Connecting to %s(%s):%s, use_ssl: %r', + host, hostip, port, self.client.use_ssl) self.logger.log(BLATHER, ' Using session_id: %r session_passwd: %s', @@ -643,7 +631,7 @@ class ConnectionHandler(object): with self._socket_error_handling(): self._socket = self.handler.create_connection( - address=(host, port), + address=(hostip, port), timeout=client._session_timeout / 1000.0, use_ssl=self.client.use_ssl, keyfile=self.client.keyfile, @@ -686,68 +674,112 @@ class ConnectionHandler(object): read_timeout) if connect_result.read_only: - self._ro = True + client._session_callback(KeeperState.CONNECTED_RO) + self._ro_mode = iter(self._server_pinger()) + else: + client._session_callback(KeeperState.CONNECTED) + self._ro_mode = None + + if self.sasl_options is not None: + self._authenticate_with_sasl(host, connect_timeout / 1000.0) # Get a copy of the auth data before iterating, in case it is # changed. client_auth_data_copy = copy.copy(client.auth_data) - if client.use_sasl and self.sasl_cli is None: - if PURESASL_AVAILABLE: - for scheme, auth in client_auth_data_copy: - if scheme == 'sasl': - username, password = auth.split(":") - self.sasl_cli = SASLClient( - host=client.sasl_server_principal, - service='zookeeper', - mechanism='DIGEST-MD5', - username=username, - password=password - ) - break - - # As described in rfc - # https://tools.ietf.org/html/rfc2831#section-2.1 - # sending empty challenge - self._send_sasl_request(challenge=b'', - timeout=connect_timeout) - else: - self.logger.warn('Pure-sasl library is missing while sasl' - ' authentification is configured. Please' - ' install pure-sasl library to connect ' - 'using sasl. Now falling back ' - 'connecting WITHOUT any ' - 'authentification.') - client.use_sasl = False - self._set_connected_ro_or_rw(client) - else: - self._set_connected_ro_or_rw(client) - for scheme, auth in client_auth_data_copy: - if scheme == "digest": - ap = Auth(0, scheme, auth) - zxid = self._invoke( - connect_timeout / 1000.0, - ap, - xid=AUTH_XID - ) - if zxid: - client.last_zxid = zxid + for scheme, auth in client_auth_data_copy: + ap = Auth(0, scheme, auth) + zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID) + if zxid: + client.last_zxid = zxid return read_timeout, connect_timeout - def _send_sasl_request(self, challenge, timeout): - """ Called when sending a SASL request, xid needs be to incremented """ - sasl_request = SASL(challenge) - self._xid = (self._xid % 2147483647) + 1 - xid = self._xid - self._submit(sasl_request, timeout / 1000.0, xid) - - def _set_connected_ro_or_rw(self, client): - """ Called to decide whether to set the KeeperState to CONNECTED_RO - or CONNECTED""" - if self._ro: - client._session_callback(KeeperState.CONNECTED_RO) - self._ro_mode = iter(self._server_pinger()) - else: - client._session_callback(KeeperState.CONNECTED) - self._ro_mode = None + def _authenticate_with_sasl(self, host, timeout): + """Establish a SASL authenticated connection to the server. + """ + if not PURESASL_AVAILABLE: + raise SASLException('Missing SASL support') + + if 'service' not in self.sasl_options: + self.sasl_options['service'] = 'zookeeper' + + # NOTE: Zookeeper hardcoded the domain for Digest authentication + # instead of using the hostname. See + # zookeeper/util/SecurityUtils.java#L74 and Server/Client + # initializations. + if self.sasl_options['mechanism'] == 'DIGEST-MD5': + host = 'zk-sasl-md5' + + sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( + host=host, + **self.sasl_options + ) + + # Inititalize the process with an empty challenge token + challenge = None + xid = 0 + + while True: + if sasl_cli.complete: + break + + try: + response = sasl_cli.process(challenge=challenge) + except puresasl.SASLError as err: + six.reraise( + SASLException, + SASLException('library error: %s' % err.message), + sys.exc_info()[2] + ) + except puresasl.SASLProtocolException as err: + six.reraise( + AuthFailedError, + AuthFailedError('protocol error: %s' % err.message), + sys.exc_info()[2] + ) + except Exception as err: + six.reraise( + AuthFailedError, + AuthFailedError('Unknown error: %s' % err), + sys.exc_info()[2] + ) + + if sasl_cli.complete and not response: + break + elif response is None: + response = b'' + + xid = (xid % 2147483647) + 1 + + request = SASL(response) + self._submit(request, timeout, xid) + + try: + header, buffer, offset = self._read_header(timeout) + except ConnectionDropped: + # Zookeeper simply drops connections with failed authentication + six.reraise( + AuthFailedError, + AuthFailedError('Connection dropped in SASL'), + sys.exc_info()[2] + ) + + if header.xid != xid: + raise RuntimeError('xids do not match, expected %r ' + 'received %r', xid, header.xid) + + if header.zxid > 0: + self.client.last_zxid = header.zxid + + if header.err: + callback_exception = EXCEPTIONS[header.err]() + self.logger.debug( + 'Received error(xid=%s) %r', xid, callback_exception) + raise callback_exception + + challenge, _ = SASL.deserialize(buffer, offset) + + # If we made it here, authentication is ok, and we are connected. + # Remove sensible information from the object. + sasl_cli.dispose() |