diff options
author | Matus Valo <matusvalo@users.noreply.github.com> | 2021-03-07 20:10:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-07 20:10:26 +0100 |
commit | 44dff855d98732f7a8e60e5ca9d121b2f08c5c9a (patch) | |
tree | e8bb4a5b75856fe4582044799f409be602ad0250 | |
parent | 39ccacf0dc986ee5842644f62da91a71fc2a768f (diff) | |
download | kombu-revert-1307-revert-1132-redis-improvements.tar.gz |
Revert "Revert "Port of redis code improvements from prior revision (#1132)" (#1307)"revert-1307-revert-1132-redis-improvements
This reverts commit d3ded0069cce8111bea24d4bb360cc2973be3804.
-rw-r--r-- | AUTHORS | 1 | ||||
-rw-r--r-- | kombu/transport/redis.py | 29 | ||||
-rw-r--r-- | requirements/test.txt | 1 | ||||
-rw-r--r-- | t/unit/transport/test_redis.py | 810 |
4 files changed, 581 insertions, 260 deletions
@@ -91,6 +91,7 @@ Marcin Lulek (ergo) <info@webreactor.eu> Marcin Puhacz <marcin.puhacz@gmail.com> Mark Lavin <mlavin@caktusgroup.com> markow <markow@red-sky.pl> +Matt Davis <matteius@gmail.com> Matt Wise <wise@wiredgeek.net> Maxime Rouyrre <rouyrre+git@gmail.com> mdk <luc.mdk@gmail.com> diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 68f5a006..dc0628e5 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -766,7 +766,8 @@ class Channel(virtual.Channel): queues = self._queue_cycle.consume(len(self.active_queues)) if not queues: return - keys = [self._q_for_pri(queue, pri) for pri in self.priority_steps + _q_for_pri = self._queue_for_priority + keys = [_q_for_pri(queue, pri) for pri in self.priority_steps for queue in queues] + [timeout or 0] self._in_poll = self.client.connection self.client.connection.send_command('BRPOP', *keys) @@ -802,7 +803,8 @@ class Channel(virtual.Channel): def _get(self, queue): with self.conn_or_acquire() as client: for pri in self.priority_steps: - item = client.rpop(self._q_for_pri(queue, pri)) + queue_name = self._queue_for_priority(queue, pri) + item = client.rpop(queue_name) if item: return loads(bytes_to_str(item)) raise Empty() @@ -811,14 +813,15 @@ class Channel(virtual.Channel): with self.conn_or_acquire() as client: with client.pipeline() as pipe: for pri in self.priority_steps: - pipe = pipe.llen(self._q_for_pri(queue, pri)) + queue_name = self._queue_for_priority(queue, pri) + pipe = pipe.llen(queue_name) sizes = pipe.execute() - return sum(size for size in sizes - if isinstance(size, numbers.Integral)) + return sum(s for s in sizes + if isinstance(s, numbers.Integral)) - def _q_for_pri(self, queue, pri): - pri = self.priority(pri) - if pri: + def _queue_for_priority(self, queue, pri): + queue_priority = self.priority(pri) + if queue_priority: return f"{queue}{self.sep}{pri}" return queue @@ -831,7 +834,7 @@ class Channel(virtual.Channel): pri = self._get_message_priority(message, reverse=False) with self.conn_or_acquire() as client: - client.lpush(self._q_for_pri(queue, pri), dumps(message)) + client.lpush(self._queue_for_priority(queue, pri), dumps(message)) def _put_fanout(self, exchange, message, routing_key, **kwargs): """Deliver fanout message.""" @@ -866,14 +869,14 @@ class Channel(virtual.Channel): queue or ''])) with client.pipeline() as pipe: for pri in self.priority_steps: - pipe = pipe.delete(self._q_for_pri(queue, pri)) + pipe = pipe.delete(self._queue_for_priority(queue, pri)) pipe.execute() def _has_queue(self, queue, **kwargs): with self.conn_or_acquire() as client: with client.pipeline() as pipe: for pri in self.priority_steps: - pipe = pipe.exists(self._q_for_pri(queue, pri)) + pipe = pipe.exists(self._queue_for_priority(queue, pri)) return any(pipe.execute()) def get_table(self, exchange): @@ -888,8 +891,8 @@ class Channel(virtual.Channel): with self.conn_or_acquire() as client: with client.pipeline() as pipe: for pri in self.priority_steps: - priq = self._q_for_pri(queue, pri) - pipe = pipe.llen(priq).delete(priq) + priority_queue = self._queue_for_priority(queue, pri) + pipe = pipe.llen(priority_queue).delete(priority_queue) sizes = pipe.execute() return sum(sizes[::2]) diff --git a/requirements/test.txt b/requirements/test.txt index 4c6f990b..be1a407a 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -3,3 +3,4 @@ case>=1.5.2 pytest<=5.3.5 pytest-sugar Pyro4 +fakeredis==1.1.0 diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index 8a73db43..a39738e6 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -1,9 +1,12 @@ import pytest - +import fakeredis +import redis import socket -import types +from array import array +from case import ANY, ContextMock, Mock, call, mock, skip, patch from collections import defaultdict +from contextlib import contextmanager from itertools import count from queue import Empty, Queue as _Queue from unittest.mock import ANY, Mock, call, patch @@ -16,6 +19,16 @@ from kombu.utils import eventio # patch poll from kombu.utils.json import dumps +_fake_redis_client = None + + +def _get_fake_redis_client(): + global _fake_redis_client + if _fake_redis_client is None: + _fake_redis_client = FakeRedisClient() + return _fake_redis_client + + class _poll(eventio._select): def register(self, fd, flags): @@ -35,33 +48,177 @@ eventio.poll = _poll pytest.importorskip('redis') # must import after poller patch, pep8 complains -from kombu.transport import redis # noqa +from kombu.transport import redis as kombu_redis # noqa class ResponseError(Exception): pass -class Client: +class DummyParser: + + def __init__(self, *args, **kwargs): + self.socket_read_size = 1 + self.encoder = None + self._sock = None + self._buffer = None + + def on_disconnect(self): + self.socket_read_size = 1 + self.encoder = None + self._sock = None + self._buffer = None + + def on_connect(self, connection): + pass + + +class FakeRedisSocket(fakeredis._server.FakeSocket): + blocking = True + filenos = count(30) + + def __init__(self, server): + super().__init__(server) + self._server = server + self._fileno = next(self.filenos) + self.data = [] + self.connection = None + self.channel = None + self.transport_options = {} + self.hostname = None + self.port = None + self.password = None + self.virtual_host = '/' + self.max_connections = 10 + self.ssl = None + + +class FakeRedisConnection(fakeredis.FakeConnection): + disconnected = False + default_port = 6379 + channel_max = 65535 + + def __init__(self, client, server, **kwargs): + kwargs['parser_class'] = DummyParser + super(fakeredis.FakeConnection, self).__init__(**kwargs) + if client is None: + client = _get_fake_redis_client() + self.client = client + if server is None: + server = fakeredis.FakeServer() + self._server = server + self._sock = FakeRedisSocket(server=server) + try: + self.on_connect() + except redis.exceptions.RedisError: + # clean up after any error in on_connect + self.disconnect() + raise + self._parser = () + self._avail_channel_ids = array( + virtual.base.ARRAY_TYPE_H, range(self.channel_max, 0, -1), + ) + self.cycle = kombu_redis.MultiChannelPoller() + conn_errs, channel_errs = kombu_redis.get_redis_error_classes() + self.connection_errors, self.channel_errors = conn_errs, channel_errs + + def disconnect(self): + self.disconnected = True + + +class FakeRedisConnectionPool(redis.ConnectionPool): + def __init__(self, connection_class, max_connections=None, + **connection_kwargs): + connection_class = FakeRedisConnection + connection_kwargs['client'] = None + connection_kwargs['server'] = None + self._connections = [] + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs + ) + + def get_connection(self, *args, **kwargs): + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + def release(self, connection): + pass + + +class FakeRedisClient(fakeredis.FakeStrictRedis): queues = {} - sets = defaultdict(set) - hashes = defaultdict(dict) shard_hint = None def __init__(self, db=None, port=None, connection_pool=None, **kwargs): self._called = [] self._connection = None self.bgsave_raises_ResponseError = False - self.connection = self._sconnection(self) + self.server = server = fakeredis.FakeServer() + connection_pool = FakeRedisConnectionPool(FakeRedisConnection) + self.connection_pool = connection_pool + super().__init__( + db=db, port=port, connection_pool=connection_pool, server=server) + self.connection = FakeRedisConnection(self, server) + self.response_callbacks = dict() + + def __del__(self, key=None): + if key: + self.delete(key) + + def ping(self, *args, **kwargs): + return True + + def pipeline(self): + return FakePipeline(self.server, self.connection_pool, [], '1234', '') + + def set_response_callback(self, command, callback): + pass - def bgsave(self): - self._called.append('BGSAVE') - if self.bgsave_raises_ResponseError: - raise ResponseError() + def _new_queue(self, queue, auto_delete=False, **kwargs): + self.queues[queue] = _Queue() + if auto_delete: + self.auto_delete_queues.add(queue) + + def rpop(self, key): + try: + return self.queues[key].get_nowait() + except (KeyError, Empty): + pass + + def llen(self, key): + try: + return self.queues[key].qsize() + except KeyError: + return 0 + + def lpush(self, key, value): + self.queues[key].put_nowait(value) + + def pubsub(self, *args, **kwargs): + self.connection_pool = FakeRedisConnectionPool(FakeRedisConnection) + return self def delete(self, key): self.queues.pop(key, None) + +class FakeRedisClientLite: + """The original FakeRedis client from Kombu to support the + Producer/Consumer TestCases, preferred to use FakeRedisClient.""" + queues = {} + sets = defaultdict(set) + hashes = defaultdict(dict) + shard_hint = None + + def __init__(self, db=None, port=None, connection_pool=None, **kwargs): + self._called = [] + self._connection = None + self.bgsave_raises_ResponseError = False + self.connection = self._sconnection(self) + def exists(self, key): return key in self.queues or key in self.sets @@ -78,18 +235,19 @@ class Client: self.sets[key].add(member) def zadd(self, key, *args): - if redis.redis.VERSION[0] >= 3: - (mapping,) = args - for item in mapping: - self.sets[key].add(item) - else: - # TODO: remove me when we drop support for Redis-py v2 - (score1, member1) = args - self.sets[key].add(member1) + (mapping,) = args + for item in mapping: + self.sets[key].add(item) def smembers(self, key): return self.sets.get(key, set()) + def sismember(self, name, value): + return value in self.sets.get(name, set()) + + def scard(self, key): + return len(self.sets.get(key, set())) + def ping(self, *args, **kwargs): return True @@ -137,7 +295,7 @@ class Client: return k in self._called def pipeline(self): - return Pipeline(self) + return FakePipelineLite(self) def encode(self, value): return str(value) @@ -172,22 +330,8 @@ class Client: def send_command(self, cmd, *args): self._sock.data.append((cmd, args)) - def info(self): - return {'foo': 1} - - def pubsub(self, *args, **kwargs): - connection = self.connection - - class ConnectionPool: - def get_connection(self, *args, **kwargs): - return connection - self.connection_pool = ConnectionPool() - - return self - - -class Pipeline: +class FakePipelineLite: def __init__(self, client): self.client = client @@ -215,10 +359,84 @@ class Pipeline: return [fun(*args, **kwargs) for fun, args, kwargs in stack] -class Channel(redis.Channel): +class FakePipeline(redis.client.Pipeline): + + def __init__(self, server, connection_pool, + response_callbacks, transaction, shard_hint): + if not server: + server = fakeredis.FakeServer() + self._server = server + correct_pool_instance = isinstance( + connection_pool, FakeRedisConnectionPool) + if connection_pool is not None and not correct_pool_instance: + connection_pool = FakeRedisConnectionPool(FakeRedisConnection) + self.connection_pool = connection_pool + self.connection = FakeRedisConnection(self, server) + self.client = connection_pool.get_connection().client + self.response_callbacks = response_callbacks + self.transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.reset() + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + def __getattr__(self, key): + if key not in self.__dict__: + + def _add(*args, **kwargs): + self.command_stack.append( + (getattr(self.connection.client, key), args, kwargs)) + return self + + return _add + return self.__dict__[key] + + def reset(self): + # Do nothing with the real connection + self.command_stack = [] + self.scripts = set() + + def execute(self): + stack = list(self.command_stack) + all_cmds = self.connection.pack_commands( + [args for args, _ in self.command_stack]) + self.connection.send_packed_command(all_cmds) + + response = [] + for cmd in all_cmds: + try: + response.append( + self.parse_response(self.connection, cmd)) + except ResponseError: + import sys + response.append(sys.exc_info()[1]) + + self.raise_first_error(self.command_stack, response) + results = [] + for t, kwargs in stack: + redis_func_name = t[0] + redis_func_name = redis_func_name.lower() + if redis_func_name == 'del': + redis_func_name = 'delete' + args = t[1:] + fun = getattr(self.client, redis_func_name) + r = fun(*args, **kwargs) + results.append(r) + + self.command_stack[:] = [] + self.reset() + return results + + +class FakeRedisKombuChannelLite(kombu_redis.Channel): def _get_client(self): - return Client + return FakeRedisClientLite def _get_pool(self, asynchronous=False): return Mock() @@ -228,41 +446,97 @@ class Channel(redis.Channel): def _new_queue(self, queue, **kwargs): for pri in self.priority_steps: - self.client._new_queue(self._q_for_pri(queue, pri)) + self.client._new_queue(self._queue_for_priority(queue, pri)) def pipeline(self): - return Pipeline(Client()) + return FakePipelineLite(FakeRedisClientLite()) -class Transport(redis.Transport): - Channel = Channel +class FakeRedisKombuChannel(kombu_redis.Channel): + _fanout_queues = {} + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _get_client(self): + return FakeRedisClient + + def _create_client(self, asynchronous=False): + global _fake_redis_client + if _fake_redis_client is None: + _fake_redis_client = self._get_client()() + return _fake_redis_client + + @contextmanager + def conn_or_acquire(self, client=None): + if client: + yield client + else: + yield self._create_client() + + def _get_pool(self, asynchronous=False): + params = self._connparams(asynchronous=asynchronous) + self.keyprefix_fanout = self.keyprefix_fanout.format(db=params['db']) + return FakeRedisConnectionPool(**params) + + def _get_response_error(self): + return ResponseError + + def _new_queue(self, queue, **kwargs): + for pri in self.priority_steps: + self.client._new_queue(self._queue_for_priority(queue, pri)) + + def pipeline(self): + yield _get_fake_redis_client().pipeline() + + def basic_publish(self, message, exchange='', routing_key='', **kwargs): + self._inplace_augment_message(message, exchange, routing_key) + # anon exchange: routing_key is the destination queue + return self._put(routing_key, message, **kwargs) + + +class FakeRedisKombuTransportLite(kombu_redis.Transport): + Channel = FakeRedisKombuChannelLite + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def _get_errors(self): return ((KeyError,), (IndexError,)) -class test_Channel: +class FakeRedisKombuTransport(FakeRedisKombuTransportLite): + Channel = FakeRedisKombuChannel + + +@skip.unless_module('redis') +class test_RedisChannel: def setup(self): self.connection = self.create_connection() self.channel = self.connection.default_channel + def teardown(self): + self.connection = None + self.channel = None + global _fake_redis_client + _fake_redis_client = None + def create_connection(self, **kwargs): kwargs.setdefault('transport_options', {'fanout_patterns': True}) - return Connection(transport=Transport, **kwargs) + return Connection(transport=FakeRedisKombuTransport, **kwargs) def _get_one_delivery_tag(self, n='test_uniq_tag'): - with self.create_connection() as conn1: - chan = conn1.default_channel - chan.exchange_declare(n) - chan.queue_declare(n) - chan.queue_bind(n, n, n) - msg = chan.prepare_message('quick brown fox') - chan.basic_publish(msg, n, n) - payload = chan._get(n) - assert payload - pymsg = chan.message_to_python(payload) - return pymsg.delivery_tag + chan = self.connection.default_channel + chan.exchange_declare(n) + chan.queue_declare(n) + chan.queue_bind(n, n, n) + msg = chan.prepare_message('quick brown fox') + chan.basic_publish(msg, n, n) + payload = chan._get(n) + assert payload + pymsg = chan.message_to_python(payload) + return pymsg.delivery_tag def test_delivery_tag_is_uuid(self): seen = set() @@ -275,9 +549,10 @@ class test_Channel: assert len(tag) == 36 def test_disable_ack_emulation(self): - conn = Connection(transport=Transport, transport_options={ - 'ack_emulation': False, - }) + conn = Connection( + transport=FakeRedisKombuTransport, + transport_options={'ack_emulation': False} + ) chan = conn.channel() assert not chan.ack_emulation @@ -288,7 +563,7 @@ class test_Channel: pool_at_init = [pool] client = Mock(name='client') - class XChannel(Channel): + class XChannel(FakeRedisKombuChannel): def __init__(self, *args, **kwargs): self._pool = pool_at_init[0] @@ -297,7 +572,12 @@ class test_Channel: def _get_client(self): return lambda *_, **__: client - class XTransport(Transport): + def _create_client(self, asynchronous=False): + if asynchronous: + return self.Client(connection_pool=self.async_pool) + return self.Client(connection_pool=self.pool) + + class XTransport(FakeRedisKombuTransport): Channel = XChannel conn = Connection(transport=XTransport) @@ -404,14 +684,15 @@ class test_Channel: client = client() def pipe(*args, **kwargs): - return Pipeline(client) + return FakePipelineLite(client) + client.pipeline = pipe client.zrevrangebyscore.return_value = [ (1, 10), (2, 20), (3, 30), ] - qos = redis.QoS(self.channel) + qos = kombu_redis.QoS(self.channel) restore = qos.restore_by_tag = Mock(name='restore_by_tag') qos._vrestore_count = 1 qos.restore_visible() @@ -433,14 +714,13 @@ class test_Channel: assert qos._vrestore_count == 1 qos._vrestore_count = 0 - client.setnx.side_effect = redis.MutexHeld() + client.setnx.side_effect = kombu_redis.MutexHeld() qos.restore_visible() def test_basic_consume_when_fanout_queue(self): self.channel.exchange_declare(exchange='txconfan', type='fanout') self.channel.queue_declare(queue='txconfanq') self.channel.queue_bind(queue='txconfanq', exchange='txconfan') - assert 'txconfanq' in self.channel._fanout_queues self.channel.basic_consume('txconfanq', False, None, 1) assert 'txconfanq' in self.channel.active_fanout_queues @@ -566,7 +846,7 @@ class test_Channel: c = self.channel.client = Mock() c.parse_response.return_value = None - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): self.channel._brpop_read() def test_poll_error(self): @@ -605,19 +885,19 @@ class test_Channel: self.channel._put('george', msg1) client().lpush.assert_called_with( - self.channel._q_for_pri('george', 3), dumps(msg1), + self.channel._queue_for_priority('george', 3), dumps(msg1), ) msg2 = {'properties': {'priority': 313}} self.channel._put('george', msg2) client().lpush.assert_called_with( - self.channel._q_for_pri('george', 9), dumps(msg2), + self.channel._queue_for_priority('george', 9), dumps(msg2), ) msg3 = {'properties': {}} self.channel._put('george', msg3) client().lpush.assert_called_with( - self.channel._q_for_pri('george', 0), dumps(msg3), + self.channel._queue_for_priority('george', 0), dumps(msg3), ) def test_delete(self): @@ -629,10 +909,12 @@ class test_Channel: x._delete('queue', 'exchange', 'routing_key', None) delete.assert_has_calls([ - call(x._q_for_pri('queue', pri)) for pri in redis.PRIORITY_STEPS + call(x._queue_for_priority('queue', pri)) + for pri in kombu_redis.PRIORITY_STEPS ]) - srem.assert_called_with(x.keyprefix_queue % ('exchange',), - x.sep.join(['routing_key', '', 'queue'])) + srem.assert_called_with( + x.keyprefix_queue % ('exchange',), + x.sep.join(['routing_key', '', 'queue'])) def test_has_queue(self): self.channel._create_client = Mock() @@ -641,8 +923,8 @@ class test_Channel: exists.return_value = True assert self.channel._has_queue('foo') exists.assert_has_calls([ - call(self.channel._q_for_pri('foo', pri)) - for pri in redis.PRIORITY_STEPS + call(self.channel._queue_for_priority('foo', pri)) + for pri in kombu_redis.PRIORITY_STEPS ]) exists.return_value = False @@ -686,9 +968,11 @@ class test_Channel: assert self.channel._connparams()['db'] == 124 def test_new_queue_with_auto_delete(self): - redis.Channel._new_queue(self.channel, 'george', auto_delete=False) + kombu_redis.Channel._new_queue( + self.channel, 'george', auto_delete=False) assert 'george' not in self.channel.auto_delete_queues - redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True) + kombu_redis.Channel._new_queue( + self.channel, 'elaine', auto_delete=True) assert 'elaine' in self.channel.auto_delete_queues def test_connparams_regular_hostname(self): @@ -731,22 +1015,21 @@ class test_Channel: cycle.rotate('elaine') def test_get_client(self): - import redis as R - KombuRedis = redis.Channel._get_client(self.channel) - assert KombuRedis + kombu_redis_client = kombu_redis.Channel._get_client(self.channel) + assert kombu_redis_client - Rv = getattr(R, 'VERSION', None) + redis_version = getattr(redis, 'VERSION', None) try: - R.VERSION = (2, 4, 0) + redis.VERSION = (2, 4, 0) with pytest.raises(VersionMismatch): - redis.Channel._get_client(self.channel) + kombu_redis.Channel._get_client(self.channel) finally: - if Rv is not None: - R.VERSION = Rv + if redis_version is not None: + redis.VERSION = redis_version def test_get_response_error(self): - from redis.exceptions import ResponseError - assert redis.Channel._get_response_error(self.channel) is ResponseError + kombu_error = kombu_redis.Channel._get_response_error(self.channel) + assert kombu_error is redis.exceptions.ResponseError def test_avail_client(self): self.channel._pool = Mock() @@ -762,7 +1045,7 @@ class test_Channel: conn = Mock(name='conn') conn.client = Mock(name='client', transport_options={}) loop = Mock(name='loop') - redis.Transport.register_with_event_loop(transport, conn, loop) + kombu_redis.Transport.register_with_event_loop(transport, conn, loop) transport.cycle.on_poll_init.assert_called_with(loop.poller) loop.call_repeatedly.assert_has_calls([ call(10, transport.cycle.maybe_restore_messages), @@ -787,7 +1070,7 @@ class test_Channel: 'health_check_interval': 15, }) loop = Mock(name='loop') - redis.Transport.register_with_event_loop(transport, conn, loop) + kombu_redis.Transport.register_with_event_loop(transport, conn, loop) transport.cycle.on_poll_init.assert_called_with(loop.poller) loop.call_repeatedly.assert_has_calls([ call(10, transport.cycle.maybe_restore_messages), @@ -808,34 +1091,20 @@ class test_Channel: cycle = transport.cycle = Mock(name='cyle') cycle.on_readable.return_value = None - redis.Transport.on_readable(transport, 13) + kombu_redis.Transport.on_readable(transport, 13) cycle.on_readable.assert_called_with(13) def test_transport_get_errors(self): - assert redis.Transport._get_errors(self.connection.transport) + assert kombu_redis.Transport._get_errors(self.connection.transport) def test_transport_driver_version(self): - assert redis.Transport.driver_version(self.connection.transport) + assert kombu_redis.Transport.driver_version(self.connection.transport) def test_transport_get_errors_when_InvalidData_used(self): - from redis import exceptions - - class ID(Exception): - pass - - DataError = getattr(exceptions, 'DataError', None) - InvalidData = getattr(exceptions, 'InvalidData', None) - exceptions.InvalidData = ID - exceptions.DataError = None - try: - errors = redis.Transport._get_errors(self.connection.transport) - assert errors - assert ID in errors[1] - finally: - if DataError is not None: - exceptions.DataError = DataError - if InvalidData is not None: - exceptions.InvalidData = InvalidData + errors = kombu_redis.Transport._get_errors( + self.connection.transport) + assert errors + assert redis.exceptions.DataError in errors[1] def test_empty_queues_key(self): channel = self.channel @@ -843,13 +1112,14 @@ class test_Channel: key = channel.keyprefix_queue % 'celery' # Everything is fine, there is a list of queues. - channel.client.sadd(key, 'celery\x06\x16\x06\x16celery') + list_of_queues = 'celery\x06\x16\x06\x16celery' + channel.client.sadd(key, list_of_queues) assert channel.get_table('celery') == [ ('celery', '', 'celery'), ] # ... then for some reason, the _kombu.binding.celery key gets lost - channel.client.srem(key) + channel.client.srem(key, list_of_queues) # which raises a channel error so that the consumer/publisher # can recover by redeclaring the required entities. @@ -862,7 +1132,7 @@ class test_Channel: connparams = conn.default_channel._connparams() assert issubclass( connparams['connection_class'], - redis.redis.UnixDomainSocketConnection, + redis.UnixDomainSocketConnection, ) assert connparams['path'] == '/tmp/redis.sock' @@ -889,7 +1159,7 @@ class test_Channel: connparams = conn.default_channel._connparams() assert issubclass( connparams['connection_class'], - redis.redis.SSLConnection, + redis.SSLConnection, ) def test_rediss_connection(self): @@ -898,87 +1168,49 @@ class test_Channel: connparams = conn.default_channel._connparams() assert issubclass( connparams['connection_class'], - redis.redis.SSLConnection, + redis.SSLConnection, ) def test_sep_transport_option(self): - with Connection(transport=Transport, transport_options={ - 'sep': ':', - }) as conn: + with Connection( + transport=FakeRedisKombuTransport, + transport_options={'sep': ':'} + ) as conn: key = conn.default_channel.keyprefix_queue % 'celery' conn.default_channel.client.sadd(key, 'celery::celery') - assert conn.default_channel.sep == ':' assert conn.default_channel.get_table('celery') == [ ('celery', '', 'celery'), ] -class test_Redis: +@skip.unless_module('redis') +@mock.patch('redis.Connection', FakeRedisConnection) +class test_RedisConnections: def setup(self): - self.connection = Connection(transport=Transport) - self.exchange = Exchange('test_Redis', type='direct') + self.connection = self.create_connection() + self.exchange_name = exchange_name = 'test_Redis' + self.exchange = Exchange(exchange_name, type='direct') self.queue = Queue('test_Redis', self.exchange, 'test_Redis') + self.queue(self.connection.default_channel).declare() + self.real_scard = FakeRedisClient.scard def teardown(self): self.connection.close() + self.queue = None + self.exchange = None + global _fake_redis_client + _fake_redis_client = None + FakeRedisClient.scard = self.real_scard - @mock.replace_module_value(redis.redis, 'VERSION', [3, 0, 0]) - def test_publish__get_redispyv3(self): - channel = self.connection.channel() - producer = Producer(channel, self.exchange, routing_key='test_Redis') - self.queue(channel).declare() - - producer.publish({'hello': 'world'}) - - assert self.queue(channel).get().payload == {'hello': 'world'} - assert self.queue(channel).get() is None - assert self.queue(channel).get() is None - assert self.queue(channel).get() is None - - @mock.replace_module_value(redis.redis, 'VERSION', [2, 5, 10]) - def test_publish__get_redispyv2(self): - channel = self.connection.channel() - producer = Producer(channel, self.exchange, routing_key='test_Redis') - self.queue(channel).declare() - - producer.publish({'hello': 'world'}) - - assert self.queue(channel).get().payload == {'hello': 'world'} - assert self.queue(channel).get() is None - assert self.queue(channel).get() is None - assert self.queue(channel).get() is None - - def test_publish__consume(self): - connection = Connection(transport=Transport) - channel = connection.channel() - producer = Producer(channel, self.exchange, routing_key='test_Redis') - consumer = Consumer(channel, queues=[self.queue]) - - producer.publish({'hello2': 'world2'}) - _received = [] - - def callback(message_data, message): - _received.append(message_data) - message.ack() - - consumer.register_callback(callback) - consumer.consume() - - assert channel in channel.connection.cycle._channels - try: - connection.drain_events(timeout=1) - assert _received - with pytest.raises(socket.timeout): - connection.drain_events(timeout=0.01) - finally: - channel.close() + def create_connection(self, **kwargs): + kwargs.setdefault('transport_options', {'fanout_patterns': True}) + return Connection(transport=FakeRedisKombuTransport, **kwargs) def test_purge(self): - channel = self.connection.channel() + channel = self.connection.default_channel producer = Producer(channel, self.exchange, routing_key='test_Redis') - self.queue(channel).declare() for i in range(10): producer.publish({'hello': f'world-{i}'}) @@ -989,38 +1221,38 @@ class test_Redis: def test_db_values(self): Connection(virtual_host=1, - transport=Transport).channel() + transport=FakeRedisKombuTransport).channel() Connection(virtual_host='1', - transport=Transport).channel() + transport=FakeRedisKombuTransport).channel() Connection(virtual_host='/1', - transport=Transport).channel() + transport=FakeRedisKombuTransport).channel() with pytest.raises(Exception): Connection('redis:///foo').channel() def test_db_port(self): - c1 = Connection(port=None, transport=Transport).channel() + c1 = Connection(port=None, transport=FakeRedisKombuTransport).channel() c1.close() - c2 = Connection(port=9999, transport=Transport).channel() + c2 = Connection(port=9999, transport=FakeRedisKombuTransport).channel() c2.close() def test_close_poller_not_active(self): - c = Connection(transport=Transport).channel() + c = Connection(transport=FakeRedisKombuTransport).channel() cycle = c.connection.cycle c.client.connection c.close() assert c not in cycle._channels def test_close_ResponseError(self): - c = Connection(transport=Transport).channel() + c = Connection(transport=FakeRedisKombuTransport).channel() c.client.bgsave_raises_ResponseError = True c.close() def test_close_disconnects(self): - c = Connection(transport=Transport).channel() + c = Connection(transport=FakeRedisKombuTransport).channel() conn1 = c.client.connection conn2 = c.subclient.connection c.close() @@ -1034,61 +1266,87 @@ class test_Redis: channel.close() def test_get_client(self): - with mock.module_exists(*_redis_modules()): - conn = Connection(transport=Transport) - chan = conn.channel() - assert chan.Client - assert chan.ResponseError - assert conn.transport.connection_errors - assert conn.transport.channel_errors + conn = Connection(transport=FakeRedisKombuTransport) + chan = conn.channel() + assert chan.Client + assert chan.ResponseError + assert conn.transport.connection_errors + assert conn.transport.channel_errors def test_check_at_least_we_try_to_connect_and_fail(self): - import redis connection = Connection('redis://localhost:65534/') with pytest.raises(redis.exceptions.ConnectionError): chan = connection.channel() chan._size('some_queue') - -def _redis_modules(): - - class ConnectionError(Exception): - pass - - class AuthenticationError(Exception): - pass - - class InvalidData(Exception): - pass - - class InvalidResponse(Exception): - pass - - class ResponseError(Exception): - pass - - exceptions = types.ModuleType('redis.exceptions') - exceptions.ConnectionError = ConnectionError - exceptions.AuthenticationError = AuthenticationError - exceptions.InvalidData = InvalidData - exceptions.InvalidResponse = InvalidResponse - exceptions.ResponseError = ResponseError - - class Redis: - pass - - myredis = types.ModuleType('redis') - myredis.exceptions = exceptions - myredis.Redis = Redis - - return myredis, exceptions - - -class test_MultiChannelPoller: + def test_redis_queue_lookup_gets_queue_when_exchange_doesnot_exist(self): + # Given: A test redis client and channel + redis_channel = self.connection.default_channel + fake_redis_client = self.connection.default_channel.client + # Given: The default queue is set: + default_queue = 'default_queue' + redis_channel.deadletter_queue = default_queue + # Determine the routing key + routing_key = redis_channel.keyprefix_queue % self.exchange + fake_redis_client.sadd(routing_key, routing_key) + lookup_queue_result = redis_channel._lookup( + exchange=None, + routing_key=routing_key, + default=default_queue) + assert lookup_queue_result == [routing_key] + + def test_redis_queue_lookup_gets_default_when_route_doesnot_exist(self): + # Given: A test redis client and channel + redis_channel = self.connection.default_channel + fake_redis_client = self.connection.default_channel.client + # Given: The default queue is set: + default_queue = 'default_queue' + redis_channel.deadletter_queue = default_queue + # Determine the routing key + routing_key = redis_channel.keyprefix_queue % self.exchange + fake_redis_client.sadd(routing_key, "DoesNotExist") + lookup_queue_result = redis_channel._lookup( + exchange=None, + routing_key=None, + default=None) + assert lookup_queue_result == [default_queue] + + def test_redis_queue_lookup_client_raises_key_error_gets_default(self): + redis_channel = self.connection.default_channel + fake_redis_client = self.connection.default_channel.client + fake_redis_client.smembers = Mock(side_effect=KeyError) + routing_key = redis_channel.keyprefix_queue % self.exchange + redis_channel.queue_bind(routing_key, self.exchange_name, routing_key) + fake_redis_client.sadd(routing_key, routing_key) + default_queue_name = 'default_queue' + lookup_queue_result = redis_channel._lookup( + exchange=self.exchange_name, + routing_key=routing_key, + default=default_queue_name) + assert lookup_queue_result == [default_queue_name] + + def test_redis_queue_lookup_client_raises_key_error_gets_deadletter(self): + fake_redis_client = self.connection.default_channel.client + fake_redis_client.smembers = Mock(side_effect=KeyError) + redis_channel = self.connection.default_channel + routing_key = redis_channel.keyprefix_queue % self.exchange + redis_channel.queue_bind(routing_key, self.exchange_name, routing_key) + fake_redis_client.sadd(routing_key, routing_key) + default_queue_name = 'deadletter_queue' + redis_channel.deadletter_queue = default_queue_name + lookup_queue_result = redis_channel._lookup( + exchange=self.exchange_name, + routing_key=routing_key, + default=None) + assert lookup_queue_result == [default_queue_name] + + +@skip.unless_module('redis') +class test_KombuRedisMultiChannelPoller: def setup(self): - self.Poller = redis.MultiChannelPoller + self.Poller = kombu_redis.MultiChannelPoller def test_on_poll_start(self): p = self.Poller() @@ -1139,17 +1397,17 @@ class test_MultiChannelPoller: chan.handlers = {'BRPOP': Mock(name='BRPOP')} chan.qos.can_consume.return_value = False - p.handle_event(13, redis.READ) + p.handle_event(13, kombu_redis.READ) chan.handlers['BRPOP'].assert_not_called() chan.qos.can_consume.return_value = True - p.handle_event(13, redis.READ) + p.handle_event(13, kombu_redis.READ) chan.handlers['BRPOP'].assert_called_with() - p.handle_event(13, redis.ERR) + p.handle_event(13, kombu_redis.ERR) chan._poll_error.assert_called_with('BRPOP') - p.handle_event(13, ~(redis.READ | redis.ERR)) + p.handle_event(13, ~(kombu_redis.READ | kombu_redis.ERR)) def test_fds(self): p = self.Poller() @@ -1276,12 +1534,12 @@ class test_MultiChannelPoller: def test_get_no_actions(self): p, channel = self.create_get() - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): p.get(Mock()) def test_qos_reject(self): p, channel = self.create_get() - qos = redis.QoS(channel) + qos = kombu_redis.QoS(channel) qos.ack = Mock(name='Qos.ack') qos.reject(1234) qos.ack.assert_called_with(1234) @@ -1290,7 +1548,7 @@ class test_MultiChannelPoller: p, channel = self.create_get(queues=['a_queue']) channel.qos.can_consume.return_value = True - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): p.get(Mock()) p._register_BRPOP.assert_called_with(channel) @@ -1299,7 +1557,7 @@ class test_MultiChannelPoller: p, channel = self.create_get(queues=['a_queue']) channel.qos.can_consume.return_value = False - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): p.get(Mock()) p._register_BRPOP.assert_not_called() @@ -1307,7 +1565,7 @@ class test_MultiChannelPoller: def test_get_listen(self): p, channel = self.create_get(fanouts=['f_queue']) - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): p.get(Mock()) p._register_LISTEN.assert_called_with(channel) @@ -1316,7 +1574,7 @@ class test_MultiChannelPoller: p, channel = self.create_get(events=[(1, eventio.ERR)]) p._fd_to_chan[1] = (channel, 'BRPOP') - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): p.get(Mock()) channel._poll_error.assert_called_with('BRPOP') @@ -1326,13 +1584,14 @@ class test_MultiChannelPoller: (1, eventio.ERR)]) p._fd_to_chan[1] = (channel, 'BRPOP') - with pytest.raises(redis.Empty): + with pytest.raises(kombu_redis.Empty): p.get(Mock()) channel._poll_error.assert_called_with('BRPOP') -class test_Mutex: +@skip.unless_module('redis') +class test_KombuRedisMutex: def test_mutex(self, lock_id='xxx'): client = Mock(name='client') @@ -1341,7 +1600,7 @@ class test_Mutex: # Won lock.acquire.return_value = True held = False - with redis.Mutex(client, 'foo1', 100): + with kombu_redis.Mutex(client, 'foo1', 100): held = True assert held lock.acquire.assert_called_with(blocking=False) @@ -1353,8 +1612,8 @@ class test_Mutex: # Did not win lock.acquire.return_value = False held = False - with pytest.raises(redis.MutexHeld): - with redis.Mutex(client, 'foo1', 100): + with pytest.raises(kombu_redis.MutexHeld): + with kombu_redis.Mutex(client, 'foo1', 100): held = True assert not held lock.acquire.assert_called_with(blocking=False) @@ -1365,19 +1624,78 @@ class test_Mutex: # Wins but raises LockNotOwnedError (and that is ignored) lock.acquire.return_value = True - lock.release.side_effect = redis.redis.exceptions.LockNotOwnedError() + lock.release.side_effect = redis.exceptions.LockNotOwnedError() held = False - with redis.Mutex(client, 'foo1', 100): + with kombu_redis.Mutex(client, 'foo1', 100): held = True assert held +class test_RedisProducerConsumer: + def setup(self): + self.connection = self.create_connection() + self.channel = self.connection.default_channel + self.routing_key = routing_key = 'test_redis_producer' + self.exchange_name = exchange_name = 'test_redis_producer' + self.exchange = Exchange(exchange_name, type='direct') + self.queue = Queue(routing_key, self.exchange, routing_key) + + self.queue(self.connection.default_channel).declare() + self.channel.queue_bind(routing_key, self.exchange_name, routing_key) + + def create_connection(self, **kwargs): + kwargs.setdefault('transport_options', {'fanout_patterns': True}) + return Connection(transport=FakeRedisKombuTransportLite, **kwargs) + + def teardown(self): + self.connection.close() + + def test_publish__get(self): + channel = self.connection.channel() + producer = Producer(channel, self.exchange, + routing_key=self.routing_key) + self.queue(channel).declare() + + producer.publish({'hello': 'world'}) + + assert self.queue(channel).get().payload == {'hello': 'world'} + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + + def test_publish__consume(self): + connection = self.create_connection() + channel = connection.default_channel + producer = Producer(channel, self.exchange, + routing_key=self.routing_key) + consumer = Consumer(channel, queues=[self.queue]) + + producer.publish({'hello2': 'world2'}) + _received = [] + + def callback(message_data, message): + _received.append(message_data) + message.ack() + + consumer.register_callback(callback) + consumer.consume() + + assert channel in channel.connection.cycle._channels + try: + connection.drain_events(timeout=1) + assert _received + with pytest.raises(socket.timeout): + connection.drain_events(timeout=0.01) + finally: + channel.close() + + +@skip.unless_module('redis.sentinel') class test_RedisSentinel: def test_method_called(self): - from kombu.transport.redis import SentinelChannel - - with patch.object(SentinelChannel, '_sentinel_managed_pool') as p: + with patch.object(kombu_redis.SentinelChannel, + '_sentinel_managed_pool') as p: connection = Connection( 'sentinel://localhost:65534/', transport_options={ @@ -1445,15 +1763,13 @@ class test_RedisSentinel: master_for().connection_pool.get_connection.assert_called() def test_can_create_connection(self): - from redis.exceptions import ConnectionError - connection = Connection( 'sentinel://localhost:65534/', transport_options={ 'master_name': 'not_important', }, ) - with pytest.raises(ConnectionError): + with pytest.raises(redis.exceptions.ConnectionError): connection.channel() def test_missing_master_name_transport_option(self): |