summaryrefslogtreecommitdiff
path: root/kazoo/protocol/connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'kazoo/protocol/connection.py')
-rw-r--r--kazoo/protocol/connection.py224
1 files changed, 128 insertions, 96 deletions
diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py
index 67d5789..b4320f2 100644
--- a/kazoo/protocol/connection.py
+++ b/kazoo/protocol/connection.py
@@ -9,12 +9,15 @@ import socket
import sys
import time
+import six
+
from kazoo.exceptions import (
AuthFailedError,
ConnectionDropped,
EXCEPTIONS,
SessionExpiredError,
- NoNodeError
+ NoNodeError,
+ SASLException,
)
from kazoo.loggingsupport import BLATHER
from kazoo.protocol.serialization import (
@@ -30,7 +33,7 @@ from kazoo.protocol.serialization import (
SASL,
Transaction,
Watch,
- int_struct
+ int_struct,
)
from kazoo.protocol.states import (
Callback,
@@ -40,10 +43,12 @@ from kazoo.protocol.states import (
)
from kazoo.retry import (
ForceRetryError,
- RetryFailedError
+ RetryFailedError,
)
+
try:
- from puresasl.client import SASLClient
+ import puresasl
+ import puresasl.client
PURESASL_AVAILABLE = True
except ImportError:
PURESASL_AVAILABLE = False
@@ -139,7 +144,7 @@ class RWServerAvailable(Exception):
class ConnectionHandler(object):
"""Zookeeper connection handler"""
- def __init__(self, client, retry_sleeper, logger=None):
+ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None):
self.client = client
self.handler = client.handler
self.retry_sleeper = retry_sleeper
@@ -159,10 +164,10 @@ class ConnectionHandler(object):
self._xid = None
self._rw_server = None
self._ro_mode = False
- self._ro = False
self._connection_routine = None
+ self.sasl_options = sasl_options
self.sasl_cli = None
# This is instance specific to avoid odd thread bug issues in Python
@@ -232,8 +237,8 @@ class ConnectionHandler(object):
# have anything to select, but the wrapped object may still
# have something to read as it has previously gotten enough
# data from the underlying socket.
- if (hasattr(self._socket, "pending")
- and self._socket.pending() > 0):
+ if (hasattr(self._socket, "pending") and
+ self._socket.pending() > 0):
pass
else:
s = self.handler.select([self._socket], [], [], timeout)[0]
@@ -427,24 +432,6 @@ class ConnectionHandler(object):
async_object.set(True)
elif header.xid == WATCH_XID:
self._read_watch_event(buffer, offset)
- elif self.sasl_cli and not self.sasl_cli.complete:
- # SASL authentication is not yet finished, this can only
- # be a SASL packet
- self.logger.log(BLATHER, 'Received SASL')
- try:
- challenge, _ = SASL.deserialize(buffer, offset)
- except Exception:
- raise ConnectionDropped('error while SASL authentication.')
- response = self.sasl_cli.process(challenge)
- if response:
- # authentication not yet finished, answering the challenge
- self._send_sasl_request(challenge=response,
- timeout=client._session_timeout)
- else:
- # authentication is ok, state is CONNECTED or CONNECTED_RO
- # remove sensible information from the object
- self._set_connected_ro_or_rw(client)
- self.sasl_cli.dispose()
else:
self.logger.log(BLATHER, 'Reading for header %r', header)
@@ -522,12 +509,13 @@ class ConnectionHandler(object):
host_ports = []
for host, port in self.client.hosts:
try:
- for rhost in socket.getaddrinfo(host.strip(), port, 0, 0,
+ host = host.strip()
+ for rhost in socket.getaddrinfo(host, port, 0, 0,
socket.IPPROTO_TCP):
- host_ports.append((rhost[4][0], rhost[4][1]))
+ host_ports.append((host, rhost[4][0], rhost[4][1]))
except socket.gaierror as e:
# Skip hosts that don't resolve
- self.logger.warning("Cannot resolve %s: %s", host.strip(), e)
+ self.logger.warning("Cannot resolve %s: %s", host, e)
pass
if self.client.randomize_hosts:
random.shuffle(host_ports)
@@ -542,11 +530,11 @@ class ConnectionHandler(object):
if len(host_ports) == 0:
return STOP_CONNECTING
- for host, port in host_ports:
+ for host, hostip, port in host_ports:
if self.client._stopped.is_set():
status = STOP_CONNECTING
break
- status = self._connect_attempt(host, port, retry)
+ status = self._connect_attempt(host, hostip, port, retry)
if status is STOP_CONNECTING:
break
@@ -555,7 +543,7 @@ class ConnectionHandler(object):
else:
raise ForceRetryError('Reconnecting')
- def _connect_attempt(self, host, port, retry):
+ def _connect_attempt(self, host, hostip, port, retry):
client = self.client
KazooTimeoutError = self.handler.timeout_exception
close_connection = False
@@ -574,7 +562,7 @@ class ConnectionHandler(object):
try:
self._xid = 0
- read_timeout, connect_timeout = self._connect(host, port)
+ read_timeout, connect_timeout = self._connect(host, hostip, port)
read_timeout = read_timeout / 1000.0
connect_timeout = connect_timeout / 1000.0
retry.reset()
@@ -611,9 +599,9 @@ class ConnectionHandler(object):
if client._state != KeeperState.CONNECTING:
self.logger.warning("Transition to CONNECTING")
client._session_callback(KeeperState.CONNECTING)
- except AuthFailedError:
+ except AuthFailedError as err:
retry.reset()
- self.logger.warning('AUTH_FAILED closing')
+ self.logger.warning('AUTH_FAILED closing: %s', err)
client._session_callback(KeeperState.AUTH_FAILED)
return STOP_CONNECTING
except SessionExpiredError:
@@ -631,10 +619,10 @@ class ConnectionHandler(object):
if self._socket is not None:
self._socket.close()
- def _connect(self, host, port):
+ def _connect(self, host, hostip, port):
client = self.client
- self.logger.info('Connecting to %s:%s, use_ssl: %r',
- host, port, self.client.use_ssl)
+ self.logger.info('Connecting to %s(%s):%s, use_ssl: %r',
+ host, hostip, port, self.client.use_ssl)
self.logger.log(BLATHER,
' Using session_id: %r session_passwd: %s',
@@ -643,7 +631,7 @@ class ConnectionHandler(object):
with self._socket_error_handling():
self._socket = self.handler.create_connection(
- address=(host, port),
+ address=(hostip, port),
timeout=client._session_timeout / 1000.0,
use_ssl=self.client.use_ssl,
keyfile=self.client.keyfile,
@@ -686,68 +674,112 @@ class ConnectionHandler(object):
read_timeout)
if connect_result.read_only:
- self._ro = True
+ client._session_callback(KeeperState.CONNECTED_RO)
+ self._ro_mode = iter(self._server_pinger())
+ else:
+ client._session_callback(KeeperState.CONNECTED)
+ self._ro_mode = None
+
+ if self.sasl_options is not None:
+ self._authenticate_with_sasl(host, connect_timeout / 1000.0)
# Get a copy of the auth data before iterating, in case it is
# changed.
client_auth_data_copy = copy.copy(client.auth_data)
- if client.use_sasl and self.sasl_cli is None:
- if PURESASL_AVAILABLE:
- for scheme, auth in client_auth_data_copy:
- if scheme == 'sasl':
- username, password = auth.split(":")
- self.sasl_cli = SASLClient(
- host=client.sasl_server_principal,
- service='zookeeper',
- mechanism='DIGEST-MD5',
- username=username,
- password=password
- )
- break
-
- # As described in rfc
- # https://tools.ietf.org/html/rfc2831#section-2.1
- # sending empty challenge
- self._send_sasl_request(challenge=b'',
- timeout=connect_timeout)
- else:
- self.logger.warn('Pure-sasl library is missing while sasl'
- ' authentification is configured. Please'
- ' install pure-sasl library to connect '
- 'using sasl. Now falling back '
- 'connecting WITHOUT any '
- 'authentification.')
- client.use_sasl = False
- self._set_connected_ro_or_rw(client)
- else:
- self._set_connected_ro_or_rw(client)
- for scheme, auth in client_auth_data_copy:
- if scheme == "digest":
- ap = Auth(0, scheme, auth)
- zxid = self._invoke(
- connect_timeout / 1000.0,
- ap,
- xid=AUTH_XID
- )
- if zxid:
- client.last_zxid = zxid
+ for scheme, auth in client_auth_data_copy:
+ ap = Auth(0, scheme, auth)
+ zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID)
+ if zxid:
+ client.last_zxid = zxid
return read_timeout, connect_timeout
- def _send_sasl_request(self, challenge, timeout):
- """ Called when sending a SASL request, xid needs be to incremented """
- sasl_request = SASL(challenge)
- self._xid = (self._xid % 2147483647) + 1
- xid = self._xid
- self._submit(sasl_request, timeout / 1000.0, xid)
-
- def _set_connected_ro_or_rw(self, client):
- """ Called to decide whether to set the KeeperState to CONNECTED_RO
- or CONNECTED"""
- if self._ro:
- client._session_callback(KeeperState.CONNECTED_RO)
- self._ro_mode = iter(self._server_pinger())
- else:
- client._session_callback(KeeperState.CONNECTED)
- self._ro_mode = None
+ def _authenticate_with_sasl(self, host, timeout):
+ """Establish a SASL authenticated connection to the server.
+ """
+ if not PURESASL_AVAILABLE:
+ raise SASLException('Missing SASL support')
+
+ if 'service' not in self.sasl_options:
+ self.sasl_options['service'] = 'zookeeper'
+
+ # NOTE: Zookeeper hardcoded the domain for Digest authentication
+ # instead of using the hostname. See
+ # zookeeper/util/SecurityUtils.java#L74 and Server/Client
+ # initializations.
+ if self.sasl_options['mechanism'] == 'DIGEST-MD5':
+ host = 'zk-sasl-md5'
+
+ sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient(
+ host=host,
+ **self.sasl_options
+ )
+
+ # Inititalize the process with an empty challenge token
+ challenge = None
+ xid = 0
+
+ while True:
+ if sasl_cli.complete:
+ break
+
+ try:
+ response = sasl_cli.process(challenge=challenge)
+ except puresasl.SASLError as err:
+ six.reraise(
+ SASLException,
+ SASLException('library error: %s' % err.message),
+ sys.exc_info()[2]
+ )
+ except puresasl.SASLProtocolException as err:
+ six.reraise(
+ AuthFailedError,
+ AuthFailedError('protocol error: %s' % err.message),
+ sys.exc_info()[2]
+ )
+ except Exception as err:
+ six.reraise(
+ AuthFailedError,
+ AuthFailedError('Unknown error: %s' % err),
+ sys.exc_info()[2]
+ )
+
+ if sasl_cli.complete and not response:
+ break
+ elif response is None:
+ response = b''
+
+ xid = (xid % 2147483647) + 1
+
+ request = SASL(response)
+ self._submit(request, timeout, xid)
+
+ try:
+ header, buffer, offset = self._read_header(timeout)
+ except ConnectionDropped:
+ # Zookeeper simply drops connections with failed authentication
+ six.reraise(
+ AuthFailedError,
+ AuthFailedError('Connection dropped in SASL'),
+ sys.exc_info()[2]
+ )
+
+ if header.xid != xid:
+ raise RuntimeError('xids do not match, expected %r '
+ 'received %r', xid, header.xid)
+
+ if header.zxid > 0:
+ self.client.last_zxid = header.zxid
+
+ if header.err:
+ callback_exception = EXCEPTIONS[header.err]()
+ self.logger.debug(
+ 'Received error(xid=%s) %r', xid, callback_exception)
+ raise callback_exception
+
+ challenge, _ = SASL.deserialize(buffer, offset)
+
+ # If we made it here, authentication is ok, and we are connected.
+ # Remove sensible information from the object.
+ sasl_cli.dispose()