diff options
Diffstat (limited to 't/unit/transport/test_redis.py')
-rw-r--r-- | t/unit/transport/test_redis.py | 1303 |
1 files changed, 1303 insertions, 0 deletions
diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py new file mode 100644 index 00000000..fd47317f --- /dev/null +++ b/t/unit/transport/test_redis.py @@ -0,0 +1,1303 @@ +from __future__ import absolute_import, unicode_literals + +import pytest +import socket +import types + +from collections import defaultdict +from itertools import count + +from case import ANY, ContextMock, Mock, call, mock, skip, patch + +from kombu import Connection, Exchange, Queue, Consumer, Producer +from kombu.exceptions import InconsistencyError, VersionMismatch +from kombu.five import Empty, Queue as _Queue, bytes_if_py2 +from kombu.transport import virtual +from kombu.utils import eventio # patch poll +from kombu.utils.json import dumps + + +class _poll(eventio._select): + + def register(self, fd, flags): + if flags & eventio.READ: + self._rfd.add(fd) + + def poll(self, timeout): + events = [] + for fd in self._rfd: + if fd.data: + events.append((fd.fileno(), eventio.READ)) + return events + + +eventio.poll = _poll +# must import after poller patch, pep8 complains +from kombu.transport import redis # noqa + + +class ResponseError(Exception): + pass + + +class Client(object): + queues = {} + sets = defaultdict(set) + hashes = defaultdict(dict) + shard_hint = None + + def __init__(self, db=None, port=None, connection_pool=None, **kwargs): + self._called = [] + self._connection = None + self.bgsave_raises_ResponseError = False + self.connection = self._sconnection(self) + + def bgsave(self): + self._called.append('BGSAVE') + if self.bgsave_raises_ResponseError: + raise ResponseError() + + def delete(self, key): + self.queues.pop(key, None) + + def exists(self, key): + return key in self.queues or key in self.sets + + def hset(self, key, k, v): + self.hashes[key][k] = v + + def hget(self, key, k): + return self.hashes[key].get(k) + + def hdel(self, key, k): + self.hashes[key].pop(k, None) + + def sadd(self, key, member, *args): + self.sets[key].add(member) + + def zadd(self, key, score1, member1, *args): + self.sets[key].add(member1) + + def smembers(self, key): + return self.sets.get(key, set()) + + def ping(self, *args, **kwargs): + return True + + def srem(self, key, *args): + self.sets.pop(key, None) + zrem = srem + + def llen(self, key): + try: + return self.queues[key].qsize() + except KeyError: + return 0 + + def lpush(self, key, value): + self.queues[key].put_nowait(value) + + def parse_response(self, connection, type, **options): + cmd, queues = self.connection._sock.data.pop() + queues = list(queues) + assert cmd == type + self.connection._sock.data = [] + if type == 'BRPOP': + timeout = queues.pop() + item = self.brpop(queues, timeout) + if item: + return item + raise Empty() + + def brpop(self, keys, timeout=None): + for key in keys: + try: + item = self.queues[key].get_nowait() + except Empty: + pass + else: + return key, item + + def rpop(self, key): + try: + return self.queues[key].get_nowait() + except (KeyError, Empty): + pass + + def __contains__(self, k): + return k in self._called + + def pipeline(self): + return Pipeline(self) + + def encode(self, value): + return str(value) + + def _new_queue(self, key): + self.queues[key] = _Queue() + + class _sconnection(object): + disconnected = False + + class _socket(object): + blocking = True + filenos = count(30) + + def __init__(self, *args): + self._fileno = next(self.filenos) + self.data = [] + + def fileno(self): + return self._fileno + + def setblocking(self, blocking): + self.blocking = blocking + + def __init__(self, client): + self.client = client + self._sock = self._socket() + + def disconnect(self): + self.disconnected = True + + def send_command(self, cmd, *args): + self._sock.data.append((cmd, args)) + + def info(self): + return {'foo': 1} + + def pubsub(self, *args, **kwargs): + connection = self.connection + + class ConnectionPool(object): + + def get_connection(self, *args, **kwargs): + return connection + self.connection_pool = ConnectionPool() + + return self + + +class Pipeline(object): + + def __init__(self, client): + self.client = client + self.stack = [] + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + pass + + def __getattr__(self, key): + if key not in self.__dict__: + + def _add(*args, **kwargs): + self.stack.append((getattr(self.client, key), args, kwargs)) + return self + + return _add + return self.__dict__[key] + + def execute(self): + stack = list(self.stack) + self.stack[:] = [] + return [fun(*args, **kwargs) for fun, args, kwargs in stack] + + +class Channel(redis.Channel): + + def _get_client(self): + return Client + + def _get_pool(self, async=False): + return Mock() + + def _get_response_error(self): + return ResponseError + + def _new_queue(self, queue, **kwargs): + for pri in self.priority_steps: + self.client._new_queue(self._q_for_pri(queue, pri)) + + def pipeline(self): + return Pipeline(Client()) + + +class Transport(redis.Transport): + Channel = Channel + + def _get_errors(self): + return ((KeyError,), (IndexError,)) + + +@skip.unless_module('redis') +class test_Channel: + + def setup(self): + self.connection = self.create_connection() + self.channel = self.connection.default_channel + + def create_connection(self, **kwargs): + kwargs.setdefault('transport_options', {'fanout_patterns': True}) + return Connection(transport=Transport, **kwargs) + + def _get_one_delivery_tag(self, n='test_uniq_tag'): + with self.create_connection() as conn1: + chan = conn1.default_channel + chan.exchange_declare(n) + chan.queue_declare(n) + chan.queue_bind(n, n, n) + msg = chan.prepare_message('quick brown fox') + chan.basic_publish(msg, n, n) + payload = chan._get(n) + assert payload + pymsg = chan.message_to_python(payload) + return pymsg.delivery_tag + + def test_delivery_tag_is_uuid(self): + seen = set() + for i in range(100): + tag = self._get_one_delivery_tag() + assert tag not in seen + seen.add(tag) + with pytest.raises(ValueError): + int(tag) + assert len(tag) == 36 + + def test_disable_ack_emulation(self): + conn = Connection(transport=Transport, transport_options={ + 'ack_emulation': False, + }) + + chan = conn.channel() + assert not chan.ack_emulation + assert chan.QoS == virtual.QoS + + def test_redis_ping_raises(self): + pool = Mock(name='pool') + pool_at_init = [pool] + client = Mock(name='client') + + class XChannel(Channel): + + def __init__(self, *args, **kwargs): + self._pool = pool_at_init[0] + super(XChannel, self).__init__(*args, **kwargs) + + def _get_client(self): + return lambda *_, **__: client + + class XTransport(Transport): + Channel = XChannel + + conn = Connection(transport=XTransport) + client.ping.side_effect = RuntimeError() + with pytest.raises(RuntimeError): + conn.channel() + pool.disconnect.assert_called_with() + pool.disconnect.reset_mock() + + pool_at_init = [None] + with pytest.raises(RuntimeError): + conn.channel() + pool.disconnect.assert_not_called() + + def test_after_fork(self): + self.channel._pool = None + self.channel._after_fork() + + pool = self.channel._pool = Mock(name='pool') + self.channel._after_fork() + pool.disconnect.assert_called_with() + + def test_next_delivery_tag(self): + assert (self.channel._next_delivery_tag() != + self.channel._next_delivery_tag()) + + def test_do_restore_message(self): + client = Mock(name='client') + pl1 = {'body': 'BODY'} + spl1 = dumps(pl1) + lookup = self.channel._lookup = Mock(name='_lookup') + lookup.return_value = {'george', 'elaine'} + self.channel._do_restore_message( + pl1, 'ex', 'rkey', client, + ) + client.rpush.assert_has_calls([ + call('george', spl1), call('elaine', spl1), + ], any_order=True) + + client = Mock(name='client') + pl2 = {'body': 'BODY2', 'headers': {'x-funny': 1}} + headers_after = dict(pl2['headers'], redelivered=True) + spl2 = dumps(dict(pl2, headers=headers_after)) + self.channel._do_restore_message( + pl2, 'ex', 'rkey', client, + ) + client.rpush.assert_any_call('george', spl2) + client.rpush.assert_any_call('elaine', spl2) + + client.rpush.side_effect = KeyError() + with patch('kombu.transport.redis.crit') as crit: + self.channel._do_restore_message( + pl2, 'ex', 'rkey', client, + ) + crit.assert_called() + + def test_restore(self): + message = Mock(name='message') + with patch('kombu.transport.redis.loads') as loads: + loads.return_value = 'M', 'EX', 'RK' + client = self.channel._create_client = Mock(name='client') + client = client() + client.pipeline = ContextMock() + restore = self.channel._do_restore_message = Mock( + name='_do_restore_message', + ) + pipe = client.pipeline.return_value + pipe_hget = Mock(name='pipe.hget') + pipe.hget.return_value = pipe_hget + pipe_hget_hdel = Mock(name='pipe.hget.hdel') + pipe_hget.hdel.return_value = pipe_hget_hdel + result = Mock(name='result') + pipe_hget_hdel.execute.return_value = None, None + + self.channel._restore(message) + client.pipeline.assert_called_with() + unacked_key = self.channel.unacked_key + loads.assert_not_called() + + tag = message.delivery_tag + pipe.hget.assert_called_with(unacked_key, tag) + pipe_hget.hdel.assert_called_with(unacked_key, tag) + pipe_hget_hdel.execute.assert_called_with() + + pipe_hget_hdel.execute.return_value = result, None + self.channel._restore(message) + loads.assert_called_with(result) + restore.assert_called_with('M', 'EX', 'RK', client, False) + + def test_qos_restore_visible(self): + client = self.channel._create_client = Mock(name='client') + client = client() + + def pipe(*args, **kwargs): + return Pipeline(client) + client.pipeline = pipe + client.zrevrangebyscore.return_value = [ + (1, 10), + (2, 20), + (3, 30), + ] + qos = redis.QoS(self.channel) + restore = qos.restore_by_tag = Mock(name='restore_by_tag') + qos._vrestore_count = 1 + qos.restore_visible() + client.zrevrangebyscore.assert_not_called() + assert qos._vrestore_count == 2 + + qos._vrestore_count = 0 + qos.restore_visible() + restore.assert_has_calls([ + call(1, client), call(2, client), call(3, client), + ]) + assert qos._vrestore_count == 1 + + qos._vrestore_count = 0 + restore.reset_mock() + client.zrevrangebyscore.return_value = [] + qos.restore_visible() + restore.assert_not_called() + assert qos._vrestore_count == 1 + + qos._vrestore_count = 0 + client.setnx.side_effect = redis.MutexHeld() + qos.restore_visible() + + def test_basic_consume_when_fanout_queue(self): + self.channel.exchange_declare(exchange='txconfan', type='fanout') + self.channel.queue_declare(queue='txconfanq') + self.channel.queue_bind(queue='txconfanq', exchange='txconfan') + + assert 'txconfanq' in self.channel._fanout_queues + self.channel.basic_consume('txconfanq', False, None, 1) + assert 'txconfanq' in self.channel.active_fanout_queues + assert self.channel._fanout_to_queue.get('txconfan') == 'txconfanq' + + def test_basic_cancel_unknown_delivery_tag(self): + assert self.channel.basic_cancel('txaseqwewq') is None + + def test_subscribe_no_queues(self): + self.channel.subclient = Mock() + self.channel.active_fanout_queues.clear() + self.channel._subscribe() + self.channel.subclient.subscribe.assert_not_called() + + def test_subscribe(self): + self.channel.subclient = Mock() + self.channel.active_fanout_queues.add('a') + self.channel.active_fanout_queues.add('b') + self.channel._fanout_queues.update(a=('a', ''), b=('b', '')) + + self.channel._subscribe() + self.channel.subclient.psubscribe.assert_called() + s_args, _ = self.channel.subclient.psubscribe.call_args + assert sorted(s_args[0]) == ['/{db}.a', '/{db}.b'] + + self.channel.subclient.connection._sock = None + self.channel._subscribe() + self.channel.subclient.connection.connect.assert_called_with() + + def test_handle_unsubscribe_message(self): + s = self.channel.subclient + s.subscribed = True + self.channel._handle_message(s, ['unsubscribe', 'a', 0]) + assert not s.subscribed + + def test_handle_pmessage_message(self): + res = self.channel._handle_message( + self.channel.subclient, + ['pmessage', 'pattern', 'channel', 'data'], + ) + assert res == { + 'type': 'pmessage', + 'pattern': 'pattern', + 'channel': 'channel', + 'data': 'data', + } + + def test_handle_message(self): + res = self.channel._handle_message( + self.channel.subclient, + ['type', 'channel', 'data'], + ) + assert res == { + 'type': 'type', + 'pattern': None, + 'channel': 'channel', + 'data': 'data', + } + + def test_brpop_start_but_no_queues(self): + assert self.channel._brpop_start() is None + + def test_receive(self): + s = self.channel.subclient = Mock() + self.channel._fanout_to_queue['a'] = 'b' + s.parse_response.return_value = ['message', 'a', + dumps({'hello': 'world'})] + payload, queue = self.channel._receive() + assert payload == {'hello': 'world'} + assert queue == 'b' + + def test_receive_raises_for_connection_error(self): + self.channel._in_listen = True + s = self.channel.subclient = Mock() + s.parse_response.side_effect = KeyError('foo') + + with pytest.raises(KeyError): + self.channel._receive() + assert not self.channel._in_listen + + def test_receive_empty(self): + s = self.channel.subclient = Mock() + s.parse_response.return_value = None + + with pytest.raises(redis.Empty): + self.channel._receive() + + def test_receive_different_message_Type(self): + s = self.channel.subclient = Mock() + s.parse_response.return_value = ['message', '/foo/', 0, 'data'] + + with pytest.raises(redis.Empty): + self.channel._receive() + + def test_brpop_read_raises(self): + c = self.channel.client = Mock() + c.parse_response.side_effect = KeyError('foo') + + with pytest.raises(KeyError): + self.channel._brpop_read() + + c.connection.disconnect.assert_called_with() + + def test_brpop_read_gives_None(self): + c = self.channel.client = Mock() + c.parse_response.return_value = None + + with pytest.raises(redis.Empty): + self.channel._brpop_read() + + def test_poll_error(self): + c = self.channel.client = Mock() + c.parse_response = Mock() + self.channel._poll_error('BRPOP') + + c.parse_response.assert_called_with(c.connection, 'BRPOP') + + c.parse_response.side_effect = KeyError('foo') + with pytest.raises(KeyError): + self.channel._poll_error('BRPOP') + + def test_poll_error_on_type_LISTEN(self): + c = self.channel.subclient = Mock() + c.parse_response = Mock() + self.channel._poll_error('LISTEN') + + c.parse_response.assert_called_with() + + c.parse_response.side_effect = KeyError('foo') + with pytest.raises(KeyError): + self.channel._poll_error('LISTEN') + + def test_put_fanout(self): + self.channel._in_poll = False + c = self.channel._create_client = Mock() + + body = {'hello': 'world'} + self.channel._put_fanout('exchange', body, '') + c().publish.assert_called_with('/{db}.exchange', dumps(body)) + + def test_put_priority(self): + client = self.channel._create_client = Mock(name='client') + msg1 = {'properties': {'priority': 3}} + + self.channel._put('george', msg1) + client().lpush.assert_called_with( + self.channel._q_for_pri('george', 6), dumps(msg1), + ) + + msg2 = {'properties': {'priority': 313}} + self.channel._put('george', msg2) + client().lpush.assert_called_with( + self.channel._q_for_pri('george', 0), dumps(msg2), + ) + + msg3 = {'properties': {}} + self.channel._put('george', msg3) + client().lpush.assert_called_with( + self.channel._q_for_pri('george', 9), dumps(msg3), + ) + + def test_delete(self): + x = self.channel + x._create_client = Mock() + x._create_client.return_value = x.client + delete = x.client.delete = Mock() + srem = x.client.srem = Mock() + + x._delete('queue', 'exchange', 'routing_key', None) + delete.assert_has_calls([ + call(x._q_for_pri('queue', pri)) for pri in redis.PRIORITY_STEPS + ]) + srem.assert_called_with(x.keyprefix_queue % ('exchange',), + x.sep.join(['routing_key', '', 'queue'])) + + def test_has_queue(self): + self.channel._create_client = Mock() + self.channel._create_client.return_value = self.channel.client + exists = self.channel.client.exists = Mock() + exists.return_value = True + assert self.channel._has_queue('foo') + exists.assert_has_calls([ + call(self.channel._q_for_pri('foo', pri)) + for pri in redis.PRIORITY_STEPS + ]) + + exists.return_value = False + assert not self.channel._has_queue('foo') + + def test_close_when_closed(self): + self.channel.closed = True + self.channel.close() + + def test_close_deletes_autodelete_fanout_queues(self): + self.channel._fanout_queues = {'foo': ('foo', ''), 'bar': ('bar', '')} + self.channel.auto_delete_queues = ['foo'] + self.channel.queue_delete = Mock(name='queue_delete') + + client = self.channel.client + self.channel.close() + self.channel.queue_delete.assert_has_calls([ + call('foo', client=client), + ]) + + def test_close_client_close_raises(self): + c = self.channel.client = Mock() + c.connection.disconnect.side_effect = self.channel.ResponseError() + + self.channel.close() + c.connection.disconnect.assert_called_with() + + def test_invalid_database_raises_ValueError(self): + + with pytest.raises(ValueError): + self.channel.connection.client.virtual_host = 'dwqeq' + self.channel._connparams() + + def test_connparams_allows_slash_in_db(self): + self.channel.connection.client.virtual_host = '/123' + assert self.channel._connparams()['db'] == 123 + + def test_connparams_db_can_be_int(self): + self.channel.connection.client.virtual_host = 124 + assert self.channel._connparams()['db'] == 124 + + def test_new_queue_with_auto_delete(self): + redis.Channel._new_queue(self.channel, 'george', auto_delete=False) + assert 'george' not in self.channel.auto_delete_queues + redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True) + assert 'elaine' in self.channel.auto_delete_queues + + def test_connparams_regular_hostname(self): + self.channel.connection.client.hostname = 'george.vandelay.com' + assert self.channel._connparams()['host'] == 'george.vandelay.com' + + def test_rotate_cycle_ValueError(self): + cycle = self.channel._queue_cycle + cycle.update(['kramer', 'jerry']) + cycle.rotate('kramer') + assert cycle.items, ['jerry' == 'kramer'] + cycle.rotate('elaine') + + def test_get_client(self): + import redis as R + KombuRedis = redis.Channel._get_client(self.channel) + assert KombuRedis + + Rv = getattr(R, 'VERSION', None) + try: + R.VERSION = (2, 4, 0) + with pytest.raises(VersionMismatch): + redis.Channel._get_client(self.channel) + finally: + if Rv is not None: + R.VERSION = Rv + + def test_get_response_error(self): + from redis.exceptions import ResponseError + assert redis.Channel._get_response_error(self.channel) is ResponseError + + def test_avail_client(self): + self.channel._pool = Mock() + cc = self.channel._create_client = Mock() + client = cc.return_value = Mock() + + with self.channel.conn_or_acquire(): + pass + self.channel.pool.release.assert_called_with(client.connection) + cc.assert_called_with() + + def test_register_with_event_loop(self): + transport = self.connection.transport + transport.cycle = Mock(name='cycle') + transport.cycle.fds = {12: 'LISTEN', 13: 'BRPOP'} + conn = Mock(name='conn') + loop = Mock(name='loop') + redis.Transport.register_with_event_loop(transport, conn, loop) + transport.cycle.on_poll_init.assert_called_with(loop.poller) + loop.call_repeatedly.assert_called_with( + 10, transport.cycle.maybe_restore_messages, + ) + loop.on_tick.add.assert_called() + on_poll_start = loop.on_tick.add.call_args[0][0] + + on_poll_start() + transport.cycle.on_poll_start.assert_called_with() + loop.add_reader.assert_has_calls([ + call(12, transport.on_readable, 12), + call(13, transport.on_readable, 13), + ]) + + def test_transport_on_readable(self): + transport = self.connection.transport + cycle = transport.cycle = Mock(name='cyle') + cycle.on_readable.return_value = None + + redis.Transport.on_readable(transport, 13) + cycle.on_readable.assert_called_with(13) + cycle.on_readable.reset_mock() + + queue = Mock(name='queue') + ret = (Mock(name='message'), queue) + cycle.on_readable.return_value = ret + transport._reject_inbound_message = Mock(name='_reject_inbound') + redis.Transport.on_readable(transport, 14) + transport._reject_inbound_message.assert_called_with(ret[0]) + + cb = transport._callbacks[queue] = Mock(name='callback') + redis.Transport.on_readable(transport, 14) + cb.assert_called_with(ret[0]) + + def test_transport_get_errors(self): + assert redis.Transport._get_errors(self.connection.transport) + + def test_transport_driver_version(self): + assert redis.Transport.driver_version(self.connection.transport) + + def test_transport_get_errors_when_InvalidData_used(self): + from redis import exceptions + + class ID(Exception): + pass + + DataError = getattr(exceptions, 'DataError', None) + InvalidData = getattr(exceptions, 'InvalidData', None) + exceptions.InvalidData = ID + exceptions.DataError = None + try: + errors = redis.Transport._get_errors(self.connection.transport) + assert errors + assert ID in errors[1] + finally: + if DataError is not None: + exceptions.DataError = DataError + if InvalidData is not None: + exceptions.InvalidData = InvalidData + + def test_empty_queues_key(self): + channel = self.channel + channel._in_poll = False + key = channel.keyprefix_queue % 'celery' + + # Everything is fine, there is a list of queues. + channel.client.sadd(key, 'celery\x06\x16\x06\x16celery') + assert channel.get_table('celery') == [ + ('celery', '', 'celery'), + ] + + # ... then for some reason, the _kombu.binding.celery key gets lost + channel.client.srem(key) + + # which raises a channel error so that the consumer/publisher + # can recover by redeclaring the required entities. + with pytest.raises(InconsistencyError): + self.channel.get_table('celery') + + def test_socket_connection(self): + with patch('kombu.transport.redis.Channel._create_client'): + with Connection('redis+socket:///tmp/redis.sock') as conn: + connparams = conn.default_channel._connparams() + assert issubclass( + connparams['connection_class'], + redis.redis.UnixDomainSocketConnection, + ) + assert connparams['path'] == '/tmp/redis.sock' + + def test_ssl_argument__dict(self): + with patch('kombu.transport.redis.Channel._create_client'): + with Connection('redis://', ssl={'ca_cert': '/foo'}) as conn: + connparams = conn.default_channel._connparams() + assert connparams['ssl'] + assert connparams['ca_cert'] == '/foo' + + def test_ssl_argument__bool(self): + with patch('kombu.transport.redis.Channel._create_client'): + with Connection('redis://', ssl=True) as conn: + connparams = conn.default_channel._connparams() + assert connparams['ssl'] + + +@skip.unless_module('redis') +class test_Redis: + + def setup(self): + self.connection = Connection(transport=Transport) + self.exchange = Exchange('test_Redis', type='direct') + self.queue = Queue('test_Redis', self.exchange, 'test_Redis') + + def teardown(self): + self.connection.close() + + def test_publish__get(self): + channel = self.connection.channel() + producer = Producer(channel, self.exchange, routing_key='test_Redis') + self.queue(channel).declare() + + producer.publish({'hello': 'world'}) + + assert self.queue(channel).get().payload == {'hello': 'world'} + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + + def test_publish__consume(self): + connection = Connection(transport=Transport) + channel = connection.channel() + producer = Producer(channel, self.exchange, routing_key='test_Redis') + consumer = Consumer(channel, queues=[self.queue]) + + producer.publish({'hello2': 'world2'}) + _received = [] + + def callback(message_data, message): + _received.append(message_data) + message.ack() + + consumer.register_callback(callback) + consumer.consume() + + assert channel in channel.connection.cycle._channels + try: + connection.drain_events(timeout=1) + assert _received + with pytest.raises(socket.timeout): + connection.drain_events(timeout=0.01) + finally: + channel.close() + + def test_purge(self): + channel = self.connection.channel() + producer = Producer(channel, self.exchange, routing_key='test_Redis') + self.queue(channel).declare() + + for i in range(10): + producer.publish({'hello': 'world-%s' % (i,)}) + + assert channel._size('test_Redis') == 10 + assert self.queue(channel).purge() == 10 + channel.close() + + def test_db_values(self): + Connection(virtual_host=1, + transport=Transport).channel() + + Connection(virtual_host='1', + transport=Transport).channel() + + Connection(virtual_host='/1', + transport=Transport).channel() + + with pytest.raises(Exception): + Connection('redis:///foo').channel() + + def test_db_port(self): + c1 = Connection(port=None, transport=Transport).channel() + c1.close() + + c2 = Connection(port=9999, transport=Transport).channel() + c2.close() + + def test_close_poller_not_active(self): + c = Connection(transport=Transport).channel() + cycle = c.connection.cycle + c.client.connection + c.close() + assert c not in cycle._channels + + def test_close_ResponseError(self): + c = Connection(transport=Transport).channel() + c.client.bgsave_raises_ResponseError = True + c.close() + + def test_close_disconnects(self): + c = Connection(transport=Transport).channel() + conn1 = c.client.connection + conn2 = c.subclient.connection + c.close() + assert conn1.disconnected + assert conn2.disconnected + + def test_get__Empty(self): + channel = self.connection.channel() + with pytest.raises(Empty): + channel._get('does-not-exist') + channel.close() + + def test_get_client(self): + with mock.module_exists(*_redis_modules()): + conn = Connection(transport=Transport) + chan = conn.channel() + assert chan.Client + assert chan.ResponseError + assert conn.transport.connection_errors + assert conn.transport.channel_errors + + def test_check_at_least_we_try_to_connect_and_fail(self): + import redis + connection = Connection('redis://localhost:65534/') + + with pytest.raises(redis.exceptions.ConnectionError): + chan = connection.channel() + chan._size('some_queue') + + +def _redis_modules(): + + class ConnectionError(Exception): + pass + + class AuthenticationError(Exception): + pass + + class InvalidData(Exception): + pass + + class InvalidResponse(Exception): + pass + + class ResponseError(Exception): + pass + + exceptions = types.ModuleType(bytes_if_py2('redis.exceptions')) + exceptions.ConnectionError = ConnectionError + exceptions.AuthenticationError = AuthenticationError + exceptions.InvalidData = InvalidData + exceptions.InvalidResponse = InvalidResponse + exceptions.ResponseError = ResponseError + + class Redis(object): + pass + + myredis = types.ModuleType(bytes_if_py2('redis')) + myredis.exceptions = exceptions + myredis.Redis = Redis + + return myredis, exceptions + + +@skip.unless_module('redis') +class test_MultiChannelPoller: + + def setup(self): + self.Poller = redis.MultiChannelPoller + + def test_on_poll_start(self): + p = self.Poller() + p._channels = [] + p.on_poll_start() + p._register_BRPOP = Mock(name='_register_BRPOP') + p._register_LISTEN = Mock(name='_register_LISTEN') + + chan1 = Mock(name='chan1') + p._channels = [chan1] + chan1.active_queues = [] + chan1.active_fanout_queues = [] + p.on_poll_start() + + chan1.active_queues = ['q1'] + chan1.active_fanout_queues = ['q2'] + chan1.qos.can_consume.return_value = False + + p.on_poll_start() + p._register_LISTEN.assert_called_with(chan1) + p._register_BRPOP.assert_not_called() + + chan1.qos.can_consume.return_value = True + p._register_LISTEN.reset_mock() + p.on_poll_start() + + p._register_BRPOP.assert_called_with(chan1) + p._register_LISTEN.assert_called_with(chan1) + + def test_on_poll_init(self): + p = self.Poller() + chan1 = Mock(name='chan1') + p._channels = [] + poller = Mock(name='poller') + p.on_poll_init(poller) + assert p.poller is poller + + p._channels = [chan1] + p.on_poll_init(poller) + chan1.qos.restore_visible.assert_called_with( + num=chan1.unacked_restore_limit, + ) + + def test_handle_event(self): + p = self.Poller() + chan = Mock(name='chan') + p._fd_to_chan[13] = chan, 'BRPOP' + chan.handlers = {'BRPOP': Mock(name='BRPOP')} + + chan.qos.can_consume.return_value = False + p.handle_event(13, redis.READ) + chan.handlers['BRPOP'].assert_not_called() + + chan.qos.can_consume.return_value = True + p.handle_event(13, redis.READ) + chan.handlers['BRPOP'].assert_called_with() + + p.handle_event(13, redis.ERR) + chan._poll_error.assert_called_with('BRPOP') + + p.handle_event(13, ~(redis.READ | redis.ERR)) + + def test_fds(self): + p = self.Poller() + p._fd_to_chan = {1: 2} + assert p.fds == p._fd_to_chan + + def test_close_unregisters_fds(self): + p = self.Poller() + poller = p.poller = Mock() + p._chan_to_sock.update({1: 1, 2: 2, 3: 3}) + + p.close() + + assert poller.unregister.call_count == 3 + u_args = poller.unregister.call_args_list + + assert sorted(u_args) == [ + ((1,), {}), + ((2,), {}), + ((3,), {}), + ] + + def test_close_when_unregister_raises_KeyError(self): + p = self.Poller() + p.poller = Mock() + p._chan_to_sock.update({1: 1}) + p.poller.unregister.side_effect = KeyError(1) + p.close() + + def test_close_resets_state(self): + p = self.Poller() + p.poller = Mock() + p._channels = Mock() + p._fd_to_chan = Mock() + p._chan_to_sock = Mock() + + p._chan_to_sock.itervalues.return_value = [] + p._chan_to_sock.values.return_value = [] # py3k + + p.close() + p._channels.clear.assert_called_with() + p._fd_to_chan.clear.assert_called_with() + p._chan_to_sock.clear.assert_called_with() + + def test_register_when_registered_reregisters(self): + p = self.Poller() + p.poller = Mock() + channel, client, type = Mock(), Mock(), Mock() + sock = client.connection._sock = Mock() + sock.fileno.return_value = 10 + + p._chan_to_sock = {(channel, client, type): 6} + p._register(channel, client, type) + p.poller.unregister.assert_called_with(6) + assert p._fd_to_chan[10] == (channel, type) + assert p._chan_to_sock[(channel, client, type)] == sock + p.poller.register.assert_called_with(sock, p.eventflags) + + # when client not connected yet + client.connection._sock = None + + def after_connected(): + client.connection._sock = Mock() + client.connection.connect.side_effect = after_connected + + p._register(channel, client, type) + client.connection.connect.assert_called_with() + + def test_register_BRPOP(self): + p = self.Poller() + channel = Mock() + channel.client.connection._sock = None + p._register = Mock() + + channel._in_poll = False + p._register_BRPOP(channel) + assert channel._brpop_start.call_count == 1 + assert p._register.call_count == 1 + + channel.client.connection._sock = Mock() + p._chan_to_sock[(channel, channel.client, 'BRPOP')] = True + channel._in_poll = True + p._register_BRPOP(channel) + assert channel._brpop_start.call_count == 1 + assert p._register.call_count == 1 + + def test_register_LISTEN(self): + p = self.Poller() + channel = Mock() + channel.subclient.connection._sock = None + channel._in_listen = False + p._register = Mock() + + p._register_LISTEN(channel) + p._register.assert_called_with(channel, channel.subclient, 'LISTEN') + assert p._register.call_count == 1 + assert channel._subscribe.call_count == 1 + + channel._in_listen = True + p._chan_to_sock[(channel, channel.subclient, 'LISTEN')] = 3 + channel.subclient.connection._sock = Mock() + p._register_LISTEN(channel) + assert p._register.call_count == 1 + assert channel._subscribe.call_count == 1 + + def create_get(self, events=None, queues=None, fanouts=None): + _pr = [] if events is None else events + _aq = [] if queues is None else queues + _af = [] if fanouts is None else fanouts + p = self.Poller() + p.poller = Mock() + p.poller.poll.return_value = _pr + + p._register_BRPOP = Mock() + p._register_LISTEN = Mock() + + channel = Mock() + p._channels = [channel] + channel.active_queues = _aq + channel.active_fanout_queues = _af + + return p, channel + + def test_get_no_actions(self): + p, channel = self.create_get() + + with pytest.raises(redis.Empty): + p.get() + + def test_qos_reject(self): + p, channel = self.create_get() + qos = redis.QoS(channel) + qos.ack = Mock(name='Qos.ack') + qos.reject(1234) + qos.ack.assert_called_with(1234) + + def test_get_brpop_qos_allow(self): + p, channel = self.create_get(queues=['a_queue']) + channel.qos.can_consume.return_value = True + + with pytest.raises(redis.Empty): + p.get() + + p._register_BRPOP.assert_called_with(channel) + + def test_get_brpop_qos_disallow(self): + p, channel = self.create_get(queues=['a_queue']) + channel.qos.can_consume.return_value = False + + with pytest.raises(redis.Empty): + p.get() + + p._register_BRPOP.assert_not_called() + + def test_get_listen(self): + p, channel = self.create_get(fanouts=['f_queue']) + + with pytest.raises(redis.Empty): + p.get() + + p._register_LISTEN.assert_called_with(channel) + + def test_get_receives_ERR(self): + p, channel = self.create_get(events=[(1, eventio.ERR)]) + p._fd_to_chan[1] = (channel, 'BRPOP') + + with pytest.raises(redis.Empty): + p.get() + + channel._poll_error.assert_called_with('BRPOP') + + def test_get_receives_multiple(self): + p, channel = self.create_get(events=[(1, eventio.ERR), + (1, eventio.ERR)]) + p._fd_to_chan[1] = (channel, 'BRPOP') + + with pytest.raises(redis.Empty): + p.get() + + channel._poll_error.assert_called_with('BRPOP') + + +@skip.unless_module('redis') +class test_Mutex: + + def test_mutex(self, lock_id='xxx'): + client = Mock(name='client') + with patch('kombu.transport.redis.uuid') as uuid: + # Won + uuid.return_value = lock_id + client.setnx.return_value = True + client.pipeline = ContextMock() + pipe = client.pipeline.return_value + pipe.get.return_value = lock_id + held = False + with redis.Mutex(client, 'foo1', 100): + held = True + assert held + client.setnx.assert_called_with('foo1', lock_id) + pipe.get.return_value = 'yyy' + held = False + with redis.Mutex(client, 'foo1', 100): + held = True + assert held + + # Did not win + client.expire.reset_mock() + pipe.get.return_value = lock_id + client.setnx.return_value = False + with pytest.raises(redis.MutexHeld): + held = False + with redis.Mutex(client, 'foo1', '100'): + held = True + assert not held + client.ttl.return_value = 0 + with pytest.raises(redis.MutexHeld): + held = False + with redis.Mutex(client, 'foo1', '100'): + held = True + assert not held + client.expire.assert_called() + + # Wins but raises WatchError (and that is ignored) + client.setnx.return_value = True + pipe.watch.side_effect = redis.redis.WatchError() + held = False + with redis.Mutex(client, 'foo1', 100): + held = True + assert held + + +@skip.unless_module('redis.sentinel') +class test_RedisSentinel: + + def test_method_called(self): + from kombu.transport.redis import SentinelChannel + + with patch.object(SentinelChannel, '_sentinel_managed_pool') as p: + connection = Connection( + 'sentinel://localhost:65534/', + transport_options={ + 'master_name': 'not_important', + }, + ) + + connection.channel() + p.assert_called() + + def test_getting_master_from_sentinel(self): + with patch('redis.sentinel.Sentinel') as patched: + connection = Connection( + 'sentinel://localhost:65534/', + transport_options={ + 'master_name': 'not_important', + }, + ) + + connection.channel() + assert patched + + master_for = patched.return_value.master_for + master_for.assert_called() + master_for.assert_called_with('not_important', ANY) + master_for().connection_pool.get_connection.assert_called() + + def test_can_create_connection(self): + from redis.exceptions import ConnectionError + + connection = Connection( + 'sentinel://localhost:65534/', + transport_options={ + 'master_name': 'not_important', + }, + ) + with pytest.raises(ConnectionError): + connection.channel() |