diff options
Diffstat (limited to 'kazoo/protocol/connection.py')
-rw-r--r-- | kazoo/protocol/connection.py | 261 |
1 files changed, 160 insertions, 101 deletions
diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 726f645..d7d84d1 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -50,6 +50,7 @@ from kazoo.retry import ( try: import puresasl import puresasl.client + PURESASL_AVAILABLE = True except ImportError: PURESASL_AVAILABLE = False @@ -76,12 +77,14 @@ AUTH_XID = -4 CLOSE_RESPONSE = Close.type -if sys.version_info > (3, ): # pragma: nocover +if sys.version_info > (3,): # pragma: nocover + def buffer(obj, offset=0): return memoryview(obj)[offset:] advance_iterator = next else: # pragma: nocover + def advance_iterator(it): return it.next() @@ -99,6 +102,7 @@ class RWPinger(object): the iterator will yield False if called too soon. """ + def __init__(self, hosts, connection_func, socket_handling): self.hosts = hosts self.connection = connection_func @@ -126,7 +130,7 @@ class RWPinger(object): sock.sendall(b"isro") result = sock.recv(8192) sock.close() - if result == b'rw': + if result == b"rw": return (host, port) else: return False @@ -145,6 +149,7 @@ class RWServerAvailable(Exception): class ConnectionHandler(object): """Zookeeper connection handler""" + def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.client = client self.handler = client.handler @@ -178,7 +183,7 @@ class ConnectionHandler(object): try: yield except (socket.error, select.error) as e: - err = getattr(e, 'strerror', e) + err = getattr(e, "strerror", e) raise ConnectionDropped("socket connection error: %s" % (err,)) def start(self): @@ -188,8 +193,9 @@ class ConnectionHandler(object): self._read_sock, self._write_sock = rw_sockets self.connection_closed.clear() if self._connection_routine: - raise Exception("Unable to start, connection routine already " - "active.") + raise Exception( + "Unable to start, connection routine already " "active." + ) self._connection_routine = self.handler.spawn(self.zk_loop) def stop(self, timeout=None): @@ -218,8 +224,11 @@ class ConnectionHandler(object): def _server_pinger(self): """Returns a server pinger iterable, that will ping the next server in the list, and apply a back-off between attempts.""" - return RWPinger(self.client.hosts, self.handler.create_connection, - self._socket_error_handling) + return RWPinger( + self.client.hosts, + self.handler.create_connection, + self._socket_error_handling, + ) def _read_header(self, timeout): b = self._read(4, timeout) @@ -238,8 +247,10 @@ class ConnectionHandler(object): # have anything to select, but the wrapped object may still # have something to read as it has previously gotten enough # data from the underlying socket. - if (hasattr(self._socket, "pending") and - self._socket.pending() > 0): + if ( + hasattr(self._socket, "pending") + and self._socket.pending() > 0 + ): pass else: s = self.handler.select([self._socket], [], [], timeout)[0] @@ -247,17 +258,20 @@ class ConnectionHandler(object): # If the read list is empty, we got a timeout. We don't # have to check wlist and xlist as we don't set any raise self.handler.timeout_exception( - "socket time-out during read") + "socket time-out during read" + ) try: chunk = self._socket.recv(remaining) except ssl.SSLError as e: - if e.errno in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE): + if e.errno in ( + ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ): continue else: raise - if chunk == b'': - raise ConnectionDropped('socket connection broken') + if chunk == b"": + raise ConnectionDropped("socket connection broken") msgparts.append(chunk) remaining -= len(chunk) return b"".join(msgparts) @@ -270,14 +284,18 @@ class ConnectionHandler(object): if xid: header, buffer, offset = self._read_header(timeout) if header.xid != xid: - raise RuntimeError('xids do not match, expected %r ' - 'received %r', xid, header.xid) + raise RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) if header.zxid > 0: zxid = header.zxid if header.err: callback_exception = EXCEPTIONS[header.err]() self.logger.debug( - 'Received error(xid=%s) %r', xid, callback_exception) + "Received error(xid=%s) %r", xid, callback_exception + ) raise callback_exception return zxid @@ -285,17 +303,19 @@ class ConnectionHandler(object): length = int_struct.unpack(msg)[0] msg = self._read(length, timeout) - if hasattr(request, 'deserialize'): + if hasattr(request, "deserialize"): try: obj, _ = request.deserialize(msg, 0) except Exception: self.logger.exception( "Exception raised during deserialization " - "of request: %s", request) + "of request: %s", + request, + ) # raise ConnectionDropped so connect loop will retry - raise ConnectionDropped('invalid server response') - self.logger.log(BLATHER, 'Read response %s', obj) + raise ConnectionDropped("invalid server response") + self.logger.log(BLATHER, "Read response %s", obj) return obj, zxid return zxid @@ -311,7 +331,10 @@ class ConnectionHandler(object): b += request.serialize() self.logger.log( (BLATHER if isinstance(request, Ping) else logging.DEBUG), - "Sending request(xid=%s): %s", xid, request) + "Sending request(xid=%s): %s", + xid, + request, + ) self._write(int_struct.pack(len(b)) + b, timeout) def _write(self, msg, timeout): @@ -324,19 +347,22 @@ class ConnectionHandler(object): if not s: # pragma: nocover # If the write list is empty, we got a timeout. We don't # have to check rlist and xlist as we don't set any - raise self.handler.timeout_exception("socket time-out" - " during write") + raise self.handler.timeout_exception( + "socket time-out" " during write" + ) msg_slice = buffer(msg, sent) try: bytes_sent = self._socket.send(msg_slice) except ssl.SSLError as e: - if e.errno in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE): + if e.errno in ( + ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ): continue else: raise if not bytes_sent: - raise ConnectionDropped('socket connection broken') + raise ConnectionDropped("socket connection broken") sent += bytes_sent def _read_watch_event(self, buffer, offset): @@ -344,7 +370,7 @@ class ConnectionHandler(object): watch, offset = Watch.deserialize(buffer, offset) path = watch.path - self.logger.debug('Received EVENT: %s', watch) + self.logger.debug("Received EVENT: %s", watch) watchers = [] @@ -356,7 +382,7 @@ class ConnectionHandler(object): elif watch.type == CHILD_EVENT: watchers.extend(client._child_watchers.pop(path, [])) else: - self.logger.warn('Received unknown event %r', watch.type) + self.logger.warn("Received unknown event %r", watch.type) return # Strip the chroot if needed @@ -369,7 +395,7 @@ class ConnectionHandler(object): # Dump the watchers to the watch thread for watch in watchers: - client.handler.dispatch_callback(Callback('watch', watch, (ev,))) + client.handler.dispatch_callback(Callback("watch", watch, (ev,))) def _read_response(self, header, buffer, offset): client = self.client @@ -377,20 +403,25 @@ class ConnectionHandler(object): if header.zxid and header.zxid > 0: client.last_zxid = header.zxid if header.xid != xid: - exc = RuntimeError('xids do not match, expected %r ' - 'received %r', xid, header.xid) + exc = RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) async_object.set_exception(exc) raise exc # Determine if its an exists request and a no node error - exists_error = (header.err == NoNodeError.code and - request.type == Exists.type) + exists_error = ( + header.err == NoNodeError.code and request.type == Exists.type + ) # Set the exception if its not an exists error if header.err and not exists_error: callback_exception = EXCEPTIONS[header.err]() self.logger.debug( - 'Received error(xid=%s) %r', xid, callback_exception) + "Received error(xid=%s) %r", xid, callback_exception + ) if async_object: async_object.set_exception(callback_exception) elif request and async_object: @@ -404,11 +435,14 @@ class ConnectionHandler(object): except Exception as exc: self.logger.exception( "Exception raised during deserialization " - "of request: %s", request) + "of request: %s", + request, + ) async_object.set_exception(exc) return self.logger.debug( - 'Received response(xid=%s): %r', xid, response) + "Received response(xid=%s): %r", xid, response + ) # We special case a Transaction as we have to unchroot things if request.type == Transaction.type: @@ -417,7 +451,7 @@ class ConnectionHandler(object): async_object.set(response) # Determine if watchers should be registered - watcher = getattr(request, 'watcher', None) + watcher = getattr(request, "watcher", None) if not client._stopped.is_set() and watcher: if isinstance(request, (GetChildren, GetChildren2)): client._child_watchers[request.path].add(watcher) @@ -425,7 +459,7 @@ class ConnectionHandler(object): client._data_watchers[request.path].add(watcher) if isinstance(request, Close): - self.logger.log(BLATHER, 'Read close response') + self.logger.log(BLATHER, "Read close response") return CLOSE_RESPONSE def _read_socket(self, read_timeout): @@ -434,10 +468,10 @@ class ConnectionHandler(object): header, buffer, offset = self._read_header(read_timeout) if header.xid == PING_XID: - self.logger.log(BLATHER, 'Received Ping') + self.logger.log(BLATHER, "Received Ping") self.ping_outstanding.clear() elif header.xid == AUTH_XID: - self.logger.log(BLATHER, 'Received AUTH') + self.logger.log(BLATHER, "Received AUTH") request, async_object, xid = client._pending.popleft() if header.err: @@ -448,7 +482,7 @@ class ConnectionHandler(object): elif header.xid == WATCH_XID: self._read_watch_event(buffer, offset) else: - self.logger.log(BLATHER, 'Reading for header %r', header) + self.logger.log(BLATHER, "Reading for header %r", header) return self._read_response(header, buffer, offset) @@ -501,7 +535,7 @@ class ConnectionHandler(object): def zk_loop(self): """Main Zookeeper handling loop""" - self.logger.log(BLATHER, 'ZK loop started') + self.logger.log(BLATHER, "ZK loop started") self.connection_stopped.clear() @@ -512,12 +546,14 @@ class ConnectionHandler(object): if retry(self._connect_loop, retry) is STOP_CONNECTING: break except RetryFailedError: - self.logger.warning("Failed connecting to Zookeeper " - "within the connection retry policy.") + self.logger.warning( + "Failed connecting to Zookeeper " + "within the connection retry policy." + ) finally: self.connection_stopped.set() self.client._session_callback(KeeperState.CLOSED) - self.logger.log(BLATHER, 'Connection stopped') + self.logger.log(BLATHER, "Connection stopped") def _expand_client_hosts(self): # Expand the entire list in advance so we can randomize it if needed @@ -525,8 +561,9 @@ class ConnectionHandler(object): for host, port in self.client.hosts: try: host = host.strip() - for rhost in socket.getaddrinfo(host, port, 0, 0, - socket.IPPROTO_TCP): + for rhost in socket.getaddrinfo( + host, port, 0, 0, socket.IPPROTO_TCP + ): host_ports.append((host, rhost[4][0], rhost[4][1])) except socket.gaierror as e: # Skip hosts that don't resolve @@ -543,7 +580,7 @@ class ConnectionHandler(object): # Check for an empty hostlist, indicating none resolved if len(host_ports) == 0: - raise ForceRetryError('No host resolved. Reconnecting') + raise ForceRetryError("No host resolved. Reconnecting") for host, hostip, port in host_ports: if self.client._stopped.is_set(): @@ -556,7 +593,7 @@ class ConnectionHandler(object): if status is STOP_CONNECTING: return STOP_CONNECTING else: - raise ForceRetryError('Reconnecting') + raise ForceRetryError("Reconnecting") def _connect_attempt(self, host, hostip, port, retry): client = self.client @@ -566,8 +603,9 @@ class ConnectionHandler(object): # Were we given a r/w server? If so, use that instead if self._rw_server: - self.logger.log(BLATHER, - "Found r/w server to use, %s:%s", host, port) + self.logger.log( + BLATHER, "Found r/w server to use, %s:%s", host, port + ) host, port = self._rw_server self._rw_server = None @@ -589,14 +627,16 @@ class ConnectionHandler(object): deadline = last_send + read_timeout / 2.0 - jitter_time # Ensure our timeout is positive timeout = max([deadline - time.time(), jitter_time]) - s = self.handler.select([self._socket, self._read_sock], - [], [], timeout)[0] + s = self.handler.select( + [self._socket, self._read_sock], [], [], timeout + )[0] if not s: if self.ping_outstanding.is_set(): self.ping_outstanding.clear() raise ConnectionDropped( - "outstanding heartbeat ping not received") + "outstanding heartbeat ping not received" + ) else: if self._socket in s: response = self._read_socket(read_timeout) @@ -614,32 +654,32 @@ class ConnectionHandler(object): if time.time() >= deadline: self._send_ping(connect_timeout) last_send = time.time() - self.logger.info('Closing connection to %s:%s', host, port) + self.logger.info("Closing connection to %s:%s", host, port) client._session_callback(KeeperState.CLOSED) return STOP_CONNECTING except (ConnectionDropped, KazooTimeoutError) as e: if isinstance(e, ConnectionDropped): - self.logger.warning('Connection dropped: %s', e) + self.logger.warning("Connection dropped: %s", e) else: - self.logger.warning('Connection time-out: %s', e) + self.logger.warning("Connection time-out: %s", e) if client._state != KeeperState.CONNECTING: self.logger.warning("Transition to CONNECTING") client._session_callback(KeeperState.CONNECTING) except AuthFailedError as err: retry.reset() - self.logger.warning('AUTH_FAILED closing: %s', err) + self.logger.warning("AUTH_FAILED closing: %s", err) client._session_callback(KeeperState.AUTH_FAILED) return STOP_CONNECTING except SessionExpiredError: retry.reset() - self.logger.warning('Session has expired') + self.logger.warning("Session has expired") client._session_callback(KeeperState.EXPIRED_SESSION) except RWServerAvailable: retry.reset() - self.logger.warning('Found a RW server, dropping connection') + self.logger.warning("Found a RW server, dropping connection") client._session_callback(KeeperState.CONNECTING) except Exception: - self.logger.exception('Unhandled exception in connection loop') + self.logger.exception("Unhandled exception in connection loop") raise finally: if self._socket is not None: @@ -647,13 +687,20 @@ class ConnectionHandler(object): def _connect(self, host, hostip, port): client = self.client - self.logger.info('Connecting to %s(%s):%s, use_ssl: %r', - host, hostip, port, self.client.use_ssl) + self.logger.info( + "Connecting to %s(%s):%s, use_ssl: %r", + host, + hostip, + port, + self.client.use_ssl, + ) - self.logger.log(BLATHER, - ' Using session_id: %r session_passwd: %s', - client._session_id, - hexlify(client._session_passwd)) + self.logger.log( + BLATHER, + " Using session_id: %r session_passwd: %s", + client._session_id, + hexlify(client._session_passwd), + ) with self._socket_error_handling(): self._socket = self.handler.create_connection( @@ -669,12 +716,18 @@ class ConnectionHandler(object): self._socket.setblocking(0) - connect = Connect(0, client.last_zxid, client._session_timeout, - client._session_id or 0, client._session_passwd, - client.read_only) + connect = Connect( + 0, + client.last_zxid, + client._session_timeout, + client._session_id or 0, + client._session_passwd, + client.read_only, + ) connect_result, zxid = self._invoke( - client._session_timeout / 1000.0 / len(client.hosts), connect) + client._session_timeout / 1000.0 / len(client.hosts), connect + ) if connect_result.time_out <= 0: raise SessionExpiredError("Session has expired") @@ -690,14 +743,18 @@ class ConnectionHandler(object): read_timeout = negotiated_session_timeout * 2.0 / 3.0 client._session_passwd = connect_result.passwd - self.logger.log(BLATHER, - 'Session created, session_id: %r session_passwd: %s\n' - ' negotiated session timeout: %s\n' - ' connect timeout: %s\n' - ' read timeout: %s', client._session_id, - hexlify(client._session_passwd), - negotiated_session_timeout, connect_timeout, - read_timeout) + self.logger.log( + BLATHER, + "Session created, session_id: %r session_passwd: %s\n" + " negotiated session timeout: %s\n" + " connect timeout: %s\n" + " read timeout: %s", + client._session_id, + hexlify(client._session_passwd), + negotiated_session_timeout, + connect_timeout, + read_timeout, + ) if connect_result.read_only: client._session_callback(KeeperState.CONNECTED_RO) @@ -722,24 +779,22 @@ class ConnectionHandler(object): return read_timeout, connect_timeout def _authenticate_with_sasl(self, host, timeout): - """Establish a SASL authenticated connection to the server. - """ + """Establish a SASL authenticated connection to the server.""" if not PURESASL_AVAILABLE: - raise SASLException('Missing SASL support') + raise SASLException("Missing SASL support") - if 'service' not in self.sasl_options: - self.sasl_options['service'] = 'zookeeper' + if "service" not in self.sasl_options: + self.sasl_options["service"] = "zookeeper" # NOTE: Zookeeper hardcoded the domain for Digest authentication # instead of using the hostname. See # zookeeper/util/SecurityUtils.java#L74 and Server/Client # initializations. - if self.sasl_options['mechanism'] == 'DIGEST-MD5': - host = 'zk-sasl-md5' + if self.sasl_options["mechanism"] == "DIGEST-MD5": + host = "zk-sasl-md5" sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( - host=host, - **self.sasl_options + host=host, **self.sasl_options ) # Inititalize the process with an empty challenge token @@ -755,26 +810,26 @@ class ConnectionHandler(object): except puresasl.SASLError as err: six.reraise( SASLException, - SASLException('library error: %s' % err.message), - sys.exc_info()[2] + SASLException("library error: %s" % err), + sys.exc_info()[2], ) except puresasl.SASLProtocolException as err: six.reraise( AuthFailedError, - AuthFailedError('protocol error: %s' % err.message), - sys.exc_info()[2] + AuthFailedError("protocol error: %s" % err), + sys.exc_info()[2], ) except Exception as err: six.reraise( AuthFailedError, - AuthFailedError('Unknown error: %s' % err), - sys.exc_info()[2] + AuthFailedError("Unknown error: %s" % err), + sys.exc_info()[2], ) if sasl_cli.complete and not response: break elif response is None: - response = b'' + response = b"" xid = (xid % 2147483647) + 1 @@ -787,13 +842,16 @@ class ConnectionHandler(object): # Zookeeper simply drops connections with failed authentication six.reraise( AuthFailedError, - AuthFailedError('Connection dropped in SASL'), - sys.exc_info()[2] + AuthFailedError("Connection dropped in SASL"), + sys.exc_info()[2], ) if header.xid != xid: - raise RuntimeError('xids do not match, expected %r ' - 'received %r', xid, header.xid) + raise RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) if header.zxid > 0: self.client.last_zxid = header.zxid @@ -801,7 +859,8 @@ class ConnectionHandler(object): if header.err: callback_exception = EXCEPTIONS[header.err]() self.logger.debug( - 'Received error(xid=%s) %r', xid, callback_exception) + "Received error(xid=%s) %r", xid, callback_exception + ) raise callback_exception challenge, _ = SASL.deserialize(buffer, offset) |