summaryrefslogtreecommitdiff
path: root/kazoo/protocol/connection.py
diff options
context:
space:
mode:
Diffstat (limited to 'kazoo/protocol/connection.py')
-rw-r--r--kazoo/protocol/connection.py261
1 files changed, 160 insertions, 101 deletions
diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py
index 726f645..d7d84d1 100644
--- a/kazoo/protocol/connection.py
+++ b/kazoo/protocol/connection.py
@@ -50,6 +50,7 @@ from kazoo.retry import (
try:
import puresasl
import puresasl.client
+
PURESASL_AVAILABLE = True
except ImportError:
PURESASL_AVAILABLE = False
@@ -76,12 +77,14 @@ AUTH_XID = -4
CLOSE_RESPONSE = Close.type
-if sys.version_info > (3, ): # pragma: nocover
+if sys.version_info > (3,): # pragma: nocover
+
def buffer(obj, offset=0):
return memoryview(obj)[offset:]
advance_iterator = next
else: # pragma: nocover
+
def advance_iterator(it):
return it.next()
@@ -99,6 +102,7 @@ class RWPinger(object):
the iterator will yield False if called too soon.
"""
+
def __init__(self, hosts, connection_func, socket_handling):
self.hosts = hosts
self.connection = connection_func
@@ -126,7 +130,7 @@ class RWPinger(object):
sock.sendall(b"isro")
result = sock.recv(8192)
sock.close()
- if result == b'rw':
+ if result == b"rw":
return (host, port)
else:
return False
@@ -145,6 +149,7 @@ class RWServerAvailable(Exception):
class ConnectionHandler(object):
"""Zookeeper connection handler"""
+
def __init__(self, client, retry_sleeper, logger=None, sasl_options=None):
self.client = client
self.handler = client.handler
@@ -178,7 +183,7 @@ class ConnectionHandler(object):
try:
yield
except (socket.error, select.error) as e:
- err = getattr(e, 'strerror', e)
+ err = getattr(e, "strerror", e)
raise ConnectionDropped("socket connection error: %s" % (err,))
def start(self):
@@ -188,8 +193,9 @@ class ConnectionHandler(object):
self._read_sock, self._write_sock = rw_sockets
self.connection_closed.clear()
if self._connection_routine:
- raise Exception("Unable to start, connection routine already "
- "active.")
+ raise Exception(
+ "Unable to start, connection routine already " "active."
+ )
self._connection_routine = self.handler.spawn(self.zk_loop)
def stop(self, timeout=None):
@@ -218,8 +224,11 @@ class ConnectionHandler(object):
def _server_pinger(self):
"""Returns a server pinger iterable, that will ping the next
server in the list, and apply a back-off between attempts."""
- return RWPinger(self.client.hosts, self.handler.create_connection,
- self._socket_error_handling)
+ return RWPinger(
+ self.client.hosts,
+ self.handler.create_connection,
+ self._socket_error_handling,
+ )
def _read_header(self, timeout):
b = self._read(4, timeout)
@@ -238,8 +247,10 @@ class ConnectionHandler(object):
# have anything to select, but the wrapped object may still
# have something to read as it has previously gotten enough
# data from the underlying socket.
- if (hasattr(self._socket, "pending") and
- self._socket.pending() > 0):
+ if (
+ hasattr(self._socket, "pending")
+ and self._socket.pending() > 0
+ ):
pass
else:
s = self.handler.select([self._socket], [], [], timeout)[0]
@@ -247,17 +258,20 @@ class ConnectionHandler(object):
# If the read list is empty, we got a timeout. We don't
# have to check wlist and xlist as we don't set any
raise self.handler.timeout_exception(
- "socket time-out during read")
+ "socket time-out during read"
+ )
try:
chunk = self._socket.recv(remaining)
except ssl.SSLError as e:
- if e.errno in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE):
+ if e.errno in (
+ ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE,
+ ):
continue
else:
raise
- if chunk == b'':
- raise ConnectionDropped('socket connection broken')
+ if chunk == b"":
+ raise ConnectionDropped("socket connection broken")
msgparts.append(chunk)
remaining -= len(chunk)
return b"".join(msgparts)
@@ -270,14 +284,18 @@ class ConnectionHandler(object):
if xid:
header, buffer, offset = self._read_header(timeout)
if header.xid != xid:
- raise RuntimeError('xids do not match, expected %r '
- 'received %r', xid, header.xid)
+ raise RuntimeError(
+ "xids do not match, expected %r " "received %r",
+ xid,
+ header.xid,
+ )
if header.zxid > 0:
zxid = header.zxid
if header.err:
callback_exception = EXCEPTIONS[header.err]()
self.logger.debug(
- 'Received error(xid=%s) %r', xid, callback_exception)
+ "Received error(xid=%s) %r", xid, callback_exception
+ )
raise callback_exception
return zxid
@@ -285,17 +303,19 @@ class ConnectionHandler(object):
length = int_struct.unpack(msg)[0]
msg = self._read(length, timeout)
- if hasattr(request, 'deserialize'):
+ if hasattr(request, "deserialize"):
try:
obj, _ = request.deserialize(msg, 0)
except Exception:
self.logger.exception(
"Exception raised during deserialization "
- "of request: %s", request)
+ "of request: %s",
+ request,
+ )
# raise ConnectionDropped so connect loop will retry
- raise ConnectionDropped('invalid server response')
- self.logger.log(BLATHER, 'Read response %s', obj)
+ raise ConnectionDropped("invalid server response")
+ self.logger.log(BLATHER, "Read response %s", obj)
return obj, zxid
return zxid
@@ -311,7 +331,10 @@ class ConnectionHandler(object):
b += request.serialize()
self.logger.log(
(BLATHER if isinstance(request, Ping) else logging.DEBUG),
- "Sending request(xid=%s): %s", xid, request)
+ "Sending request(xid=%s): %s",
+ xid,
+ request,
+ )
self._write(int_struct.pack(len(b)) + b, timeout)
def _write(self, msg, timeout):
@@ -324,19 +347,22 @@ class ConnectionHandler(object):
if not s: # pragma: nocover
# If the write list is empty, we got a timeout. We don't
# have to check rlist and xlist as we don't set any
- raise self.handler.timeout_exception("socket time-out"
- " during write")
+ raise self.handler.timeout_exception(
+ "socket time-out" " during write"
+ )
msg_slice = buffer(msg, sent)
try:
bytes_sent = self._socket.send(msg_slice)
except ssl.SSLError as e:
- if e.errno in (ssl.SSL_ERROR_WANT_READ,
- ssl.SSL_ERROR_WANT_WRITE):
+ if e.errno in (
+ ssl.SSL_ERROR_WANT_READ,
+ ssl.SSL_ERROR_WANT_WRITE,
+ ):
continue
else:
raise
if not bytes_sent:
- raise ConnectionDropped('socket connection broken')
+ raise ConnectionDropped("socket connection broken")
sent += bytes_sent
def _read_watch_event(self, buffer, offset):
@@ -344,7 +370,7 @@ class ConnectionHandler(object):
watch, offset = Watch.deserialize(buffer, offset)
path = watch.path
- self.logger.debug('Received EVENT: %s', watch)
+ self.logger.debug("Received EVENT: %s", watch)
watchers = []
@@ -356,7 +382,7 @@ class ConnectionHandler(object):
elif watch.type == CHILD_EVENT:
watchers.extend(client._child_watchers.pop(path, []))
else:
- self.logger.warn('Received unknown event %r', watch.type)
+ self.logger.warn("Received unknown event %r", watch.type)
return
# Strip the chroot if needed
@@ -369,7 +395,7 @@ class ConnectionHandler(object):
# Dump the watchers to the watch thread
for watch in watchers:
- client.handler.dispatch_callback(Callback('watch', watch, (ev,)))
+ client.handler.dispatch_callback(Callback("watch", watch, (ev,)))
def _read_response(self, header, buffer, offset):
client = self.client
@@ -377,20 +403,25 @@ class ConnectionHandler(object):
if header.zxid and header.zxid > 0:
client.last_zxid = header.zxid
if header.xid != xid:
- exc = RuntimeError('xids do not match, expected %r '
- 'received %r', xid, header.xid)
+ exc = RuntimeError(
+ "xids do not match, expected %r " "received %r",
+ xid,
+ header.xid,
+ )
async_object.set_exception(exc)
raise exc
# Determine if its an exists request and a no node error
- exists_error = (header.err == NoNodeError.code and
- request.type == Exists.type)
+ exists_error = (
+ header.err == NoNodeError.code and request.type == Exists.type
+ )
# Set the exception if its not an exists error
if header.err and not exists_error:
callback_exception = EXCEPTIONS[header.err]()
self.logger.debug(
- 'Received error(xid=%s) %r', xid, callback_exception)
+ "Received error(xid=%s) %r", xid, callback_exception
+ )
if async_object:
async_object.set_exception(callback_exception)
elif request and async_object:
@@ -404,11 +435,14 @@ class ConnectionHandler(object):
except Exception as exc:
self.logger.exception(
"Exception raised during deserialization "
- "of request: %s", request)
+ "of request: %s",
+ request,
+ )
async_object.set_exception(exc)
return
self.logger.debug(
- 'Received response(xid=%s): %r', xid, response)
+ "Received response(xid=%s): %r", xid, response
+ )
# We special case a Transaction as we have to unchroot things
if request.type == Transaction.type:
@@ -417,7 +451,7 @@ class ConnectionHandler(object):
async_object.set(response)
# Determine if watchers should be registered
- watcher = getattr(request, 'watcher', None)
+ watcher = getattr(request, "watcher", None)
if not client._stopped.is_set() and watcher:
if isinstance(request, (GetChildren, GetChildren2)):
client._child_watchers[request.path].add(watcher)
@@ -425,7 +459,7 @@ class ConnectionHandler(object):
client._data_watchers[request.path].add(watcher)
if isinstance(request, Close):
- self.logger.log(BLATHER, 'Read close response')
+ self.logger.log(BLATHER, "Read close response")
return CLOSE_RESPONSE
def _read_socket(self, read_timeout):
@@ -434,10 +468,10 @@ class ConnectionHandler(object):
header, buffer, offset = self._read_header(read_timeout)
if header.xid == PING_XID:
- self.logger.log(BLATHER, 'Received Ping')
+ self.logger.log(BLATHER, "Received Ping")
self.ping_outstanding.clear()
elif header.xid == AUTH_XID:
- self.logger.log(BLATHER, 'Received AUTH')
+ self.logger.log(BLATHER, "Received AUTH")
request, async_object, xid = client._pending.popleft()
if header.err:
@@ -448,7 +482,7 @@ class ConnectionHandler(object):
elif header.xid == WATCH_XID:
self._read_watch_event(buffer, offset)
else:
- self.logger.log(BLATHER, 'Reading for header %r', header)
+ self.logger.log(BLATHER, "Reading for header %r", header)
return self._read_response(header, buffer, offset)
@@ -501,7 +535,7 @@ class ConnectionHandler(object):
def zk_loop(self):
"""Main Zookeeper handling loop"""
- self.logger.log(BLATHER, 'ZK loop started')
+ self.logger.log(BLATHER, "ZK loop started")
self.connection_stopped.clear()
@@ -512,12 +546,14 @@ class ConnectionHandler(object):
if retry(self._connect_loop, retry) is STOP_CONNECTING:
break
except RetryFailedError:
- self.logger.warning("Failed connecting to Zookeeper "
- "within the connection retry policy.")
+ self.logger.warning(
+ "Failed connecting to Zookeeper "
+ "within the connection retry policy."
+ )
finally:
self.connection_stopped.set()
self.client._session_callback(KeeperState.CLOSED)
- self.logger.log(BLATHER, 'Connection stopped')
+ self.logger.log(BLATHER, "Connection stopped")
def _expand_client_hosts(self):
# Expand the entire list in advance so we can randomize it if needed
@@ -525,8 +561,9 @@ class ConnectionHandler(object):
for host, port in self.client.hosts:
try:
host = host.strip()
- for rhost in socket.getaddrinfo(host, port, 0, 0,
- socket.IPPROTO_TCP):
+ for rhost in socket.getaddrinfo(
+ host, port, 0, 0, socket.IPPROTO_TCP
+ ):
host_ports.append((host, rhost[4][0], rhost[4][1]))
except socket.gaierror as e:
# Skip hosts that don't resolve
@@ -543,7 +580,7 @@ class ConnectionHandler(object):
# Check for an empty hostlist, indicating none resolved
if len(host_ports) == 0:
- raise ForceRetryError('No host resolved. Reconnecting')
+ raise ForceRetryError("No host resolved. Reconnecting")
for host, hostip, port in host_ports:
if self.client._stopped.is_set():
@@ -556,7 +593,7 @@ class ConnectionHandler(object):
if status is STOP_CONNECTING:
return STOP_CONNECTING
else:
- raise ForceRetryError('Reconnecting')
+ raise ForceRetryError("Reconnecting")
def _connect_attempt(self, host, hostip, port, retry):
client = self.client
@@ -566,8 +603,9 @@ class ConnectionHandler(object):
# Were we given a r/w server? If so, use that instead
if self._rw_server:
- self.logger.log(BLATHER,
- "Found r/w server to use, %s:%s", host, port)
+ self.logger.log(
+ BLATHER, "Found r/w server to use, %s:%s", host, port
+ )
host, port = self._rw_server
self._rw_server = None
@@ -589,14 +627,16 @@ class ConnectionHandler(object):
deadline = last_send + read_timeout / 2.0 - jitter_time
# Ensure our timeout is positive
timeout = max([deadline - time.time(), jitter_time])
- s = self.handler.select([self._socket, self._read_sock],
- [], [], timeout)[0]
+ s = self.handler.select(
+ [self._socket, self._read_sock], [], [], timeout
+ )[0]
if not s:
if self.ping_outstanding.is_set():
self.ping_outstanding.clear()
raise ConnectionDropped(
- "outstanding heartbeat ping not received")
+ "outstanding heartbeat ping not received"
+ )
else:
if self._socket in s:
response = self._read_socket(read_timeout)
@@ -614,32 +654,32 @@ class ConnectionHandler(object):
if time.time() >= deadline:
self._send_ping(connect_timeout)
last_send = time.time()
- self.logger.info('Closing connection to %s:%s', host, port)
+ self.logger.info("Closing connection to %s:%s", host, port)
client._session_callback(KeeperState.CLOSED)
return STOP_CONNECTING
except (ConnectionDropped, KazooTimeoutError) as e:
if isinstance(e, ConnectionDropped):
- self.logger.warning('Connection dropped: %s', e)
+ self.logger.warning("Connection dropped: %s", e)
else:
- self.logger.warning('Connection time-out: %s', e)
+ self.logger.warning("Connection time-out: %s", e)
if client._state != KeeperState.CONNECTING:
self.logger.warning("Transition to CONNECTING")
client._session_callback(KeeperState.CONNECTING)
except AuthFailedError as err:
retry.reset()
- self.logger.warning('AUTH_FAILED closing: %s', err)
+ self.logger.warning("AUTH_FAILED closing: %s", err)
client._session_callback(KeeperState.AUTH_FAILED)
return STOP_CONNECTING
except SessionExpiredError:
retry.reset()
- self.logger.warning('Session has expired')
+ self.logger.warning("Session has expired")
client._session_callback(KeeperState.EXPIRED_SESSION)
except RWServerAvailable:
retry.reset()
- self.logger.warning('Found a RW server, dropping connection')
+ self.logger.warning("Found a RW server, dropping connection")
client._session_callback(KeeperState.CONNECTING)
except Exception:
- self.logger.exception('Unhandled exception in connection loop')
+ self.logger.exception("Unhandled exception in connection loop")
raise
finally:
if self._socket is not None:
@@ -647,13 +687,20 @@ class ConnectionHandler(object):
def _connect(self, host, hostip, port):
client = self.client
- self.logger.info('Connecting to %s(%s):%s, use_ssl: %r',
- host, hostip, port, self.client.use_ssl)
+ self.logger.info(
+ "Connecting to %s(%s):%s, use_ssl: %r",
+ host,
+ hostip,
+ port,
+ self.client.use_ssl,
+ )
- self.logger.log(BLATHER,
- ' Using session_id: %r session_passwd: %s',
- client._session_id,
- hexlify(client._session_passwd))
+ self.logger.log(
+ BLATHER,
+ " Using session_id: %r session_passwd: %s",
+ client._session_id,
+ hexlify(client._session_passwd),
+ )
with self._socket_error_handling():
self._socket = self.handler.create_connection(
@@ -669,12 +716,18 @@ class ConnectionHandler(object):
self._socket.setblocking(0)
- connect = Connect(0, client.last_zxid, client._session_timeout,
- client._session_id or 0, client._session_passwd,
- client.read_only)
+ connect = Connect(
+ 0,
+ client.last_zxid,
+ client._session_timeout,
+ client._session_id or 0,
+ client._session_passwd,
+ client.read_only,
+ )
connect_result, zxid = self._invoke(
- client._session_timeout / 1000.0 / len(client.hosts), connect)
+ client._session_timeout / 1000.0 / len(client.hosts), connect
+ )
if connect_result.time_out <= 0:
raise SessionExpiredError("Session has expired")
@@ -690,14 +743,18 @@ class ConnectionHandler(object):
read_timeout = negotiated_session_timeout * 2.0 / 3.0
client._session_passwd = connect_result.passwd
- self.logger.log(BLATHER,
- 'Session created, session_id: %r session_passwd: %s\n'
- ' negotiated session timeout: %s\n'
- ' connect timeout: %s\n'
- ' read timeout: %s', client._session_id,
- hexlify(client._session_passwd),
- negotiated_session_timeout, connect_timeout,
- read_timeout)
+ self.logger.log(
+ BLATHER,
+ "Session created, session_id: %r session_passwd: %s\n"
+ " negotiated session timeout: %s\n"
+ " connect timeout: %s\n"
+ " read timeout: %s",
+ client._session_id,
+ hexlify(client._session_passwd),
+ negotiated_session_timeout,
+ connect_timeout,
+ read_timeout,
+ )
if connect_result.read_only:
client._session_callback(KeeperState.CONNECTED_RO)
@@ -722,24 +779,22 @@ class ConnectionHandler(object):
return read_timeout, connect_timeout
def _authenticate_with_sasl(self, host, timeout):
- """Establish a SASL authenticated connection to the server.
- """
+ """Establish a SASL authenticated connection to the server."""
if not PURESASL_AVAILABLE:
- raise SASLException('Missing SASL support')
+ raise SASLException("Missing SASL support")
- if 'service' not in self.sasl_options:
- self.sasl_options['service'] = 'zookeeper'
+ if "service" not in self.sasl_options:
+ self.sasl_options["service"] = "zookeeper"
# NOTE: Zookeeper hardcoded the domain for Digest authentication
# instead of using the hostname. See
# zookeeper/util/SecurityUtils.java#L74 and Server/Client
# initializations.
- if self.sasl_options['mechanism'] == 'DIGEST-MD5':
- host = 'zk-sasl-md5'
+ if self.sasl_options["mechanism"] == "DIGEST-MD5":
+ host = "zk-sasl-md5"
sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient(
- host=host,
- **self.sasl_options
+ host=host, **self.sasl_options
)
# Inititalize the process with an empty challenge token
@@ -755,26 +810,26 @@ class ConnectionHandler(object):
except puresasl.SASLError as err:
six.reraise(
SASLException,
- SASLException('library error: %s' % err.message),
- sys.exc_info()[2]
+ SASLException("library error: %s" % err),
+ sys.exc_info()[2],
)
except puresasl.SASLProtocolException as err:
six.reraise(
AuthFailedError,
- AuthFailedError('protocol error: %s' % err.message),
- sys.exc_info()[2]
+ AuthFailedError("protocol error: %s" % err),
+ sys.exc_info()[2],
)
except Exception as err:
six.reraise(
AuthFailedError,
- AuthFailedError('Unknown error: %s' % err),
- sys.exc_info()[2]
+ AuthFailedError("Unknown error: %s" % err),
+ sys.exc_info()[2],
)
if sasl_cli.complete and not response:
break
elif response is None:
- response = b''
+ response = b""
xid = (xid % 2147483647) + 1
@@ -787,13 +842,16 @@ class ConnectionHandler(object):
# Zookeeper simply drops connections with failed authentication
six.reraise(
AuthFailedError,
- AuthFailedError('Connection dropped in SASL'),
- sys.exc_info()[2]
+ AuthFailedError("Connection dropped in SASL"),
+ sys.exc_info()[2],
)
if header.xid != xid:
- raise RuntimeError('xids do not match, expected %r '
- 'received %r', xid, header.xid)
+ raise RuntimeError(
+ "xids do not match, expected %r " "received %r",
+ xid,
+ header.xid,
+ )
if header.zxid > 0:
self.client.last_zxid = header.zxid
@@ -801,7 +859,8 @@ class ConnectionHandler(object):
if header.err:
callback_exception = EXCEPTIONS[header.err]()
self.logger.debug(
- 'Received error(xid=%s) %r', xid, callback_exception)
+ "Received error(xid=%s) %r", xid, callback_exception
+ )
raise callback_exception
challenge, _ = SASL.deserialize(buffer, offset)