From 74b044c3e729cbb6e65126e993b90902114f0293 Mon Sep 17 00:00:00 2001 From: Rajiv Bakulesh Shah Date: Mon, 11 Nov 2019 14:56:39 -0800 Subject: Add equality test on Redis client and conn pool (#1240) Add equality test on Redis client and connection pool --- redis/client.py | 6 ++++++ redis/connection.py | 6 ++++++ tests/test_client.py | 27 +++++++++++++++++++++++++++ tests/test_connection_pool.py | 41 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+) create mode 100644 tests/test_client.py diff --git a/redis/client.py b/redis/client.py index 5adaa7c..d77ef01 100755 --- a/redis/client.py +++ b/redis/client.py @@ -707,6 +707,12 @@ class Redis(object): def __repr__(self): return "%s<%s>" % (type(self).__name__, repr(self.connection_pool)) + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.connection_pool == other.connection_pool + ) + def set_response_callback(self, command, callback): "Set a custom Response Callback" self.response_callbacks[command] = callback diff --git a/redis/connection.py b/redis/connection.py index feea041..44a9922 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1044,6 +1044,12 @@ class ConnectionPool(object): repr(self.connection_class(**self.connection_kwargs)), ) + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.connection_kwargs == other.connection_kwargs + ) + def reset(self): self.pid = os.getpid() self._created_connections = 0 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..e8f79b1 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,27 @@ +import redis + + +class TestClient(object): + def test_client_equality(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6379/9') + assert r1 == r2 + + def test_clients_unequal_if_different_types(self): + r = redis.Redis.from_url('redis://localhost:6379/9') + assert r != 0 + + def test_clients_unequal_if_different_hosts(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://127.0.0.1:6379/9') + assert r1 != r2 + + def test_clients_unequal_if_different_ports(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6380/9') + assert r1 != r2 + + def test_clients_unequal_if_different_dbs(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6380/10') + assert r1 != r2 diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index f580f71..406b5db 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -77,6 +77,47 @@ class TestConnectionPool(object): expected = 'ConnectionPool>' assert repr(pool) == expected + def test_pool_equality(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + assert pool1 == pool2 + + def test_pools_unequal_if_different_types(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + assert pool != 0 + + def test_pools_unequal_if_different_hosts(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': '127.0.0.1', 'port': 6379, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + + def test_pools_unequal_if_different_ports(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': 'localhost', 'port': 6380, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + + def test_pools_unequal_if_different_dbs(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': 'localhost', 'port': 6379, 'db': 2} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + class TestBlockingConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): -- cgit v1.2.1