diff options
Diffstat (limited to 'kazoo/protocol')
-rw-r--r-- | kazoo/protocol/connection.py | 261 | ||||
-rw-r--r-- | kazoo/protocol/paths.py | 37 | ||||
-rw-r--r-- | kazoo/protocol/serialization.py | 146 | ||||
-rw-r--r-- | kazoo/protocol/states.py | 44 |
4 files changed, 295 insertions, 193 deletions
diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 726f645..e1bd996 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -50,6 +50,7 @@ from kazoo.retry import ( try: import puresasl import puresasl.client + PURESASL_AVAILABLE = True except ImportError: PURESASL_AVAILABLE = False @@ -76,12 +77,14 @@ AUTH_XID = -4 CLOSE_RESPONSE = Close.type -if sys.version_info > (3, ): # pragma: nocover +if sys.version_info > (3,): # pragma: nocover + def buffer(obj, offset=0): return memoryview(obj)[offset:] advance_iterator = next else: # pragma: nocover + def advance_iterator(it): return it.next() @@ -99,6 +102,7 @@ class RWPinger(object): the iterator will yield False if called too soon. """ + def __init__(self, hosts, connection_func, socket_handling): self.hosts = hosts self.connection = connection_func @@ -126,7 +130,7 @@ class RWPinger(object): sock.sendall(b"isro") result = sock.recv(8192) sock.close() - if result == b'rw': + if result == b"rw": return (host, port) else: return False @@ -145,6 +149,7 @@ class RWServerAvailable(Exception): class ConnectionHandler(object): """Zookeeper connection handler""" + def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.client = client self.handler = client.handler @@ -178,7 +183,7 @@ class ConnectionHandler(object): try: yield except (socket.error, select.error) as e: - err = getattr(e, 'strerror', e) + err = getattr(e, "strerror", e) raise ConnectionDropped("socket connection error: %s" % (err,)) def start(self): @@ -188,8 +193,9 @@ class ConnectionHandler(object): self._read_sock, self._write_sock = rw_sockets self.connection_closed.clear() if self._connection_routine: - raise Exception("Unable to start, connection routine already " - "active.") + raise Exception( + "Unable to start, connection routine already " "active." + ) self._connection_routine = self.handler.spawn(self.zk_loop) def stop(self, timeout=None): @@ -218,8 +224,11 @@ class ConnectionHandler(object): def _server_pinger(self): """Returns a server pinger iterable, that will ping the next server in the list, and apply a back-off between attempts.""" - return RWPinger(self.client.hosts, self.handler.create_connection, - self._socket_error_handling) + return RWPinger( + self.client.hosts, + self.handler.create_connection, + self._socket_error_handling, + ) def _read_header(self, timeout): b = self._read(4, timeout) @@ -238,8 +247,10 @@ class ConnectionHandler(object): # have anything to select, but the wrapped object may still # have something to read as it has previously gotten enough # data from the underlying socket. - if (hasattr(self._socket, "pending") and - self._socket.pending() > 0): + if ( + hasattr(self._socket, "pending") + and self._socket.pending() > 0 + ): pass else: s = self.handler.select([self._socket], [], [], timeout)[0] @@ -247,17 +258,20 @@ class ConnectionHandler(object): # If the read list is empty, we got a timeout. We don't # have to check wlist and xlist as we don't set any raise self.handler.timeout_exception( - "socket time-out during read") + "socket time-out during read" + ) try: chunk = self._socket.recv(remaining) except ssl.SSLError as e: - if e.errno in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE): + if e.errno in ( + ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ): continue else: raise - if chunk == b'': - raise ConnectionDropped('socket connection broken') + if chunk == b"": + raise ConnectionDropped("socket connection broken") msgparts.append(chunk) remaining -= len(chunk) return b"".join(msgparts) @@ -270,14 +284,18 @@ class ConnectionHandler(object): if xid: header, buffer, offset = self._read_header(timeout) if header.xid != xid: - raise RuntimeError('xids do not match, expected %r ' - 'received %r', xid, header.xid) + raise RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) if header.zxid > 0: zxid = header.zxid if header.err: callback_exception = EXCEPTIONS[header.err]() self.logger.debug( - 'Received error(xid=%s) %r', xid, callback_exception) + "Received error(xid=%s) %r", xid, callback_exception + ) raise callback_exception return zxid @@ -285,17 +303,19 @@ class ConnectionHandler(object): length = int_struct.unpack(msg)[0] msg = self._read(length, timeout) - if hasattr(request, 'deserialize'): + if hasattr(request, "deserialize"): try: obj, _ = request.deserialize(msg, 0) except Exception: self.logger.exception( "Exception raised during deserialization " - "of request: %s", request) + "of request: %s", + request, + ) # raise ConnectionDropped so connect loop will retry - raise ConnectionDropped('invalid server response') - self.logger.log(BLATHER, 'Read response %s', obj) + raise ConnectionDropped("invalid server response") + self.logger.log(BLATHER, "Read response %s", obj) return obj, zxid return zxid @@ -311,7 +331,10 @@ class ConnectionHandler(object): b += request.serialize() self.logger.log( (BLATHER if isinstance(request, Ping) else logging.DEBUG), - "Sending request(xid=%s): %s", xid, request) + "Sending request(xid=%s): %s", + xid, + request, + ) self._write(int_struct.pack(len(b)) + b, timeout) def _write(self, msg, timeout): @@ -324,19 +347,22 @@ class ConnectionHandler(object): if not s: # pragma: nocover # If the write list is empty, we got a timeout. We don't # have to check rlist and xlist as we don't set any - raise self.handler.timeout_exception("socket time-out" - " during write") + raise self.handler.timeout_exception( + "socket time-out" " during write" + ) msg_slice = buffer(msg, sent) try: bytes_sent = self._socket.send(msg_slice) except ssl.SSLError as e: - if e.errno in (ssl.SSL_ERROR_WANT_READ, - ssl.SSL_ERROR_WANT_WRITE): + if e.errno in ( + ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ): continue else: raise if not bytes_sent: - raise ConnectionDropped('socket connection broken') + raise ConnectionDropped("socket connection broken") sent += bytes_sent def _read_watch_event(self, buffer, offset): @@ -344,7 +370,7 @@ class ConnectionHandler(object): watch, offset = Watch.deserialize(buffer, offset) path = watch.path - self.logger.debug('Received EVENT: %s', watch) + self.logger.debug("Received EVENT: %s", watch) watchers = [] @@ -356,7 +382,7 @@ class ConnectionHandler(object): elif watch.type == CHILD_EVENT: watchers.extend(client._child_watchers.pop(path, [])) else: - self.logger.warn('Received unknown event %r', watch.type) + self.logger.warn("Received unknown event %r", watch.type) return # Strip the chroot if needed @@ -369,7 +395,7 @@ class ConnectionHandler(object): # Dump the watchers to the watch thread for watch in watchers: - client.handler.dispatch_callback(Callback('watch', watch, (ev,))) + client.handler.dispatch_callback(Callback("watch", watch, (ev,))) def _read_response(self, header, buffer, offset): client = self.client @@ -377,20 +403,25 @@ class ConnectionHandler(object): if header.zxid and header.zxid > 0: client.last_zxid = header.zxid if header.xid != xid: - exc = RuntimeError('xids do not match, expected %r ' - 'received %r', xid, header.xid) + exc = RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) async_object.set_exception(exc) raise exc # Determine if its an exists request and a no node error - exists_error = (header.err == NoNodeError.code and - request.type == Exists.type) + exists_error = ( + header.err == NoNodeError.code and request.type == Exists.type + ) # Set the exception if its not an exists error if header.err and not exists_error: callback_exception = EXCEPTIONS[header.err]() self.logger.debug( - 'Received error(xid=%s) %r', xid, callback_exception) + "Received error(xid=%s) %r", xid, callback_exception + ) if async_object: async_object.set_exception(callback_exception) elif request and async_object: @@ -404,11 +435,14 @@ class ConnectionHandler(object): except Exception as exc: self.logger.exception( "Exception raised during deserialization " - "of request: %s", request) + "of request: %s", + request, + ) async_object.set_exception(exc) return self.logger.debug( - 'Received response(xid=%s): %r', xid, response) + "Received response(xid=%s): %r", xid, response + ) # We special case a Transaction as we have to unchroot things if request.type == Transaction.type: @@ -417,7 +451,7 @@ class ConnectionHandler(object): async_object.set(response) # Determine if watchers should be registered - watcher = getattr(request, 'watcher', None) + watcher = getattr(request, "watcher", None) if not client._stopped.is_set() and watcher: if isinstance(request, (GetChildren, GetChildren2)): client._child_watchers[request.path].add(watcher) @@ -425,7 +459,7 @@ class ConnectionHandler(object): client._data_watchers[request.path].add(watcher) if isinstance(request, Close): - self.logger.log(BLATHER, 'Read close response') + self.logger.log(BLATHER, "Read close response") return CLOSE_RESPONSE def _read_socket(self, read_timeout): @@ -434,10 +468,10 @@ class ConnectionHandler(object): header, buffer, offset = self._read_header(read_timeout) if header.xid == PING_XID: - self.logger.log(BLATHER, 'Received Ping') + self.logger.log(BLATHER, "Received Ping") self.ping_outstanding.clear() elif header.xid == AUTH_XID: - self.logger.log(BLATHER, 'Received AUTH') + self.logger.log(BLATHER, "Received AUTH") request, async_object, xid = client._pending.popleft() if header.err: @@ -448,7 +482,7 @@ class ConnectionHandler(object): elif header.xid == WATCH_XID: self._read_watch_event(buffer, offset) else: - self.logger.log(BLATHER, 'Reading for header %r', header) + self.logger.log(BLATHER, "Reading for header %r", header) return self._read_response(header, buffer, offset) @@ -501,7 +535,7 @@ class ConnectionHandler(object): def zk_loop(self): """Main Zookeeper handling loop""" - self.logger.log(BLATHER, 'ZK loop started') + self.logger.log(BLATHER, "ZK loop started") self.connection_stopped.clear() @@ -512,12 +546,14 @@ class ConnectionHandler(object): if retry(self._connect_loop, retry) is STOP_CONNECTING: break except RetryFailedError: - self.logger.warning("Failed connecting to Zookeeper " - "within the connection retry policy.") + self.logger.warning( + "Failed connecting to Zookeeper " + "within the connection retry policy." + ) finally: self.connection_stopped.set() self.client._session_callback(KeeperState.CLOSED) - self.logger.log(BLATHER, 'Connection stopped') + self.logger.log(BLATHER, "Connection stopped") def _expand_client_hosts(self): # Expand the entire list in advance so we can randomize it if needed @@ -525,8 +561,9 @@ class ConnectionHandler(object): for host, port in self.client.hosts: try: host = host.strip() - for rhost in socket.getaddrinfo(host, port, 0, 0, - socket.IPPROTO_TCP): + for rhost in socket.getaddrinfo( + host, port, 0, 0, socket.IPPROTO_TCP + ): host_ports.append((host, rhost[4][0], rhost[4][1])) except socket.gaierror as e: # Skip hosts that don't resolve @@ -543,7 +580,7 @@ class ConnectionHandler(object): # Check for an empty hostlist, indicating none resolved if len(host_ports) == 0: - raise ForceRetryError('No host resolved. Reconnecting') + raise ForceRetryError("No host resolved. Reconnecting") for host, hostip, port in host_ports: if self.client._stopped.is_set(): @@ -556,7 +593,7 @@ class ConnectionHandler(object): if status is STOP_CONNECTING: return STOP_CONNECTING else: - raise ForceRetryError('Reconnecting') + raise ForceRetryError("Reconnecting") def _connect_attempt(self, host, hostip, port, retry): client = self.client @@ -566,8 +603,9 @@ class ConnectionHandler(object): # Were we given a r/w server? If so, use that instead if self._rw_server: - self.logger.log(BLATHER, - "Found r/w server to use, %s:%s", host, port) + self.logger.log( + BLATHER, "Found r/w server to use, %s:%s", host, port + ) host, port = self._rw_server self._rw_server = None @@ -589,14 +627,16 @@ class ConnectionHandler(object): deadline = last_send + read_timeout / 2.0 - jitter_time # Ensure our timeout is positive timeout = max([deadline - time.time(), jitter_time]) - s = self.handler.select([self._socket, self._read_sock], - [], [], timeout)[0] + s = self.handler.select( + [self._socket, self._read_sock], [], [], timeout + )[0] if not s: if self.ping_outstanding.is_set(): self.ping_outstanding.clear() raise ConnectionDropped( - "outstanding heartbeat ping not received") + "outstanding heartbeat ping not received" + ) else: if self._socket in s: response = self._read_socket(read_timeout) @@ -614,32 +654,32 @@ class ConnectionHandler(object): if time.time() >= deadline: self._send_ping(connect_timeout) last_send = time.time() - self.logger.info('Closing connection to %s:%s', host, port) + self.logger.info("Closing connection to %s:%s", host, port) client._session_callback(KeeperState.CLOSED) return STOP_CONNECTING except (ConnectionDropped, KazooTimeoutError) as e: if isinstance(e, ConnectionDropped): - self.logger.warning('Connection dropped: %s', e) + self.logger.warning("Connection dropped: %s", e) else: - self.logger.warning('Connection time-out: %s', e) + self.logger.warning("Connection time-out: %s", e) if client._state != KeeperState.CONNECTING: self.logger.warning("Transition to CONNECTING") client._session_callback(KeeperState.CONNECTING) except AuthFailedError as err: retry.reset() - self.logger.warning('AUTH_FAILED closing: %s', err) + self.logger.warning("AUTH_FAILED closing: %s", err) client._session_callback(KeeperState.AUTH_FAILED) return STOP_CONNECTING except SessionExpiredError: retry.reset() - self.logger.warning('Session has expired') + self.logger.warning("Session has expired") client._session_callback(KeeperState.EXPIRED_SESSION) except RWServerAvailable: retry.reset() - self.logger.warning('Found a RW server, dropping connection') + self.logger.warning("Found a RW server, dropping connection") client._session_callback(KeeperState.CONNECTING) except Exception: - self.logger.exception('Unhandled exception in connection loop') + self.logger.exception("Unhandled exception in connection loop") raise finally: if self._socket is not None: @@ -647,13 +687,20 @@ class ConnectionHandler(object): def _connect(self, host, hostip, port): client = self.client - self.logger.info('Connecting to %s(%s):%s, use_ssl: %r', - host, hostip, port, self.client.use_ssl) + self.logger.info( + "Connecting to %s(%s):%s, use_ssl: %r", + host, + hostip, + port, + self.client.use_ssl, + ) - self.logger.log(BLATHER, - ' Using session_id: %r session_passwd: %s', - client._session_id, - hexlify(client._session_passwd)) + self.logger.log( + BLATHER, + " Using session_id: %r session_passwd: %s", + client._session_id, + hexlify(client._session_passwd), + ) with self._socket_error_handling(): self._socket = self.handler.create_connection( @@ -669,12 +716,18 @@ class ConnectionHandler(object): self._socket.setblocking(0) - connect = Connect(0, client.last_zxid, client._session_timeout, - client._session_id or 0, client._session_passwd, - client.read_only) + connect = Connect( + 0, + client.last_zxid, + client._session_timeout, + client._session_id or 0, + client._session_passwd, + client.read_only, + ) connect_result, zxid = self._invoke( - client._session_timeout / 1000.0 / len(client.hosts), connect) + client._session_timeout / 1000.0 / len(client.hosts), connect + ) if connect_result.time_out <= 0: raise SessionExpiredError("Session has expired") @@ -690,14 +743,18 @@ class ConnectionHandler(object): read_timeout = negotiated_session_timeout * 2.0 / 3.0 client._session_passwd = connect_result.passwd - self.logger.log(BLATHER, - 'Session created, session_id: %r session_passwd: %s\n' - ' negotiated session timeout: %s\n' - ' connect timeout: %s\n' - ' read timeout: %s', client._session_id, - hexlify(client._session_passwd), - negotiated_session_timeout, connect_timeout, - read_timeout) + self.logger.log( + BLATHER, + "Session created, session_id: %r session_passwd: %s\n" + " negotiated session timeout: %s\n" + " connect timeout: %s\n" + " read timeout: %s", + client._session_id, + hexlify(client._session_passwd), + negotiated_session_timeout, + connect_timeout, + read_timeout, + ) if connect_result.read_only: client._session_callback(KeeperState.CONNECTED_RO) @@ -722,24 +779,22 @@ class ConnectionHandler(object): return read_timeout, connect_timeout def _authenticate_with_sasl(self, host, timeout): - """Establish a SASL authenticated connection to the server. - """ + """Establish a SASL authenticated connection to the server.""" if not PURESASL_AVAILABLE: - raise SASLException('Missing SASL support') + raise SASLException("Missing SASL support") - if 'service' not in self.sasl_options: - self.sasl_options['service'] = 'zookeeper' + if "service" not in self.sasl_options: + self.sasl_options["service"] = "zookeeper" # NOTE: Zookeeper hardcoded the domain for Digest authentication # instead of using the hostname. See # zookeeper/util/SecurityUtils.java#L74 and Server/Client # initializations. - if self.sasl_options['mechanism'] == 'DIGEST-MD5': - host = 'zk-sasl-md5' + if self.sasl_options["mechanism"] == "DIGEST-MD5": + host = "zk-sasl-md5" sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( - host=host, - **self.sasl_options + host=host, **self.sasl_options ) # Inititalize the process with an empty challenge token @@ -755,26 +810,26 @@ class ConnectionHandler(object): except puresasl.SASLError as err: six.reraise( SASLException, - SASLException('library error: %s' % err.message), - sys.exc_info()[2] + SASLException("library error: %s" % err.message), + sys.exc_info()[2], ) except puresasl.SASLProtocolException as err: six.reraise( AuthFailedError, - AuthFailedError('protocol error: %s' % err.message), - sys.exc_info()[2] + AuthFailedError("protocol error: %s" % err.message), + sys.exc_info()[2], ) except Exception as err: six.reraise( AuthFailedError, - AuthFailedError('Unknown error: %s' % err), - sys.exc_info()[2] + AuthFailedError("Unknown error: %s" % err), + sys.exc_info()[2], ) if sasl_cli.complete and not response: break elif response is None: - response = b'' + response = b"" xid = (xid % 2147483647) + 1 @@ -787,13 +842,16 @@ class ConnectionHandler(object): # Zookeeper simply drops connections with failed authentication six.reraise( AuthFailedError, - AuthFailedError('Connection dropped in SASL'), - sys.exc_info()[2] + AuthFailedError("Connection dropped in SASL"), + sys.exc_info()[2], ) if header.xid != xid: - raise RuntimeError('xids do not match, expected %r ' - 'received %r', xid, header.xid) + raise RuntimeError( + "xids do not match, expected %r " "received %r", + xid, + header.xid, + ) if header.zxid > 0: self.client.last_zxid = header.zxid @@ -801,7 +859,8 @@ class ConnectionHandler(object): if header.err: callback_exception = EXCEPTIONS[header.err]() self.logger.debug( - 'Received error(xid=%s) %r', xid, callback_exception) + "Received error(xid=%s) %r", xid, callback_exception + ) raise callback_exception challenge, _ = SASL.deserialize(buffer, offset) diff --git a/kazoo/protocol/paths.py b/kazoo/protocol/paths.py index 7fe961c..b8bf665 100644 --- a/kazoo/protocol/paths.py +++ b/kazoo/protocol/paths.py @@ -1,18 +1,18 @@ def normpath(path, trailing=False): """Normalize path, eliminating double slashes, etc.""" - comps = path.split('/') + comps = path.split("/") new_comps = [] for comp in comps: - if comp == '': + if comp == "": continue - if comp in ('.', '..'): - raise ValueError('relative paths not allowed') + if comp in (".", ".."): + raise ValueError("relative paths not allowed") new_comps.append(comp) - new_path = '/'.join(new_comps) - if trailing is True and path.endswith('/'): - new_path += '/' - if path.startswith('/') and new_path != '/': - return '/' + new_path + new_path = "/".join(new_comps) + if trailing is True and path.endswith("/"): + new_path += "/" + if path.startswith("/") and new_path != "/": + return "/" + new_path return new_path @@ -25,31 +25,32 @@ def join(a, *p): """ path = a for b in p: - if b.startswith('/'): + if b.startswith("/"): path = b - elif path == '' or path.endswith('/'): + elif path == "" or path.endswith("/"): path += b else: - path += '/' + b + path += "/" + b return path def isabs(s): """Test whether a path is absolute""" - return s.startswith('/') + return s.startswith("/") def basename(p): """Returns the final component of a pathname""" - i = p.rfind('/') + 1 + i = p.rfind("/") + 1 return p[i:] def _prefix_root(root, path, trailing=False): - """Prepend a root to a path. """ - return normpath(join(_norm_root(root), path.lstrip('/')), - trailing=trailing) + """Prepend a root to a path.""" + return normpath( + join(_norm_root(root), path.lstrip("/")), trailing=trailing + ) def _norm_root(root): - return normpath(join('/', root)) + return normpath(join("/", root)) diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 80fa4d1..c702318 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -11,16 +11,16 @@ from kazoo.security import Id # Struct objects with formats compiled -bool_struct = struct.Struct('B') -int_struct = struct.Struct('!i') -int_int_struct = struct.Struct('!ii') -int_int_long_struct = struct.Struct('!iiq') +bool_struct = struct.Struct("B") +int_struct = struct.Struct("!i") +int_int_struct = struct.Struct("!ii") +int_int_long_struct = struct.Struct("!iiq") -int_long_int_long_struct = struct.Struct('!iqiq') -long_struct = struct.Struct('!q') -multiheader_struct = struct.Struct('!iBi') -reply_header_struct = struct.Struct('!iqi') -stat_struct = struct.Struct('!qqqqiiiqiiq') +int_long_int_long_struct = struct.Struct("!iqiq") +long_struct = struct.Struct("!q") +multiheader_struct = struct.Struct("!iBi") +reply_header_struct = struct.Struct("!iqi") +stat_struct = struct.Struct("!qqqqiiiqiiq") def read_string(buffer, offset): @@ -33,7 +33,7 @@ def read_string(buffer, offset): else: index = offset offset += length - return buffer[index:index + length].decode('utf-8'), offset + return buffer[index : index + length].decode("utf-8"), offset def read_acl(bytes, offset): @@ -48,7 +48,7 @@ def write_string(bytes): if not bytes: return int_struct.pack(-1) else: - utf8_str = bytes.encode('utf-8') + utf8_str = bytes.encode("utf-8") return int_struct.pack(len(utf8_str)) + utf8_str @@ -67,38 +67,50 @@ def read_buffer(bytes, offset): else: index = offset offset += length - return bytes[index:index + length], offset + return bytes[index : index + length], offset -class Close(namedtuple('Close', '')): +class Close(namedtuple("Close", "")): type = -11 @classmethod def serialize(cls): - return b'' + return b"" + CloseInstance = Close() -class Ping(namedtuple('Ping', '')): +class Ping(namedtuple("Ping", "")): type = 11 @classmethod def serialize(cls): - return b'' + return b"" + PingInstance = Ping() -class Connect(namedtuple('Connect', 'protocol_version last_zxid_seen' - ' time_out session_id passwd read_only')): +class Connect( + namedtuple( + "Connect", + "protocol_version last_zxid_seen" + " time_out session_id passwd read_only", + ) +): type = None def serialize(self): b = bytearray() - b.extend(int_long_int_long_struct.pack( - self.protocol_version, self.last_zxid_seen, self.time_out, - self.session_id)) + b.extend( + int_long_int_long_struct.pack( + self.protocol_version, + self.last_zxid_seen, + self.time_out, + self.session_id, + ) + ) b.extend(write_buffer(self.passwd)) b.extend([1 if self.read_only else 0]) return b @@ -106,7 +118,8 @@ class Connect(namedtuple('Connect', 'protocol_version last_zxid_seen' @classmethod def deserialize(cls, bytes, offset): proto_version, timeout, session_id = int_int_long_struct.unpack_from( - bytes, offset) + bytes, offset + ) offset += int_int_long_struct.size password, offset = read_buffer(bytes, offset) @@ -115,11 +128,13 @@ class Connect(namedtuple('Connect', 'protocol_version last_zxid_seen' offset += bool_struct.size except struct.error: read_only = False - return cls(proto_version, 0, timeout, session_id, password, - read_only), offset + return ( + cls(proto_version, 0, timeout, session_id, password, read_only), + offset, + ) -class Create(namedtuple('Create', 'path data acl flags')): +class Create(namedtuple("Create", "path data acl flags")): type = 1 def serialize(self): @@ -128,8 +143,11 @@ class Create(namedtuple('Create', 'path data acl flags')): b.extend(write_buffer(self.data)) b.extend(int_struct.pack(len(self.acl))) for acl in self.acl: - b.extend(int_struct.pack(acl.perms) + - write_string(acl.id.scheme) + write_string(acl.id.id)) + b.extend( + int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + + write_string(acl.id.id) + ) b.extend(int_struct.pack(self.flags)) return b @@ -138,7 +156,7 @@ class Create(namedtuple('Create', 'path data acl flags')): return read_string(bytes, offset)[0] -class Delete(namedtuple('Delete', 'path version')): +class Delete(namedtuple("Delete", "path version")): type = 2 def serialize(self): @@ -152,7 +170,7 @@ class Delete(namedtuple('Delete', 'path version')): return True -class Exists(namedtuple('Exists', 'path watcher')): +class Exists(namedtuple("Exists", "path watcher")): type = 3 def serialize(self): @@ -167,7 +185,7 @@ class Exists(namedtuple('Exists', 'path watcher')): return stat if stat.czxid != -1 else None -class GetData(namedtuple('GetData', 'path watcher')): +class GetData(namedtuple("GetData", "path watcher")): type = 4 def serialize(self): @@ -183,7 +201,7 @@ class GetData(namedtuple('GetData', 'path watcher')): return data, stat -class SetData(namedtuple('SetData', 'path data version')): +class SetData(namedtuple("SetData", "path data version")): type = 5 def serialize(self): @@ -198,7 +216,7 @@ class SetData(namedtuple('SetData', 'path data version')): return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) -class GetACL(namedtuple('GetACL', 'path')): +class GetACL(namedtuple("GetACL", "path")): type = 6 def serialize(self): @@ -219,7 +237,7 @@ class GetACL(namedtuple('GetACL', 'path')): return acls, stat -class SetACL(namedtuple('SetACL', 'path acls version')): +class SetACL(namedtuple("SetACL", "path acls version")): type = 7 def serialize(self): @@ -227,8 +245,11 @@ class SetACL(namedtuple('SetACL', 'path acls version')): b.extend(write_string(self.path)) b.extend(int_struct.pack(len(self.acls))) for acl in self.acls: - b.extend(int_struct.pack(acl.perms) + - write_string(acl.id.scheme) + write_string(acl.id.id)) + b.extend( + int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + + write_string(acl.id.id) + ) b.extend(int_struct.pack(self.version)) return b @@ -237,7 +258,7 @@ class SetACL(namedtuple('SetACL', 'path acls version')): return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) -class GetChildren(namedtuple('GetChildren', 'path watcher')): +class GetChildren(namedtuple("GetChildren", "path watcher")): type = 8 def serialize(self): @@ -260,7 +281,7 @@ class GetChildren(namedtuple('GetChildren', 'path watcher')): return children -class Sync(namedtuple('Sync', 'path')): +class Sync(namedtuple("Sync", "path")): type = 9 def serialize(self): @@ -271,7 +292,7 @@ class Sync(namedtuple('Sync', 'path')): return read_string(buffer, offset)[0] -class GetChildren2(namedtuple('GetChildren2', 'path watcher')): +class GetChildren2(namedtuple("GetChildren2", "path watcher")): type = 12 def serialize(self): @@ -295,7 +316,7 @@ class GetChildren2(namedtuple('GetChildren2', 'path watcher')): return children, stat -class CheckVersion(namedtuple('CheckVersion', 'path version')): +class CheckVersion(namedtuple("CheckVersion", "path version")): type = 13 def serialize(self): @@ -305,14 +326,15 @@ class CheckVersion(namedtuple('CheckVersion', 'path version')): return b -class Transaction(namedtuple('Transaction', 'operations')): +class Transaction(namedtuple("Transaction", "operations")): type = 14 def serialize(self): b = bytearray() for op in self.operations: - b.extend(MultiHeader(op.type, False, -1).serialize() + - op.serialize()) + b.extend( + MultiHeader(op.type, False, -1).serialize() + op.serialize() + ) return b + multiheader_struct.pack(-1, True, -1) @classmethod @@ -327,7 +349,8 @@ class Transaction(namedtuple('Transaction', 'operations')): response = True elif header.type == SetData.type: response = ZnodeStat._make( - stat_struct.unpack_from(bytes, offset)) + stat_struct.unpack_from(bytes, offset) + ) offset += stat_struct.size elif header.type == CheckVersion.type: response = True @@ -351,7 +374,7 @@ class Transaction(namedtuple('Transaction', 'operations')): return resp -class Create2(namedtuple('Create2', 'path data acl flags')): +class Create2(namedtuple("Create2", "path data acl flags")): type = 15 def serialize(self): @@ -360,8 +383,11 @@ class Create2(namedtuple('Create2', 'path data acl flags')): b.extend(write_buffer(self.data)) b.extend(int_struct.pack(len(self.acl))) for acl in self.acl: - b.extend(int_struct.pack(acl.perms) + - write_string(acl.id.scheme) + write_string(acl.id.id)) + b.extend( + int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + + write_string(acl.id.id) + ) b.extend(int_struct.pack(self.flags)) return b @@ -372,8 +398,9 @@ class Create2(namedtuple('Create2', 'path data acl flags')): return path, stat -class Reconfig(namedtuple('Reconfig', - 'joining leaving new_members config_id')): +class Reconfig( + namedtuple("Reconfig", "joining leaving new_members config_id") +): type = 16 def serialize(self): @@ -391,15 +418,18 @@ class Reconfig(namedtuple('Reconfig', return data, stat -class Auth(namedtuple('Auth', 'auth_type scheme auth')): +class Auth(namedtuple("Auth", "auth_type scheme auth")): type = 100 def serialize(self): - return (int_struct.pack(self.auth_type) + write_string(self.scheme) + - write_string(self.auth)) + return ( + int_struct.pack(self.auth_type) + + write_string(self.scheme) + + write_string(self.auth) + ) -class SASL(namedtuple('SASL', 'challenge')): +class SASL(namedtuple("SASL", "challenge")): type = 102 def serialize(self): @@ -413,7 +443,7 @@ class SASL(namedtuple('SASL', 'challenge')): return challenge, offset -class Watch(namedtuple('Watch', 'type state path')): +class Watch(namedtuple("Watch", "type state path")): @classmethod def deserialize(cls, bytes, offset): """Given bytes and the current bytes offset, return the @@ -424,17 +454,19 @@ class Watch(namedtuple('Watch', 'type state path')): return cls(type, state, path), offset -class ReplyHeader(namedtuple('ReplyHeader', 'xid, zxid, err')): +class ReplyHeader(namedtuple("ReplyHeader", "xid, zxid, err")): @classmethod def deserialize(cls, bytes, offset): """Given bytes and the current bytes offset, return a :class:`ReplyHeader` instance and the new offset""" new_offset = offset + reply_header_struct.size - return cls._make( - reply_header_struct.unpack_from(bytes, offset)), new_offset + return ( + cls._make(reply_header_struct.unpack_from(bytes, offset)), + new_offset, + ) -class MultiHeader(namedtuple('MultiHeader', 'type done err')): +class MultiHeader(namedtuple("MultiHeader", "type done err")): def serialize(self): b = bytearray() b.extend(int_struct.pack(self.type)) diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 66a8425..480a586 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -27,6 +27,7 @@ class KazooState(object): use, they can be considered lost as well. """ + SUSPENDED = "SUSPENDED" CONNECTED = "CONNECTED" LOST = "LOST" @@ -60,12 +61,13 @@ class KeeperState(object): gone. """ - AUTH_FAILED = 'AUTH_FAILED' - CONNECTED = 'CONNECTED' - CONNECTED_RO = 'CONNECTED_RO' - CONNECTING = 'CONNECTING' - CLOSED = 'CLOSED' - EXPIRED_SESSION = 'EXPIRED_SESSION' + + AUTH_FAILED = "AUTH_FAILED" + CONNECTED = "CONNECTED" + CONNECTED_RO = "CONNECTED_RO" + CONNECTING = "CONNECTING" + CLOSED = "CLOSED" + EXPIRED_SESSION = "EXPIRED_SESSION" class EventType(object): @@ -98,22 +100,24 @@ class EventType(object): The connection state has been altered. """ - CREATED = 'CREATED' - DELETED = 'DELETED' - CHANGED = 'CHANGED' - CHILD = 'CHILD' - NONE = 'NONE' + + CREATED = "CREATED" + DELETED = "DELETED" + CHANGED = "CHANGED" + CHILD = "CHILD" + NONE = "NONE" + EVENT_TYPE_MAP = { -1: EventType.NONE, 1: EventType.CREATED, 2: EventType.DELETED, 3: EventType.CHANGED, - 4: EventType.CHILD + 4: EventType.CHILD, } -class WatchedEvent(namedtuple('WatchedEvent', ('type', 'state', 'path'))): +class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): """A change on ZooKeeper that a Watcher is able to respond to. The :class:`WatchedEvent` includes exactly what happened, the @@ -137,7 +141,7 @@ class WatchedEvent(namedtuple('WatchedEvent', ('type', 'state', 'path'))): """ -class Callback(namedtuple('Callback', ('type', 'func', 'args'))): +class Callback(namedtuple("Callback", ("type", "func", "args"))): """A callback that is handed to a handler for dispatch :param type: Type of the callback, currently is only 'watch' @@ -147,9 +151,14 @@ class Callback(namedtuple('Callback', ('type', 'func', 'args'))): """ -class ZnodeStat(namedtuple('ZnodeStat', 'czxid mzxid ctime mtime version' - ' cversion aversion ephemeralOwner dataLength' - ' numChildren pzxid')): +class ZnodeStat( + namedtuple( + "ZnodeStat", + "czxid mzxid ctime mtime version" + " cversion aversion ephemeralOwner dataLength" + " numChildren pzxid", + ) +): """A ZnodeStat structure with convenience properties When getting the value of a znode from Zookeeper, the properties for @@ -206,6 +215,7 @@ class ZnodeStat(namedtuple('ZnodeStat', 'czxid mzxid ctime mtime version' The number of children of this znode. """ + @property def acl_version(self): return self.aversion |