diff options
Diffstat (limited to 'kazoo/tests/test_connection.py')
-rw-r--r-- | kazoo/tests/test_connection.py | 405 |
1 files changed, 405 insertions, 0 deletions
diff --git a/kazoo/tests/test_connection.py b/kazoo/tests/test_connection.py new file mode 100644 index 0000000..032b94b --- /dev/null +++ b/kazoo/tests/test_connection.py @@ -0,0 +1,405 @@ +from collections import namedtuple, deque +import os +import threading +import time +import uuid +from unittest.mock import patch +import struct +import sys + +import pytest + +from kazoo.exceptions import ConnectionLoss +from kazoo.protocol.serialization import ( + Connect, + int_struct, + write_string, +) +from kazoo.protocol.states import KazooState +from kazoo.protocol.connection import _CONNECTION_DROP +from kazoo.testing import KazooTestCase +from kazoo.tests.util import wait, CI_ZK_VERSION, CI + + +class Delete(namedtuple("Delete", "path version")): + type = 2 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.version)) + return b + + @classmethod + def deserialize(self, bytes, offset): + raise ValueError("oh my") + + +class TestConnectionHandler(KazooTestCase): + def test_bad_deserialization(self): + async_object = self.client.handler.async_result() + self.client._queue.append( + (Delete(self.client.chroot, -1), async_object) + ) + self.client._connection._write_sock.send(b"\0") + + with pytest.raises(ValueError): + async_object.get() + + def test_with_bad_sessionid(self): + ev = threading.Event() + + def expired(state): + if state == KazooState.CONNECTED: + ev.set() + + password = os.urandom(16) + client = self._get_client(client_id=(82838284824, password)) + client.add_listener(expired) + client.start() + try: + ev.wait(15) + assert ev.is_set() + finally: + client.stop() + + def test_connection_read_timeout(self): + client = self.client + ev = threading.Event() + path = "/" + uuid.uuid4().hex + handler = client.handler + _select = handler.select + _socket = client._connection._socket + + def delayed_select(*args, **kwargs): + result = _select(*args, **kwargs) + if len(args[0]) == 1 and _socket in args[0]: + # for any socket read, simulate a timeout + return [], [], [] + return result + + def back(state): + if state == KazooState.CONNECTED: + ev.set() + + client.add_listener(back) + client.create(path, b"1") + try: + handler.select = delayed_select + with pytest.raises(ConnectionLoss): + client.get(path) + finally: + handler.select = _select + # the client reconnects automatically + ev.wait(5) + assert ev.is_set() + assert client.get(path)[0] == b"1" + + def test_connection_write_timeout(self): + client = self.client + ev = threading.Event() + path = "/" + uuid.uuid4().hex + handler = client.handler + _select = handler.select + _socket = client._connection._socket + + def delayed_select(*args, **kwargs): + result = _select(*args, **kwargs) + if _socket in args[1]: + # for any socket write, simulate a timeout + return [], [], [] + return result + + def back(state): + if state == KazooState.CONNECTED: + ev.set() + + client.add_listener(back) + + try: + handler.select = delayed_select + with pytest.raises(ConnectionLoss): + client.create(path) + finally: + handler.select = _select + # the client reconnects automatically + ev.wait(5) + assert ev.is_set() + assert client.exists(path) is None + + def test_connection_deserialize_fail(self): + client = self.client + ev = threading.Event() + path = "/" + uuid.uuid4().hex + handler = client.handler + _select = handler.select + _socket = client._connection._socket + + def delayed_select(*args, **kwargs): + result = _select(*args, **kwargs) + if _socket in args[1]: + # for any socket write, simulate a timeout + return [], [], [] + return result + + def back(state): + if state == KazooState.CONNECTED: + ev.set() + + client.add_listener(back) + + deserialize_ev = threading.Event() + + def bad_deserialize(_bytes, offset): + deserialize_ev.set() + raise struct.error() + + # force the connection to die but, on reconnect, cause the + # server response to be non-deserializable. ensure that the client + # continues to retry. This partially reproduces a rare bug seen + # in production. + + with patch.object(Connect, "deserialize") as mock_deserialize: + mock_deserialize.side_effect = bad_deserialize + try: + handler.select = delayed_select + with pytest.raises(ConnectionLoss): + client.create(path) + finally: + handler.select = _select + # the client reconnects automatically but the first attempt will + # hit a deserialize failure. wait for that. + deserialize_ev.wait(5) + assert deserialize_ev.is_set() + + # this time should succeed + ev.wait(5) + assert ev.is_set() + assert client.exists(path) is None + + def test_connection_close(self): + with pytest.raises(Exception): + self.client.close() + self.client.stop() + self.client.close() + + # should be able to restart + self.client.start() + + def test_connection_sock(self): + client = self.client + read_sock = client._connection._read_sock + write_sock = client._connection._write_sock + + assert read_sock is not None + assert write_sock is not None + + # stop client and socket should not yet be closed + client.stop() + assert read_sock is not None + assert write_sock is not None + + read_sock.getsockname() + write_sock.getsockname() + + # close client, and sockets should be closed + client.close() + + # Todo check socket closing + + # start client back up. should get a new, valid socket + client.start() + read_sock = client._connection._read_sock + write_sock = client._connection._write_sock + + assert read_sock is not None + assert write_sock is not None + read_sock.getsockname() + write_sock.getsockname() + + def test_dirty_sock(self): + client = self.client + read_sock = client._connection._read_sock + write_sock = client._connection._write_sock + + # add a stray byte to the socket and ensure that doesn't + # blow up client. simulates case where some error leaves + # a byte in the socket which doesn't correspond to the + # request queue. + write_sock.send(b"\0") + + # eventually this byte should disappear from socket + wait(lambda: client.handler.select([read_sock], [], [], 0)[0] == []) + + +class TestConnectionDrop(KazooTestCase): + def test_connection_dropped(self): + ev = threading.Event() + + def back(state): + if state == KazooState.CONNECTED: + ev.set() + + # create a node with a large value and stop the ZK node + path = "/" + uuid.uuid4().hex + self.client.create(path) + self.client.add_listener(back) + result = self.client.set_async(path, b"a" * 1000 * 1024) + self.client._call(_CONNECTION_DROP, None) + + with pytest.raises(ConnectionLoss): + result.get() + # we have a working connection to a new node + ev.wait(30) + assert ev.is_set() + + +class TestReadOnlyMode(KazooTestCase): + def setUp(self): + os.environ["ZOOKEEPER_LOCAL_SESSION_RO"] = "true" + self.setup_zookeeper() + skip = False + if CI_ZK_VERSION and CI_ZK_VERSION < (3, 4): + skip = True + elif CI_ZK_VERSION and CI_ZK_VERSION >= (3, 4): + skip = False + else: + ver = self.client.server_version() + if ver[1] < 4: + skip = True + if skip: + pytest.skip("Must use Zookeeper 3.4 or above") + + def tearDown(self): + self.client.stop() + os.environ.pop("ZOOKEEPER_LOCAL_SESSION_RO", None) + + def test_read_only(self): + from kazoo.exceptions import NotReadOnlyCallError + from kazoo.protocol.states import KeeperState + + if CI: + # force some wait to make sure the data produced during the + # `setUp()` step are replicaed to all zk members + # if not done the `get_children()` test may fail because the + # node does not exist on the node that we will keep alive + time.sleep(15) + # do not keep the client started in the `setUp` step alive + self.client.stop() + client = self._get_client(connection_retry=None, read_only=True) + states = [] + ev = threading.Event() + + @client.add_listener + def listen(state): + states.append(state) + if client.client_state == KeeperState.CONNECTED_RO: + ev.set() + + client.start() + try: + # stopping both nodes at the same time + # else the test seems flaky when on CI hosts + zk_stop_threads = [] + zk_stop_threads.append( + threading.Thread(target=self.cluster[1].stop, daemon=True) + ) + zk_stop_threads.append( + threading.Thread(target=self.cluster[2].stop, daemon=True) + ) + for thread in zk_stop_threads: + thread.start() + for thread in zk_stop_threads: + thread.join() + # stopping the client is *mandatory*, else the client might try to + # reconnect using a xid that the server may endlessly refuse + # restarting the client makes sure the xid gets reset + client.stop() + client.start() + ev.wait(15) + assert ev.is_set() + assert client.client_state == KeeperState.CONNECTED_RO + + # Test read only command + assert client.get_children("/") == [] + + # Test error with write command + with pytest.raises(NotReadOnlyCallError): + client.create("/fred") + + # Wait for a ping + time.sleep(15) + finally: + client.remove_listener(listen) + self.cluster[1].run() + self.cluster[2].run() + + +class TestUnorderedXids(KazooTestCase): + def setUp(self): + super(TestUnorderedXids, self).setUp() + + self.connection = self.client._connection + self.connection_routine = self.connection._connection_routine + + self._pending = self.client._pending + self.client._pending = _naughty_deque() + + def tearDown(self): + self.client._pending = self._pending + super(TestUnorderedXids, self).tearDown() + + def _get_client(self, **kwargs): + # overrides for patching zk_loop + c = KazooTestCase._get_client(self, **kwargs) + self._zk_loop = c._connection.zk_loop + self._zk_loop_errors = [] + c._connection.zk_loop = self._zk_loop_func + return c + + def _zk_loop_func(self, *args, **kwargs): + # patched zk_loop which will catch and collect all RuntimeError + try: + self._zk_loop(*args, **kwargs) + except RuntimeError as e: + self._zk_loop_errors.append(e) + + def test_xids_mismatch(self): + from kazoo.protocol.states import KeeperState + + ev = threading.Event() + error_stack = [] + + @self.client.add_listener + def listen(state): + if self.client.client_state == KeeperState.CLOSED: + ev.set() + + def log_exception(*args): + error_stack.append((args, sys.exc_info())) + + self.connection.logger.exception = log_exception + + ev.clear() + with pytest.raises(RuntimeError): + self.client.get_children("/") + + ev.wait() + assert self.client.connected is False + assert self.client.state == "LOST" + assert self.client.client_state == KeeperState.CLOSED + + args, exc_info = error_stack[-1] + assert args == ("Unhandled exception in connection loop",) + assert exc_info[0] == RuntimeError + + self.client.handler.sleep_func(0.2) + assert not self.connection_routine.is_alive() + assert len(self._zk_loop_errors) == 1 + assert self._zk_loop_errors[0] == exc_info[1] + + +class _naughty_deque(deque): + def append(self, s): + request, async_object, xid = s + return deque.append(self, (request, async_object, xid + 1)) # +1s |