diff options
| author | Asif Saif Uddin <auvipy@gmail.com> | 2023-04-08 22:45:08 +0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-08 22:45:08 +0600 |
| commit | 973dc3790ac25b9da7b6d2641ac72d95470f6ed8 (patch) | |
| tree | 9e7ba02d8520994a06efc37dde05fba722138189 /t/unit | |
| parent | 7ceb675bb69917fae182ebdaf9a2298a308c3fa4 (diff) | |
| parent | 2de7f9f038dd62e097e490cb3fa609067c1c3c36 (diff) | |
| download | kombu-py310.tar.gz | |
Merge branch 'main' into py310py310
Diffstat (limited to 't/unit')
60 files changed, 1152 insertions, 227 deletions
diff --git a/t/unit/asynchronous/aws/case.py b/t/unit/asynchronous/aws/case.py index 56c70812..220cd700 100644 --- a/t/unit/asynchronous/aws/case.py +++ b/t/unit/asynchronous/aws/case.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest import t.skip diff --git a/t/unit/asynchronous/aws/sqs/test_connection.py b/t/unit/asynchronous/aws/sqs/test_connection.py index c3dd184b..0c5d2ac9 100644 --- a/t/unit/asynchronous/aws/sqs/test_connection.py +++ b/t/unit/asynchronous/aws/sqs/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import MagicMock, Mock from kombu.asynchronous.aws.ext import boto3 diff --git a/t/unit/asynchronous/aws/sqs/test_queue.py b/t/unit/asynchronous/aws/sqs/test_queue.py index 56812831..70f10a75 100644 --- a/t/unit/asynchronous/aws/sqs/test_queue.py +++ b/t/unit/asynchronous/aws/sqs/test_queue.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/asynchronous/aws/test_aws.py b/t/unit/asynchronous/aws/test_aws.py index 93d92e4b..736fdf8a 100644 --- a/t/unit/asynchronous/aws/test_aws.py +++ b/t/unit/asynchronous/aws/test_aws.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock from kombu.asynchronous.aws import connect_sqs diff --git a/t/unit/asynchronous/aws/test_connection.py b/t/unit/asynchronous/aws/test_connection.py index 68e3c746..03fc5412 100644 --- a/t/unit/asynchronous/aws/test_connection.py +++ b/t/unit/asynchronous/aws/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager from io import StringIO from unittest.mock import Mock diff --git a/t/unit/asynchronous/http/test_curl.py b/t/unit/asynchronous/http/test_curl.py index db8f5f91..51f9128e 100644 --- a/t/unit/asynchronous/http/test_curl.py +++ b/t/unit/asynchronous/http/test_curl.py @@ -1,4 +1,7 @@ -from unittest.mock import Mock, call, patch +from __future__ import annotations + +from io import BytesIO +from unittest.mock import ANY, Mock, call, patch import pytest @@ -131,3 +134,24 @@ class test_CurlClient: x._on_event.assert_called_with(fd, _pycurl.CSELECT_IN) x.on_writable(fd, _pycurl=_pycurl) x._on_event.assert_called_with(fd, _pycurl.CSELECT_OUT) + + def test_setup_request_sets_proxy_when_specified(self): + with patch('kombu.asynchronous.http.curl.pycurl') as _pycurl: + x = self.Client() + proxy_host = 'http://www.example.com' + request = Mock( + name='request', headers={}, auth_mode=None, proxy_host=None + ) + proxied_request = Mock( + name='request', headers={}, auth_mode=None, + proxy_host=proxy_host, proxy_port=123 + ) + x._setup_request( + x.Curl, request, BytesIO(), x.Headers(), _pycurl=_pycurl + ) + with pytest.raises(AssertionError): + x.Curl.setopt.assert_any_call(_pycurl.PROXY, ANY) + x._setup_request( + x.Curl, proxied_request, BytesIO(), x.Headers(), _pycurl + ) + x.Curl.setopt.assert_any_call(_pycurl.PROXY, proxy_host) diff --git a/t/unit/asynchronous/http/test_http.py b/t/unit/asynchronous/http/test_http.py index 6e6abdcb..816bf89d 100644 --- a/t/unit/asynchronous/http/test_http.py +++ b/t/unit/asynchronous/http/test_http.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from io import BytesIO from unittest.mock import Mock diff --git a/t/unit/asynchronous/test_hub.py b/t/unit/asynchronous/test_hub.py index eae25357..27b048b9 100644 --- a/t/unit/asynchronous/test_hub.py +++ b/t/unit/asynchronous/test_hub.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import errno -from unittest.mock import Mock, call, patch +from unittest.mock import ANY, Mock, call, patch import pytest from vine import promise @@ -187,6 +189,12 @@ class test_Hub: assert promise() in self.hub._ready assert ret is promise() + def test_call_soon_uses_lock(self): + callback = Mock(name='callback') + with patch.object(self.hub, '_ready_lock', autospec=True) as lock: + self.hub.call_soon(callback) + assert lock.__enter__.called_once() + def test_call_soon__promise_argument(self): callback = promise(Mock(name='callback'), (1, 2, 3)) ret = self.hub.call_soon(callback) @@ -533,3 +541,31 @@ class test_Hub: callbacks[0].assert_called_once_with() callbacks[1].assert_called_once_with() deferred.assert_not_called() + + def test_loop__no_todo_tick_delay(self): + cb = Mock(name='parent') + cb.todo, cb.tick, cb.poller = Mock(), Mock(), Mock() + cb.poller.poll.side_effect = lambda obj: () + self.hub.poller = cb.poller + self.hub.add(2, Mock(), READ) + self.hub.call_soon(cb.todo) + self.hub.on_tick = [cb.tick] + + next(self.hub.loop) + + cb.assert_has_calls([ + call.todo(), + call.tick(), + call.poller.poll(ANY), + ]) + + def test__pop_ready_pops_ready_items(self): + self.hub._ready.add(None) + ret = self.hub._pop_ready() + assert ret == {None} + assert self.hub._ready == set() + + def test__pop_ready_uses_lock(self): + with patch.object(self.hub, '_ready_lock', autospec=True) as lock: + self.hub._pop_ready() + assert lock.__enter__.called_once() diff --git a/t/unit/asynchronous/test_semaphore.py b/t/unit/asynchronous/test_semaphore.py index 8767ca91..5c41a6d8 100644 --- a/t/unit/asynchronous/test_semaphore.py +++ b/t/unit/asynchronous/test_semaphore.py @@ -1,11 +1,13 @@ +from __future__ import annotations + from kombu.asynchronous.semaphore import LaxBoundedSemaphore class test_LaxBoundedSemaphore: - def test_over_release(self): + def test_over_release(self) -> None: x = LaxBoundedSemaphore(2) - calls = [] + calls: list[int] = [] for i in range(1, 21): x.acquire(calls.append, i) x.release() diff --git a/t/unit/asynchronous/test_timer.py b/t/unit/asynchronous/test_timer.py index 20411784..531b3d2e 100644 --- a/t/unit/asynchronous/test_timer.py +++ b/t/unit/asynchronous/test_timer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from unittest.mock import Mock, patch diff --git a/t/unit/conftest.py b/t/unit/conftest.py index b798e3e5..15e31366 100644 --- a/t/unit/conftest.py +++ b/t/unit/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import atexit import builtins import io diff --git a/t/unit/test_clocks.py b/t/unit/test_clocks.py index b4392440..8f2d1340 100644 --- a/t/unit/test_clocks.py +++ b/t/unit/test_clocks.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from heapq import heappush from time import time @@ -8,7 +10,7 @@ from kombu.clocks import LamportClock, timetuple class test_LamportClock: - def test_clocks(self): + def test_clocks(self) -> None: c1 = LamportClock() c2 = LamportClock() @@ -29,12 +31,12 @@ class test_LamportClock: c1.adjust(c2.value) assert c1.value == c2.value + 1 - def test_sort(self): + def test_sort(self) -> None: c = LamportClock() pid1 = 'a.example.com:312' pid2 = 'b.example.com:311' - events = [] + events: list[tuple[int, str]] = [] m1 = (c.forward(), pid1) heappush(events, m1) @@ -56,15 +58,15 @@ class test_LamportClock: class test_timetuple: - def test_repr(self): + def test_repr(self) -> None: x = timetuple(133, time(), 'id', Mock()) assert repr(x) - def test_pickleable(self): + def test_pickleable(self) -> None: x = timetuple(133, time(), 'id', 'obj') assert pickle.loads(pickle.dumps(x)) == tuple(x) - def test_order(self): + def test_order(self) -> None: t1 = time() t2 = time() + 300 # windows clock not reliable a = timetuple(133, t1, 'A', 'obj') @@ -81,5 +83,6 @@ class test_timetuple: NotImplemented) assert timetuple(134, t2, 'A', 'obj') > timetuple(133, t1, 'A', 'obj') assert timetuple(134, t1, 'B', 'obj') > timetuple(134, t1, 'A', 'obj') - assert (timetuple(None, t2, 'B', 'obj') > - timetuple(None, t1, 'A', 'obj')) + assert ( + timetuple(None, t2, 'B', 'obj') > timetuple(None, t1, 'A', 'obj') + ) diff --git a/t/unit/test_common.py b/t/unit/test_common.py index 0f669b7d..fd20243f 100644 --- a/t/unit/test_common.py +++ b/t/unit/test_common.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import socket +from typing import TYPE_CHECKING from unittest.mock import Mock, patch import pytest @@ -10,6 +13,9 @@ from kombu.common import (PREFETCH_COUNT_MAX, Broadcast, QoS, collect_replies, maybe_declare, send_reply) from t.mocks import ContextMock, MockPool +if TYPE_CHECKING: + from types import TracebackType + def test_generate_oid(): from uuid import NAMESPACE_OID @@ -338,7 +344,12 @@ class MockConsumer: self.consumers.add(self) return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: self.consumers.discard(self) diff --git a/t/unit/test_compat.py b/t/unit/test_compat.py index d75ce5df..837d6f22 100644 --- a/t/unit/test_compat.py +++ b/t/unit/test_compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest @@ -115,12 +117,14 @@ class test_Publisher: pub.close() def test__enter__exit__(self): - pub = compat.Publisher(self.connection, - exchange='test_Publisher_send', - routing_key='rkey') - x = pub.__enter__() - assert x is pub - x.__exit__() + pub = compat.Publisher( + self.connection, + exchange='test_Publisher_send', + routing_key='rkey' + ) + with pub as x: + assert x is pub + assert pub._closed @@ -158,11 +162,14 @@ class test_Consumer: assert q2.exchange.auto_delete def test__enter__exit__(self, n='test__enter__exit__'): - c = compat.Consumer(self.connection, queue=n, exchange=n, - routing_key='rkey') - x = c.__enter__() - assert x is c - x.__exit__() + c = compat.Consumer( + self.connection, + queue=n, + exchange=n, + routing_key='rkey' + ) + with c as x: + assert x is c assert c._closed def test_revive(self, n='test_revive'): diff --git a/t/unit/test_compression.py b/t/unit/test_compression.py index f1f426b7..95139811 100644 --- a/t/unit/test_compression.py +++ b/t/unit/test_compression.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys import pytest diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 0b184d3b..c2daee3b 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import socket from copy import copy, deepcopy @@ -9,7 +11,7 @@ from kombu import Connection, Consumer, Producer, parse_url from kombu.connection import Resource from kombu.exceptions import OperationalError from kombu.utils.functional import lazy -from t.mocks import Transport +from t.mocks import TimeoutingTransport, Transport class test_connection_utils: @@ -99,6 +101,19 @@ class test_connection_utils: # see Appendix A of http://www.rabbitmq.com/uri-spec.html self.assert_info(Connection(url), **expected) + @pytest.mark.parametrize('url,expected', [ + ('sqs://user:pass@', + {'userid': None, 'password': None, 'hostname': None, + 'port': None, 'virtual_host': '/'}), + ('sqs://', + {'userid': None, 'password': None, 'hostname': None, + 'port': None, 'virtual_host': '/'}), + ]) + def test_sqs_example_urls(self, url, expected, caplog): + pytest.importorskip('boto3') + self.assert_info(Connection('sqs://'), **expected) + assert not caplog.records + @pytest.mark.skip('TODO: urllib cannot parse ipv6 urls') def test_url_IPV6(self): self.assert_info( @@ -293,7 +308,9 @@ class test_Connection: assert not c.is_evented def test_register_with_event_loop(self): - c = Connection(transport=Mock) + transport = Mock(name='transport') + transport.connection_errors = [] + c = Connection(transport=transport) loop = Mock(name='loop') c.register_with_event_loop(loop) c.transport.register_with_event_loop.assert_called_with( @@ -383,14 +400,12 @@ class test_Connection: qsms.assert_called_with(self.conn.connection) def test__enter____exit__(self): - conn = self.conn - context = conn.__enter__() - assert context is conn - conn.connect() - assert conn.connection.connected - conn.__exit__() - assert conn.connection is None - conn.close() # again + with self.conn as context: + assert context is self.conn + self.conn.connect() + assert self.conn.connection.connected + assert self.conn.connection is None + self.conn.close() # again def test_close_survives_connerror(self): @@ -477,15 +492,52 @@ class test_Connection: def publish(): raise _ConnectionError('failed connection') - self.conn.transport.connection_errors = (_ConnectionError,) + self.conn.get_transport_cls().connection_errors = (_ConnectionError,) ensured = self.conn.ensure(self.conn, publish) with pytest.raises(OperationalError): ensured() + def test_ensure_retry_errors_is_not_looping_infinitely(self): + class _MessageNacked(Exception): + pass + + def publish(): + raise _MessageNacked('NACK') + + with pytest.raises(ValueError): + self.conn.ensure( + self.conn, + publish, + retry_errors=(_MessageNacked,) + ) + + def test_ensure_retry_errors_is_limited_by_max_retries(self): + class _MessageNacked(Exception): + pass + + tries = 0 + + def publish(): + nonlocal tries + tries += 1 + if tries <= 3: + raise _MessageNacked('NACK') + # On the 4th try, we let it pass + return 'ACK' + + ensured = self.conn.ensure( + self.conn, + publish, + max_retries=3, # 3 retries + 1 initial try = 4 tries + retry_errors=(_MessageNacked,) + ) + + assert ensured() == 'ACK' + def test_autoretry(self): myfun = Mock() - self.conn.transport.connection_errors = (KeyError,) + self.conn.get_transport_cls().connection_errors = (KeyError,) def on_call(*args, **kwargs): myfun.side_effect = None @@ -571,6 +623,18 @@ class test_Connection: conn = Connection(transport=MyTransport) assert conn.channel_errors == (KeyError, ValueError) + def test_channel_errors__exception_no_cache(self): + """Ensure the channel_errors can be retrieved without an initialized + transport. + """ + + class MyTransport(Transport): + channel_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.channel_errors == (KeyError,) + def test_connection_errors(self): class MyTransport(Transport): @@ -579,6 +643,80 @@ class test_Connection: conn = Connection(transport=MyTransport) assert conn.connection_errors == (KeyError, ValueError) + def test_connection_errors__exception_no_cache(self): + """Ensure the connection_errors can be retrieved without an + initialized transport. + """ + + class MyTransport(Transport): + connection_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.connection_errors == (KeyError,) + + def test_recoverable_connection_errors(self): + + class MyTransport(Transport): + recoverable_connection_errors = (KeyError, ValueError) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_connection_errors == (KeyError, ValueError) + + def test_recoverable_connection_errors__fallback(self): + """Ensure missing recoverable_connection_errors on the Transport does + not cause a fatal error. + """ + + class MyTransport(Transport): + connection_errors = (KeyError,) + channel_errors = (ValueError,) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_connection_errors == (KeyError, ValueError) + + def test_recoverable_connection_errors__exception_no_cache(self): + """Ensure the recoverable_connection_errors can be retrieved without + an initialized transport. + """ + + class MyTransport(Transport): + recoverable_connection_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.recoverable_connection_errors == (KeyError,) + + def test_recoverable_channel_errors(self): + + class MyTransport(Transport): + recoverable_channel_errors = (KeyError, ValueError) + + conn = Connection(transport=MyTransport) + assert conn.recoverable_channel_errors == (KeyError, ValueError) + + def test_recoverable_channel_errors__fallback(self): + """Ensure missing recoverable_channel_errors on the Transport does not + cause a fatal error. + """ + + class MyTransport(Transport): + pass + + conn = Connection(transport=MyTransport) + assert conn.recoverable_channel_errors == () + + def test_recoverable_channel_errors__exception_no_cache(self): + """Ensure the recoverable_channel_errors can be retrieved without an + initialized transport. + """ + class MyTransport(Transport): + recoverable_channel_errors = (KeyError,) + + conn = Connection(transport=MyTransport) + MyTransport.__init__ = Mock(side_effect=Exception) + assert conn.recoverable_channel_errors == (KeyError,) + def test_multiple_urls_hostname(self): conn = Connection(['example.com;amqp://example.com']) assert conn.as_uri() == 'amqp://guest:**@example.com:5672//' @@ -587,6 +725,47 @@ class test_Connection: conn = Connection('example.com;example.com;') assert conn.as_uri() == 'amqp://guest:**@example.com:5672//' + def test_connection_respect_its_timeout(self): + invalid_port = 1222 + with Connection( + f'amqp://guest:guest@localhost:{invalid_port}//', + transport_options={'max_retries': 2}, + connect_timeout=1 + ) as conn: + with pytest.raises(OperationalError): + conn.default_channel + + def test_connection_failover_without_total_timeout(self): + with Connection( + ['server1', 'server2'], + transport=TimeoutingTransport, + connect_timeout=1, + transport_options={'interval_start': 0, 'interval_step': 0}, + ) as conn: + conn._establish_connection = Mock( + side_effect=conn._establish_connection + ) + with pytest.raises(OperationalError): + conn.default_channel + # Never retried, because `retry_over_time` `timeout` is equal + # to `connect_timeout` + conn._establish_connection.assert_called_once() + + def test_connection_failover_with_total_timeout(self): + with Connection( + ['server1', 'server2'], + transport=TimeoutingTransport, + connect_timeout=1, + transport_options={'connect_retries_timeout': 2, + 'interval_start': 0, 'interval_step': 0}, + ) as conn: + conn._establish_connection = Mock( + side_effect=conn._establish_connection + ) + with pytest.raises(OperationalError): + conn.default_channel + assert conn._establish_connection.call_count == 2 + class test_Connection_with_transport_options: diff --git a/t/unit/test_entity.py b/t/unit/test_entity.py index 52c42b2b..fcb0afb9 100644 --- a/t/unit/test_entity.py +++ b/t/unit/test_entity.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from unittest.mock import Mock, call @@ -10,13 +12,13 @@ from kombu.serialization import registry from t.mocks import Transport -def get_conn(): +def get_conn() -> Connection: return Connection(transport=Transport) class test_binding: - def test_constructor(self): + def test_constructor(self) -> None: x = binding( Exchange('foo'), 'rkey', arguments={'barg': 'bval'}, @@ -27,31 +29,31 @@ class test_binding: assert x.arguments == {'barg': 'bval'} assert x.unbind_arguments == {'uarg': 'uval'} - def test_declare(self): + def test_declare(self) -> None: chan = get_conn().channel() x = binding(Exchange('foo'), 'rkey') x.declare(chan) assert 'exchange_declare' in chan - def test_declare_no_exchange(self): + def test_declare_no_exchange(self) -> None: chan = get_conn().channel() x = binding() x.declare(chan) assert 'exchange_declare' not in chan - def test_bind(self): + def test_bind(self) -> None: chan = get_conn().channel() x = binding(Exchange('foo')) x.bind(Exchange('bar')(chan)) assert 'exchange_bind' in chan - def test_unbind(self): + def test_unbind(self) -> None: chan = get_conn().channel() x = binding(Exchange('foo')) x.unbind(Exchange('bar')(chan)) assert 'exchange_unbind' in chan - def test_repr(self): + def test_repr(self) -> None: b = binding(Exchange('foo'), 'rkey') assert 'foo' in repr(b) assert 'rkey' in repr(b) @@ -59,7 +61,7 @@ class test_binding: class test_Exchange: - def test_bound(self): + def test_bound(self) -> None: exchange = Exchange('foo', 'direct') assert not exchange.is_bound assert '<unbound' in repr(exchange) @@ -70,11 +72,11 @@ class test_Exchange: assert bound.channel is chan assert f'bound to chan:{chan.channel_id!r}' in repr(bound) - def test_hash(self): + def test_hash(self) -> None: assert hash(Exchange('a')) == hash(Exchange('a')) assert hash(Exchange('a')) != hash(Exchange('b')) - def test_can_cache_declaration(self): + def test_can_cache_declaration(self) -> None: assert Exchange('a', durable=True).can_cache_declaration assert Exchange('a', durable=False).can_cache_declaration assert not Exchange('a', auto_delete=True).can_cache_declaration @@ -82,12 +84,12 @@ class test_Exchange: 'a', durable=True, auto_delete=True, ).can_cache_declaration - def test_pickle(self): + def test_pickle(self) -> None: e1 = Exchange('foo', 'direct') e2 = pickle.loads(pickle.dumps(e1)) assert e1 == e2 - def test_eq(self): + def test_eq(self) -> None: e1 = Exchange('foo', 'direct') e2 = Exchange('foo', 'direct') assert e1 == e2 @@ -97,7 +99,7 @@ class test_Exchange: assert e1.__eq__(True) == NotImplemented - def test_revive(self): + def test_revive(self) -> None: exchange = Exchange('foo', 'direct') conn = get_conn() chan = conn.channel() @@ -116,7 +118,7 @@ class test_Exchange: assert bound.is_bound assert bound._channel is chan2 - def test_assert_is_bound(self): + def test_assert_is_bound(self) -> None: exchange = Exchange('foo', 'direct') with pytest.raises(NotBoundError): exchange.declare() @@ -126,80 +128,80 @@ class test_Exchange: exchange.bind(chan).declare() assert 'exchange_declare' in chan - def test_set_transient_delivery_mode(self): + def test_set_transient_delivery_mode(self) -> None: exc = Exchange('foo', 'direct', delivery_mode='transient') assert exc.delivery_mode == Exchange.TRANSIENT_DELIVERY_MODE - def test_set_passive_mode(self): + def test_set_passive_mode(self) -> None: exc = Exchange('foo', 'direct', passive=True) assert exc.passive - def test_set_persistent_delivery_mode(self): + def test_set_persistent_delivery_mode(self) -> None: exc = Exchange('foo', 'direct', delivery_mode='persistent') assert exc.delivery_mode == Exchange.PERSISTENT_DELIVERY_MODE - def test_bind_at_instantiation(self): + def test_bind_at_instantiation(self) -> None: assert Exchange('foo', channel=get_conn().channel()).is_bound - def test_create_message(self): + def test_create_message(self) -> None: chan = get_conn().channel() Exchange('foo', channel=chan).Message({'foo': 'bar'}) assert 'prepare_message' in chan - def test_publish(self): + def test_publish(self) -> None: chan = get_conn().channel() Exchange('foo', channel=chan).publish('the quick brown fox') assert 'basic_publish' in chan - def test_delete(self): + def test_delete(self) -> None: chan = get_conn().channel() Exchange('foo', channel=chan).delete() assert 'exchange_delete' in chan - def test__repr__(self): + def test__repr__(self) -> None: b = Exchange('foo', 'topic') assert 'foo(topic)' in repr(b) assert 'Exchange' in repr(b) - def test_bind_to(self): + def test_bind_to(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') bar = Exchange('bar', 'topic') foo(chan).bind_to(bar) assert 'exchange_bind' in chan - def test_bind_to_by_name(self): + def test_bind_to_by_name(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') foo(chan).bind_to('bar') assert 'exchange_bind' in chan - def test_unbind_from(self): + def test_unbind_from(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') bar = Exchange('bar', 'topic') foo(chan).unbind_from(bar) assert 'exchange_unbind' in chan - def test_unbind_from_by_name(self): + def test_unbind_from_by_name(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic') foo(chan).unbind_from('bar') assert 'exchange_unbind' in chan - def test_declare__no_declare(self): + def test_declare__no_declare(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic', no_declare=True) foo(chan).declare() assert 'exchange_declare' not in chan - def test_declare__internal_exchange(self): + def test_declare__internal_exchange(self) -> None: chan = get_conn().channel() foo = Exchange('amq.rabbitmq.trace', 'topic') foo(chan).declare() assert 'exchange_declare' not in chan - def test_declare(self): + def test_declare(self) -> None: chan = get_conn().channel() foo = Exchange('foo', 'topic', no_declare=False) foo(chan).declare() @@ -208,33 +210,33 @@ class test_Exchange: class test_Queue: - def setup(self): + def setup(self) -> None: self.exchange = Exchange('foo', 'direct') - def test_constructor_with_actual_exchange(self): + def test_constructor_with_actual_exchange(self) -> None: exchange = Exchange('exchange_name', 'direct') queue = Queue(name='queue_name', exchange=exchange) assert queue.exchange == exchange - def test_constructor_with_string_exchange(self): + def test_constructor_with_string_exchange(self) -> None: exchange_name = 'exchange_name' queue = Queue(name='queue_name', exchange=exchange_name) assert queue.exchange == Exchange(exchange_name) - def test_constructor_with_default_exchange(self): + def test_constructor_with_default_exchange(self) -> None: queue = Queue(name='queue_name') assert queue.exchange == Exchange('') - def test_hash(self): + def test_hash(self) -> None: assert hash(Queue('a')) == hash(Queue('a')) assert hash(Queue('a')) != hash(Queue('b')) - def test_repr_with_bindings(self): + def test_repr_with_bindings(self) -> None: ex = Exchange('foo') x = Queue('foo', bindings=[ex.binding('A'), ex.binding('B')]) assert repr(x) - def test_anonymous(self): + def test_anonymous(self) -> None: chan = Mock() x = Queue(bindings=[binding(Exchange('foo'), 'rkey')]) chan.queue_declare.return_value = 'generated', 0, 0 @@ -242,7 +244,7 @@ class test_Queue: xx.declare() assert xx.name == 'generated' - def test_basic_get__accept_disallowed(self): + def test_basic_get__accept_disallowed(self) -> None: conn = Connection('memory://') q = Queue('foo', exchange=self.exchange) p = Producer(conn) @@ -257,7 +259,7 @@ class test_Queue: with pytest.raises(q.ContentDisallowed): message.decode() - def test_basic_get__accept_allowed(self): + def test_basic_get__accept_allowed(self) -> None: conn = Connection('memory://') q = Queue('foo', exchange=self.exchange) p = Producer(conn) @@ -272,12 +274,12 @@ class test_Queue: payload = message.decode() assert payload['complex'] - def test_when_bound_but_no_exchange(self): + def test_when_bound_but_no_exchange(self) -> None: q = Queue('a') q.exchange = None assert q.when_bound() is None - def test_declare_but_no_exchange(self): + def test_declare_but_no_exchange(self) -> None: q = Queue('a') q.queue_declare = Mock() q.queue_bind = Mock() @@ -287,7 +289,7 @@ class test_Queue: q.queue_declare.assert_called_with( channel=None, nowait=False, passive=False) - def test_declare__no_declare(self): + def test_declare__no_declare(self) -> None: q = Queue('a', no_declare=True) q.queue_declare = Mock() q.queue_bind = Mock() @@ -297,19 +299,19 @@ class test_Queue: q.queue_declare.assert_not_called() q.queue_bind.assert_not_called() - def test_bind_to_when_name(self): + def test_bind_to_when_name(self) -> None: chan = Mock() q = Queue('a') q(chan).bind_to('ex') chan.queue_bind.assert_called() - def test_get_when_no_m2p(self): + def test_get_when_no_m2p(self) -> None: chan = Mock() q = Queue('a')(chan) chan.message_to_python = None assert q.get() - def test_multiple_bindings(self): + def test_multiple_bindings(self) -> None: chan = Mock() q = Queue('mul', [ binding(Exchange('mul1'), 'rkey1'), @@ -327,14 +329,14 @@ class test_Queue: durable=True, ) in chan.exchange_declare.call_args_list - def test_can_cache_declaration(self): + def test_can_cache_declaration(self) -> None: assert Queue('a', durable=True).can_cache_declaration assert Queue('a', durable=False).can_cache_declaration assert not Queue( 'a', queue_arguments={'x-expires': 100} ).can_cache_declaration - def test_eq(self): + def test_eq(self) -> None: q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') q2 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') assert q1 == q2 @@ -343,14 +345,14 @@ class test_Queue: q3 = Queue('yyy', Exchange('xxx', 'direct'), 'xxx') assert q1 != q3 - def test_exclusive_implies_auto_delete(self): + def test_exclusive_implies_auto_delete(self) -> None: assert Queue('foo', self.exchange, exclusive=True).auto_delete - def test_binds_at_instantiation(self): + def test_binds_at_instantiation(self) -> None: assert Queue('foo', self.exchange, channel=get_conn().channel()).is_bound - def test_also_binds_exchange(self): + def test_also_binds_exchange(self) -> None: chan = get_conn().channel() b = Queue('foo', self.exchange) assert not b.is_bound @@ -361,7 +363,7 @@ class test_Queue: assert b.channel is b.exchange.channel assert b.exchange is not self.exchange - def test_declare(self): + def test_declare(self) -> None: chan = get_conn().channel() b = Queue('foo', self.exchange, 'foo', channel=chan) assert b.is_bound @@ -370,49 +372,49 @@ class test_Queue: assert 'queue_declare' in chan assert 'queue_bind' in chan - def test_get(self): + def test_get(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.get() assert 'basic_get' in b.channel - def test_purge(self): + def test_purge(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.purge() assert 'queue_purge' in b.channel - def test_consume(self): + def test_consume(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.consume('fifafo', None) assert 'basic_consume' in b.channel - def test_cancel(self): + def test_cancel(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.cancel('fifafo') assert 'basic_cancel' in b.channel - def test_delete(self): + def test_delete(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.delete() assert 'queue_delete' in b.channel - def test_queue_unbind(self): + def test_queue_unbind(self) -> None: b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.queue_unbind() assert 'queue_unbind' in b.channel - def test_as_dict(self): + def test_as_dict(self) -> None: q = Queue('foo', self.exchange, 'rk') d = q.as_dict(recurse=True) assert d['exchange']['name'] == self.exchange.name - def test_queue_dump(self): + def test_queue_dump(self) -> None: b = binding(self.exchange, 'rk') q = Queue('foo', self.exchange, 'rk', bindings=[b]) d = q.as_dict(recurse=True) assert d['bindings'][0]['routing_key'] == 'rk' registry.dumps(d) - def test__repr__(self): + def test__repr__(self) -> None: b = Queue('foo', self.exchange, 'foo') assert 'foo' in repr(b) assert 'Queue' in repr(b) @@ -420,5 +422,5 @@ class test_Queue: class test_MaybeChannelBound: - def test_repr(self): + def test_repr(self) -> None: assert repr(MaybeChannelBound()) diff --git a/t/unit/test_exceptions.py b/t/unit/test_exceptions.py index bba72a83..7e67fc6f 100644 --- a/t/unit/test_exceptions.py +++ b/t/unit/test_exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock from kombu.exceptions import HttpError @@ -5,5 +7,5 @@ from kombu.exceptions import HttpError class test_HttpError: - def test_str(self): + def test_str(self) -> None: assert str(HttpError(200, 'msg', Mock(name='response'))) diff --git a/t/unit/test_log.py b/t/unit/test_log.py index 4a8cd94c..30c6796f 100644 --- a/t/unit/test_log.py +++ b/t/unit/test_log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import sys from unittest.mock import ANY, Mock, patch diff --git a/t/unit/test_matcher.py b/t/unit/test_matcher.py index 2100fa74..37ae5207 100644 --- a/t/unit/test_matcher.py +++ b/t/unit/test_matcher.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu.matcher import (MatcherNotInstalled, fnmatch, match, register, diff --git a/t/unit/test_message.py b/t/unit/test_message.py index 5b0833dd..4c53cac2 100644 --- a/t/unit/test_message.py +++ b/t/unit/test_message.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from unittest.mock import Mock, patch diff --git a/t/unit/test_messaging.py b/t/unit/test_messaging.py index f8ed437c..4bd467c2 100644 --- a/t/unit/test_messaging.py +++ b/t/unit/test_messaging.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle import sys from collections import defaultdict @@ -188,9 +190,8 @@ class test_Producer: def test_enter_exit(self): p = self.connection.Producer() p.release = Mock() - - assert p.__enter__() is p - p.__exit__() + with p as x: + assert x is p p.release.assert_called_with() def test_connection_property_handles_AttributeError(self): diff --git a/t/unit/test_mixins.py b/t/unit/test_mixins.py index 04a56a6c..39b7370f 100644 --- a/t/unit/test_mixins.py +++ b/t/unit/test_mixins.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket from unittest.mock import Mock, patch diff --git a/t/unit/test_pidbox.py b/t/unit/test_pidbox.py index fac46139..cf8a748a 100644 --- a/t/unit/test_pidbox.py +++ b/t/unit/test_pidbox.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import warnings from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor diff --git a/t/unit/test_pools.py b/t/unit/test_pools.py index eb2a556e..1557da95 100644 --- a/t/unit/test_pools.py +++ b/t/unit/test_pools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest @@ -139,7 +141,7 @@ class test_PoolGroup: def test_delitem(self): g = self.MyGroup() g['foo'] - del(g['foo']) + del g['foo'] assert 'foo' not in g def test_Connections(self): diff --git a/t/unit/test_serialization.py b/t/unit/test_serialization.py index 14952e5e..d3fd5c20 100644 --- a/t/unit/test_serialization.py +++ b/t/unit/test_serialization.py @@ -1,5 +1,7 @@ #!/usr/bin/python +from __future__ import annotations + from base64 import b64decode from unittest.mock import call, patch diff --git a/t/unit/test_simple.py b/t/unit/test_simple.py index a5cd899a..50ea880b 100644 --- a/t/unit/test_simple.py +++ b/t/unit/test_simple.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest @@ -91,9 +93,8 @@ class SimpleBase: def test_enter_exit(self): q = self.Queue('test_enter_exit') q.close = Mock() - - assert q.__enter__() is q - q.__exit__() + with q as x: + assert x is q q.close.assert_called_with() def test_qsize(self): diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 944728f1..2b1219fc 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -4,6 +4,8 @@ NOTE: The SQSQueueMock and SQSConnectionMock classes originally come from http://github.com/pcsforeducation/sqs-mock-python. They have been patched slightly. """ +from __future__ import annotations + import base64 import os import random @@ -38,6 +40,11 @@ example_predefined_queues = { 'access_key_id': 'c', 'secret_access_key': 'd', }, + 'queue-3.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue-3.fifo', + 'access_key_id': 'e', + 'secret_access_key': 'f', + } } @@ -151,6 +158,7 @@ class test_Channel: predefined_queues_sqs_conn_mocks = { 'queue-1': SQSClientMock(QueueName='queue-1'), 'queue-2': SQSClientMock(QueueName='queue-2'), + 'queue-3.fifo': SQSClientMock(QueueName='queue-3.fifo') } def mock_sqs(): @@ -330,13 +338,13 @@ class test_Channel: with pytest.raises(Empty): self.channel._get_bulk(self.queue_name) - def test_is_base64_encoded(self): + def test_optional_b64_decode(self): raw = b'{"id": "4cc7438e-afd4-4f8f-a2f3-f46567e7ca77","task": "celery.task.PingTask",' \ b'"args": [],"kwargs": {},"retries": 0,"eta": "2009-11-17T12:30:56.527191"}' # noqa b64_enc = base64.b64encode(raw) - assert self.channel._Channel__b64_encoded(b64_enc) - assert not self.channel._Channel__b64_encoded(raw) - assert not self.channel._Channel__b64_encoded(b"test123") + assert self.channel._optional_b64_decode(b64_enc) == raw + assert self.channel._optional_b64_decode(raw) == raw + assert self.channel._optional_b64_decode(b"test123") == b"test123" def test_messages_to_python(self): from kombu.asynchronous.aws.sqs.message import Message @@ -738,6 +746,77 @@ class test_Channel: QueueUrl='https://sqs.us-east-1.amazonaws.com/xxx/queue-1', ReceiptHandle='test_message_id', VisibilityTimeout=20) + def test_predefined_queues_put_to_fifo_queue(self): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': example_predefined_queues, + }) + channel = connection.channel() + + queue_name = 'queue-3.fifo' + + exchange = Exchange('test_SQS', type='direct') + p = messaging.Producer(channel, exchange, routing_key=queue_name) + + queue = Queue(queue_name, exchange, queue_name) + queue(channel).declare() + + channel.sqs = Mock() + sqs_queue_mock = Mock() + channel.sqs.return_value = sqs_queue_mock + p.publish('message') + + sqs_queue_mock.send_message.assert_called_once() + assert 'MessageGroupId' in sqs_queue_mock.send_message.call_args[1] + assert 'MessageDeduplicationId' in \ + sqs_queue_mock.send_message.call_args[1] + + def test_predefined_queues_put_to_queue(self): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': example_predefined_queues, + }) + channel = connection.channel() + + queue_name = 'queue-2' + + exchange = Exchange('test_SQS', type='direct') + p = messaging.Producer(channel, exchange, routing_key=queue_name) + + queue = Queue(queue_name, exchange, queue_name) + queue(channel).declare() + + channel.sqs = Mock() + sqs_queue_mock = Mock() + channel.sqs.return_value = sqs_queue_mock + p.publish('message', DelaySeconds=10) + + sqs_queue_mock.send_message.assert_called_once() + + assert 'DelaySeconds' in sqs_queue_mock.send_message.call_args[1] + assert sqs_queue_mock.send_message.call_args[1]['DelaySeconds'] == 10 + + @pytest.mark.parametrize('predefined_queues', ( + { + 'invalid-fifo-queue-name': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue.fifo', + 'access_key_id': 'a', + 'secret_access_key': 'b' + } + }, + { + 'standard-queue.fifo': { + 'url': 'https://sqs.us-east-1.amazonaws.com/xxx/queue', + 'access_key_id': 'a', + 'secret_access_key': 'b' + } + } + )) + def test_predefined_queues_invalid_configuration(self, predefined_queues): + connection = Connection(transport=SQS.Transport, transport_options={ + 'predefined_queues': predefined_queues, + }) + with pytest.raises(SQS.InvalidQueueException): + connection.channel() + def test_sts_new_session(self): # Arrange connection = Connection(transport=SQS.Transport, transport_options={ diff --git a/t/unit/transport/test_azureservicebus.py b/t/unit/transport/test_azureservicebus.py index 97775d06..5de93c2f 100644 --- a/t/unit/transport/test_azureservicebus.py +++ b/t/unit/transport/test_azureservicebus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import random @@ -201,14 +203,35 @@ MockQueue = namedtuple( ) +@pytest.fixture(autouse=True) +def sbac_class_patch(): + with patch('kombu.transport.azureservicebus.ServiceBusAdministrationClient') as sbac: # noqa + yield sbac + + +@pytest.fixture(autouse=True) +def sbc_class_patch(): + with patch('kombu.transport.azureservicebus.ServiceBusClient') as sbc: # noqa + yield sbc + + +@pytest.fixture(autouse=True) +def mock_clients( + sbc_class_patch, + sbac_class_patch, + mock_asb, + mock_asb_management +): + sbc_class_patch.from_connection_string.return_value = mock_asb + sbac_class_patch.from_connection_string.return_value = mock_asb_management + + @pytest.fixture def mock_queue(mock_asb, mock_asb_management, random_queue) -> MockQueue: exchange = Exchange('test_servicebus', type='direct') queue = Queue(random_queue, exchange, random_queue) conn = Connection(URL_CREDS, transport=azureservicebus.Transport) channel = conn.channel() - channel._queue_service = mock_asb - channel._queue_mgmt_service = mock_asb_management queue(channel).declare() producer = messaging.Producer(channel, exchange, routing_key=random_queue) diff --git a/t/unit/transport/test_azurestoragequeues.py b/t/unit/transport/test_azurestoragequeues.py new file mode 100644 index 00000000..0c9ef32a --- /dev/null +++ b/t/unit/transport/test_azurestoragequeues.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from unittest.mock import patch + +import pytest +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential + +from kombu import Connection + +pytest.importorskip('azure.storage.queue') +from kombu.transport import azurestoragequeues # noqa + +URL_NOCREDS = 'azurestoragequeues://' +URL_CREDS = 'azurestoragequeues://sas/key%@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa +AZURITE_CREDS = 'azurestoragequeues://Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==@http://localhost:10001/devstoreaccount1' # noqa +AZURITE_CREDS_DOCKER_COMPOSE = 'azurestoragequeues://Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==@http://azurite:10001/devstoreaccount1' # noqa +DEFAULT_AZURE_URL_CREDS = 'azurestoragequeues://DefaultAzureCredential@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa +MANAGED_IDENTITY_URL_CREDS = 'azurestoragequeues://ManagedIdentityCredential@https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa + + +def test_queue_service_nocredentials(): + conn = Connection(URL_NOCREDS, transport=azurestoragequeues.Transport) + with pytest.raises( + ValueError, + match='Need a URI like azurestoragequeues://{SAS or access key}@{URL}' + ): + conn.channel() + + +def test_queue_service(): + # Test gettings queue service without credentials + conn = Connection(URL_CREDS, transport=azurestoragequeues.Transport) + with patch('kombu.transport.azurestoragequeues.QueueServiceClient'): + channel = conn.channel() + + # Check the SAS token "sas/key%" has been parsed from the url correctly + assert channel._credential == 'sas/key%' + assert channel._url == 'https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/' # noqa + + +@pytest.mark.parametrize( + "creds, hostname", + [ + (AZURITE_CREDS, 'localhost'), + (AZURITE_CREDS_DOCKER_COMPOSE, 'azurite'), + ] +) +def test_queue_service_works_for_azurite(creds, hostname): + conn = Connection(creds, transport=azurestoragequeues.Transport) + with patch('kombu.transport.azurestoragequeues.QueueServiceClient'): + channel = conn.channel() + + assert channel._credential == { + 'account_name': 'devstoreaccount1', + 'account_key': 'Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==' # noqa + } + assert channel._url == f'http://{hostname}:10001/devstoreaccount1' # noqa + + +def test_queue_service_works_for_default_azure_credentials(): + conn = Connection( + DEFAULT_AZURE_URL_CREDS, transport=azurestoragequeues.Transport + ) + with patch("kombu.transport.azurestoragequeues.QueueServiceClient"): + channel = conn.channel() + + assert isinstance(channel._credential, DefaultAzureCredential) + assert ( + channel._url + == "https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/" + ) + + +def test_queue_service_works_for_managed_identity_credentials(): + conn = Connection( + MANAGED_IDENTITY_URL_CREDS, transport=azurestoragequeues.Transport + ) + with patch("kombu.transport.azurestoragequeues.QueueServiceClient"): + channel = conn.channel() + + assert isinstance(channel._credential, ManagedIdentityCredential) + assert ( + channel._url + == "https://STORAGE_ACCOUNT_NAME.queue.core.windows.net/" + ) diff --git a/t/unit/transport/test_base.py b/t/unit/transport/test_base.py index 7df12c9e..5beae3c6 100644 --- a/t/unit/transport/test_base.py +++ b/t/unit/transport/test_base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/transport/test_consul.py b/t/unit/transport/test_consul.py index ce6c4fcb..ff110e11 100644 --- a/t/unit/transport/test_consul.py +++ b/t/unit/transport/test_consul.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from array import array from queue import Empty from unittest.mock import Mock @@ -12,6 +15,8 @@ class test_Consul: def setup(self): self.connection = Mock() + self.connection._used_channel_ids = array('H') + self.connection.channel_max = 65535 self.connection.client.transport_options = {} self.connection.client.port = 303 self.consul = self.patching('consul.Consul').return_value diff --git a/t/unit/transport/test_etcd.py b/t/unit/transport/test_etcd.py index 6c75a033..f3fad035 100644 --- a/t/unit/transport/test_etcd.py +++ b/t/unit/transport/test_etcd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from queue import Empty from unittest.mock import Mock, patch diff --git a/t/unit/transport/test_filesystem.py b/t/unit/transport/test_filesystem.py index a8d1708b..20c7f47a 100644 --- a/t/unit/transport/test_filesystem.py +++ b/t/unit/transport/test_filesystem.py @@ -1,4 +1,9 @@ +from __future__ import annotations + import tempfile +from fcntl import LOCK_EX, LOCK_SH +from queue import Empty +from unittest.mock import call, patch import pytest @@ -138,3 +143,162 @@ class test_FilesystemTransport: assert self.q2(consumer_channel).get() self.q2(consumer_channel).purge() assert self.q2(consumer_channel).get() is None + + +@t.skip.if_win32 +class test_FilesystemFanout: + def setup(self): + try: + data_folder_in = tempfile.mkdtemp() + data_folder_out = tempfile.mkdtemp() + control_folder = tempfile.mkdtemp() + except Exception: + pytest.skip("filesystem transport: cannot create tempfiles") + + self.consumer_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_in, + "data_folder_out": data_folder_out, + "control_folder": control_folder, + }, + ) + self.consume_channel = self.consumer_connection.channel() + self.produce_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_out, + "data_folder_out": data_folder_in, + "control_folder": control_folder, + }, + ) + self.producer_channel = self.produce_connection.channel() + self.exchange = Exchange("filesystem_exchange_fanout", type="fanout") + self.q1 = Queue("queue1", exchange=self.exchange) + self.q2 = Queue("queue2", exchange=self.exchange) + + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in [self.producer_channel, self.consumer_connection]: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def test_produce_consume(self): + + producer = Producer(self.producer_channel, self.exchange) + consumer1 = Consumer(self.consume_channel, self.q1) + consumer2 = Consumer(self.consume_channel, self.q2) + self.q2(self.consume_channel).declare() + + for i in range(10): + producer.publish({"foo": i}) + + _received1 = [] + _received2 = [] + + def callback1(message_data, message): + _received1.append(message) + message.ack() + + def callback2(message_data, message): + _received2.append(message) + message.ack() + + consumer1.register_callback(callback1) + consumer2.register_callback(callback2) + + consumer1.consume() + consumer2.consume() + + while 1: + try: + self.consume_channel.drain_events() + except Empty: + break + + assert len(_received1) + len(_received2) == 20 + + # queue.delete + for i in range(10): + producer.publish({"foo": i}) + assert self.q1(self.consume_channel).get() + self.q1(self.consume_channel).delete() + self.q1(self.consume_channel).declare() + assert self.q1(self.consume_channel).get() is None + + # queue.purge + assert self.q2(self.consume_channel).get() + self.q2(self.consume_channel).purge() + assert self.q2(self.consume_channel).get() is None + + +@t.skip.if_win32 +class test_FilesystemLock: + def setup(self): + try: + data_folder_in = tempfile.mkdtemp() + data_folder_out = tempfile.mkdtemp() + control_folder = tempfile.mkdtemp() + except Exception: + pytest.skip("filesystem transport: cannot create tempfiles") + + self.consumer_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_in, + "data_folder_out": data_folder_out, + "control_folder": control_folder, + }, + ) + self.consume_channel = self.consumer_connection.channel() + self.produce_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_out, + "data_folder_out": data_folder_in, + "control_folder": control_folder, + }, + ) + self.producer_channel = self.produce_connection.channel() + self.exchange = Exchange("filesystem_exchange_lock", type="fanout") + self.q = Queue("queue1", exchange=self.exchange) + + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in [self.producer_channel, self.consumer_connection]: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def test_lock_during_process(self): + producer = Producer(self.producer_channel, self.exchange) + + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + Consumer(self.consume_channel, self.q) + assert unlock_m.call_count == 1 + lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX) + + self.q(self.consume_channel).declare() + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + producer.publish({"foo": 1}) + assert unlock_m.call_count == 2 + assert lock_m.call_count == 2 + exchange_file_obj = unlock_m.call_args_list[0][0][0] + msg_file_obj = unlock_m.call_args_list[1][0][0] + assert lock_m.call_args_list == [call(exchange_file_obj, LOCK_SH), + call(msg_file_obj, LOCK_EX)] diff --git a/t/unit/transport/test_librabbitmq.py b/t/unit/transport/test_librabbitmq.py index 58ee7e1e..84f8691e 100644 --- a/t/unit/transport/test_librabbitmq.py +++ b/t/unit/transport/test_librabbitmq.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock, patch import pytest diff --git a/t/unit/transport/test_memory.py b/t/unit/transport/test_memory.py index 2c1fe83f..c707d34c 100644 --- a/t/unit/transport/test_memory.py +++ b/t/unit/transport/test_memory.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import pytest @@ -131,8 +133,8 @@ class test_MemoryTransport: with pytest.raises(socket.timeout): self.c.drain_events(timeout=0.1) - del(c1) # so pyflakes doesn't complain. - del(c2) + del c1 # so pyflakes doesn't complain. + del c2 def test_drain_events_unregistered_queue(self): c1 = self.c.channel() diff --git a/t/unit/transport/test_mongodb.py b/t/unit/transport/test_mongodb.py index 39976988..6bb5f1f9 100644 --- a/t/unit/transport/test_mongodb.py +++ b/t/unit/transport/test_mongodb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from queue import Empty from unittest.mock import MagicMock, call, patch @@ -151,16 +153,15 @@ class test_mongodb_channel(BaseMongoDBChannelCase): def test_get(self): - self.set_operation_return_value('messages', 'find_and_modify', { + self.set_operation_return_value('messages', 'find_one_and_delete', { '_id': 'docId', 'payload': '{"some": "data"}', }) event = self.channel._get('foobar') self.assert_collection_accessed('messages') self.assert_operation_called_with( - 'messages', 'find_and_modify', - query={'queue': 'foobar'}, - remove=True, + 'messages', 'find_one_and_delete', + {'queue': 'foobar'}, sort=[ ('priority', pymongo.ASCENDING), ], @@ -168,7 +169,11 @@ class test_mongodb_channel(BaseMongoDBChannelCase): assert event == {'some': 'data'} - self.set_operation_return_value('messages', 'find_and_modify', None) + self.set_operation_return_value( + 'messages', + 'find_one_and_delete', + None, + ) with pytest.raises(Empty): self.channel._get('foobar') @@ -188,7 +193,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._put('foobar', {'some': 'data'}) self.assert_collection_accessed('messages') - self.assert_operation_called_with('messages', 'insert', { + self.assert_operation_called_with('messages', 'insert_one', { 'queue': 'foobar', 'priority': 9, 'payload': '{"some": "data"}', @@ -200,17 +205,17 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._put_fanout('foobar', {'some': 'data'}, 'foo') self.assert_collection_accessed('messages.broadcast') - self.assert_operation_called_with('broadcast', 'insert', { + self.assert_operation_called_with('broadcast', 'insert_one', { 'queue': 'foobar', 'payload': '{"some": "data"}', }) def test_size(self): - self.set_operation_return_value('messages', 'find.count', 77) + self.set_operation_return_value('messages', 'count_documents', 77) result = self.channel._size('foobar') self.assert_collection_accessed('messages') self.assert_operation_called_with( - 'messages', 'find', {'queue': 'foobar'}, + 'messages', 'count_documents', {'queue': 'foobar'}, ) assert result == 77 @@ -227,7 +232,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): assert result == 77 def test_purge(self): - self.set_operation_return_value('messages', 'find.count', 77) + self.set_operation_return_value('messages', 'count_documents', 77) result = self.channel._purge('foobar') self.assert_collection_accessed('messages') @@ -276,11 +281,11 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._queue_bind('test_exchange', 'foo', '*', 'foo') self.assert_collection_accessed('messages.routing') self.assert_operation_called_with( - 'routing', 'update', - {'queue': 'foo', 'pattern': '*', - 'routing_key': 'foo', 'exchange': 'test_exchange'}, + 'routing', 'update_one', {'queue': 'foo', 'pattern': '*', 'routing_key': 'foo', 'exchange': 'test_exchange'}, + {'$set': {'queue': 'foo', 'pattern': '*', + 'routing_key': 'foo', 'exchange': 'test_exchange'}}, upsert=True, ) @@ -317,16 +322,16 @@ class test_mongodb_channel(BaseMongoDBChannelCase): self.channel._ensure_indexes(self.channel.client) self.assert_operation_called_with( - 'messages', 'ensure_index', + 'messages', 'create_index', [('queue', 1), ('priority', 1), ('_id', 1)], background=True, ) self.assert_operation_called_with( - 'broadcast', 'ensure_index', + 'broadcast', 'create_index', [('queue', 1)], ) self.assert_operation_called_with( - 'routing', 'ensure_index', [('queue', 1), ('exchange', 1)], + 'routing', 'create_index', [('queue', 1), ('exchange', 1)], ) def test_create_broadcast_cursor(self): @@ -381,9 +386,9 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._new_queue('foobar') self.assert_operation_called_with( - 'queues', 'update', + 'queues', 'update_one', {'_id': 'foobar'}, - {'_id': 'foobar', 'options': {}, 'expire_at': None}, + {'$set': {'_id': 'foobar', 'options': {}, 'expire_at': None}}, upsert=True, ) @@ -393,25 +398,23 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, }) - self.set_operation_return_value('messages', 'find_and_modify', { + self.set_operation_return_value('messages', 'find_one_and_delete', { '_id': 'docId', 'payload': '{"some": "data"}', }) self.channel._get('foobar') self.assert_collection_accessed('messages', 'messages.queues') self.assert_operation_called_with( - 'messages', 'find_and_modify', - query={'queue': 'foobar'}, - remove=True, + 'messages', 'find_one_and_delete', + {'queue': 'foobar'}, sort=[ ('priority', pymongo.ASCENDING), ], ) self.assert_operation_called_with( - 'routing', 'update', + 'routing', 'update_many', {'queue': 'foobar'}, {'$set': {'expire_at': self.expire_at}}, - multi=True, ) def test_put(self): @@ -422,7 +425,7 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._put('foobar', {'some': 'data'}) self.assert_collection_accessed('messages') - self.assert_operation_called_with('messages', 'insert', { + self.assert_operation_called_with('messages', 'insert_one', { 'queue': 'foobar', 'priority': 9, 'payload': '{"some": "data"}', @@ -437,12 +440,14 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._queue_bind('test_exchange', 'foo', '*', 'foo') self.assert_collection_accessed('messages.routing') self.assert_operation_called_with( - 'routing', 'update', + 'routing', 'update_one', {'queue': 'foo', 'pattern': '*', 'routing_key': 'foo', 'exchange': 'test_exchange'}, - {'queue': 'foo', 'pattern': '*', - 'routing_key': 'foo', 'exchange': 'test_exchange', - 'expire_at': self.expire_at}, + {'$set': { + 'queue': 'foo', 'pattern': '*', + 'routing_key': 'foo', 'exchange': 'test_exchange', + 'expire_at': self.expire_at + }}, upsert=True, ) @@ -456,18 +461,18 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel._ensure_indexes(self.channel.client) self.assert_operation_called_with( - 'messages', 'ensure_index', [('expire_at', 1)], + 'messages', 'create_index', [('expire_at', 1)], expireAfterSeconds=0) self.assert_operation_called_with( - 'routing', 'ensure_index', [('expire_at', 1)], + 'routing', 'create_index', [('expire_at', 1)], expireAfterSeconds=0) self.assert_operation_called_with( - 'queues', 'ensure_index', [('expire_at', 1)], expireAfterSeconds=0) + 'queues', 'create_index', [('expire_at', 1)], expireAfterSeconds=0) - def test_get_expire(self): - result = self.channel._get_expire( + def test_get_queue_expire(self): + result = self.channel._get_queue_expire( {'arguments': {'x-expires': 777}}, 'x-expires') self.channel.client.assert_not_called() @@ -478,9 +483,15 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, }) - result = self.channel._get_expire('foobar', 'x-expires') + result = self.channel._get_queue_expire('foobar', 'x-expires') assert result == self.expire_at + def test_get_message_expire(self): + assert self.channel._get_message_expire({ + 'properties': {'expiration': 777}, + }) == self.expire_at + assert self.channel._get_message_expire({}) is None + def test_update_queues_expire(self): self.set_operation_return_value('queues', 'find_one', { '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, @@ -489,16 +500,14 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.assert_collection_accessed('messages.routing', 'messages.queues') self.assert_operation_called_with( - 'routing', 'update', + 'routing', 'update_many', {'queue': 'foobar'}, {'$set': {'expire_at': self.expire_at}}, - multi=True, ) self.assert_operation_called_with( - 'queues', 'update', + 'queues', 'update_many', {'_id': 'foobar'}, {'$set': {'expire_at': self.expire_at}}, - multi=True, ) @@ -515,7 +524,7 @@ class test_mongodb_channel_calc_queue_size(BaseMongoDBChannelCase): # Tests def test_size(self): - self.set_operation_return_value('messages', 'find.count', 77) + self.set_operation_return_value('messages', 'count_documents', 77) result = self.channel._size('foobar') diff --git a/t/unit/transport/test_pyamqp.py b/t/unit/transport/test_pyamqp.py index d5f6d7e2..bd402395 100644 --- a/t/unit/transport/test_pyamqp.py +++ b/t/unit/transport/test_pyamqp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from itertools import count from unittest.mock import MagicMock, Mock, patch diff --git a/t/unit/transport/test_pyro.py b/t/unit/transport/test_pyro.py index 325f81ce..258abc9e 100644 --- a/t/unit/transport/test_pyro.py +++ b/t/unit/transport/test_pyro.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import pytest @@ -59,8 +61,8 @@ class test_PyroTransport: with pytest.raises(socket.timeout): self.c.drain_events(timeout=0.1) - del(c1) # so pyflakes doesn't complain. - del(c2) + del c1 # so pyflakes doesn't complain. + del c2 @pytest.mark.skip("requires running Pyro nameserver and Kombu Broker") def test_drain_events_unregistered_queue(self): diff --git a/t/unit/transport/test_qpid.py b/t/unit/transport/test_qpid.py index 351a929b..0048fd38 100644 --- a/t/unit/transport/test_qpid.py +++ b/t/unit/transport/test_qpid.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import select import socket import ssl import sys import time import uuid -from collections import OrderedDict from collections.abc import Callable from itertools import count from queue import Empty @@ -33,7 +34,7 @@ class QpidException(Exception): """ def __init__(self, code=None, text=None): - super(Exception, self).__init__(self) + super().__init__(self) self.code = code self.text = text @@ -57,7 +58,7 @@ class test_QoS__init__: assert qos_limit_two.prefetch_count == 1 def test__init___not_yet_acked_is_initialized(self): - assert isinstance(self.qos._not_yet_acked, OrderedDict) + assert isinstance(self.qos._not_yet_acked, dict) @pytest.mark.skip(reason='Not supported in Python3') diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index c7ea8f67..b14408a6 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +import base64 +import copy import socket import types from collections import defaultdict from itertools import count from queue import Empty from queue import Queue as _Queue +from typing import TYPE_CHECKING from unittest.mock import ANY, Mock, call, patch import pytest @@ -13,7 +18,9 @@ from kombu.exceptions import VersionMismatch from kombu.transport import virtual from kombu.utils import eventio # patch poll from kombu.utils.json import dumps -from t.mocks import ContextMock + +if TYPE_CHECKING: + from types import TracebackType def _redis_modules(): @@ -230,7 +237,12 @@ class Pipeline: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None + ) -> None: pass def __getattr__(self, key): @@ -270,9 +282,8 @@ class Channel(redis.Channel): class Transport(redis.Transport): Channel = Channel - - def _get_errors(self): - return ((KeyError,), (IndexError,)) + connection_errors = (KeyError,) + channel_errors = (IndexError,) class test_Channel: @@ -401,38 +412,117 @@ class test_Channel: ) crit.assert_called() - def test_restore(self): + def test_do_restore_message_celery(self): + # Payload value from real Celery project + payload = { + "body": base64.b64encode(dumps([ + [], + {}, + { + "callbacks": None, + "errbacks": None, + "chain": None, + "chord": None, + }, + ]).encode()).decode(), + "content-encoding": "utf-8", + "content-type": "application/json", + "headers": { + "lang": "py", + "task": "common.tasks.test_task", + "id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "shadow": None, + "eta": None, + "expires": None, + "group": None, + "group_index": None, + "retries": 0, + "timelimit": [None, None], + "root_id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "parent_id": None, + "argsrepr": "()", + "kwargsrepr": "{}", + "origin": "gen3437@Desktop", + "ignore_result": False, + }, + "properties": { + "correlation_id": "980ad2bf-104c-4ce0-8643-67d1947173f6", + "reply_to": "512f2489-ca40-3585-bc10-9b801a981782", + "delivery_mode": 2, + "delivery_info": { + "exchange": "", + "routing_key": "celery", + }, + "priority": 0, + "body_encoding": "base64", + "delivery_tag": "badb725e-9c3e-45be-b0a4-07e44630519f", + }, + } + result_payload = copy.deepcopy(payload) + result_payload['headers']['redelivered'] = True + result_payload['properties']['delivery_info']['redelivered'] = True + queue = 'celery' + + client = Mock(name='client') + lookup = self.channel._lookup = Mock(name='_lookup') + lookup.return_value = [queue] + + self.channel._do_restore_message( + payload, 'exchange', 'routing_key', client, + ) + + client.rpush.assert_called_with(queue, dumps(result_payload)) + + def test_restore_no_messages(self): message = Mock(name='message') + with patch('kombu.transport.redis.loads') as loads: - loads.return_value = 'M', 'EX', 'RK' + def transaction_handler(restore_transaction, unacked_key): + assert unacked_key == self.channel.unacked_key + pipe = Mock(name='pipe') + pipe.hget.return_value = None + + restore_transaction(pipe) + + pipe.multi.assert_called_once_with() + pipe.hdel.assert_called_once_with( + unacked_key, message.delivery_tag) + loads.assert_not_called() + client = self.channel._create_client = Mock(name='client') client = client() - client.pipeline = ContextMock() - restore = self.channel._do_restore_message = Mock( - name='_do_restore_message', - ) - pipe = client.pipeline.return_value - pipe_hget = Mock(name='pipe.hget') - pipe.hget.return_value = pipe_hget - pipe_hget_hdel = Mock(name='pipe.hget.hdel') - pipe_hget.hdel.return_value = pipe_hget_hdel - result = Mock(name='result') - pipe_hget_hdel.execute.return_value = None, None - + client.transaction.side_effect = transaction_handler self.channel._restore(message) - client.pipeline.assert_called_with() - unacked_key = self.channel.unacked_key - loads.assert_not_called() + client.transaction.assert_called() + + def test_restore_messages(self): + message = Mock(name='message') + + with patch('kombu.transport.redis.loads') as loads: + + def transaction_handler(restore_transaction, unacked_key): + assert unacked_key == self.channel.unacked_key + restore = self.channel._do_restore_message = Mock( + name='_do_restore_message', + ) + result = Mock(name='result') + loads.return_value = 'M', 'EX', 'RK' + pipe = Mock(name='pipe') + pipe.hget.return_value = result + + restore_transaction(pipe) - tag = message.delivery_tag - pipe.hget.assert_called_with(unacked_key, tag) - pipe_hget.hdel.assert_called_with(unacked_key, tag) - pipe_hget_hdel.execute.assert_called_with() + loads.assert_called_with(result) + pipe.multi.assert_called_once_with() + pipe.hdel.assert_called_once_with( + unacked_key, message.delivery_tag) + loads.assert_called() + restore.assert_called_with('M', 'EX', 'RK', pipe, False) - pipe_hget_hdel.execute.return_value = result, None + client = self.channel._create_client = Mock(name='client') + client = client() + client.transaction.side_effect = transaction_handler self.channel._restore(message) - loads.assert_called_with(result) - restore.assert_called_with('M', 'EX', 'RK', client, False) def test_qos_restore_visible(self): client = self.channel._create_client = Mock(name='client') @@ -837,6 +927,26 @@ class test_Channel: call(13, transport.on_readable, 13), ]) + @pytest.mark.parametrize('fds', [{12: 'LISTEN', 13: 'BRPOP'}, {}]) + def test_register_with_event_loop__on_disconnect__loop_cleanup(self, fds): + """Ensure event loop polling stops on disconnect (if started).""" + transport = self.connection.transport + self.connection._sock = None + transport.cycle = Mock(name='cycle') + transport.cycle.fds = fds + conn = Mock(name='conn') + conn.client = Mock(name='client', transport_options={}) + loop = Mock(name='loop') + loop.on_tick = set() + redis.Transport.register_with_event_loop(transport, conn, loop) + assert len(loop.on_tick) == 1 + transport.cycle._on_connection_disconnect(self.connection) + if fds: + assert len(loop.on_tick) == 0 + else: + # on_tick shouldn't be cleared when polling hasn't started + assert len(loop.on_tick) == 1 + def test_configurable_health_check(self): transport = self.connection.transport transport.cycle = Mock(name='cycle') @@ -870,15 +980,22 @@ class test_Channel: redis.Transport.on_readable(transport, 13) cycle.on_readable.assert_called_with(13) - def test_transport_get_errors(self): - assert redis.Transport._get_errors(self.connection.transport) + def test_transport_connection_errors(self): + """Ensure connection_errors are populated.""" + assert redis.Transport.connection_errors + + def test_transport_channel_errors(self): + """Ensure connection_errors are populated.""" + assert redis.Transport.channel_errors def test_transport_driver_version(self): assert redis.Transport.driver_version(self.connection.transport) - def test_transport_get_errors_when_InvalidData_used(self): + def test_transport_errors_when_InvalidData_used(self): from redis import exceptions + from kombu.transport.redis import get_redis_error_classes + class ID(Exception): pass @@ -887,7 +1004,7 @@ class test_Channel: exceptions.InvalidData = ID exceptions.DataError = None try: - errors = redis.Transport._get_errors(self.connection.transport) + errors = get_redis_error_classes() assert errors assert ID in errors[1] finally: @@ -1008,6 +1125,57 @@ class test_Channel: '\x06\x16\x06\x16queue' ) + @patch("redis.client.PubSub.execute_command") + def test_global_keyprefix_pubsub(self, mock_execute_command): + from kombu.transport.redis import PrefixedStrictRedis + + with Connection(transport=Transport) as conn: + client = PrefixedStrictRedis(global_keyprefix='foo_') + + channel = conn.channel() + channel.global_keyprefix = 'foo_' + channel._create_client = Mock() + channel._create_client.return_value = client + channel.subclient.connection = Mock() + channel.active_fanout_queues.add('a') + + channel._subscribe() + mock_execute_command.assert_called_with( + 'PSUBSCRIBE', + 'foo_/{db}.a', + ) + + @patch("redis.client.Pipeline.execute_command") + def test_global_keyprefix_transaction(self, mock_execute_command): + from kombu.transport.redis import PrefixedStrictRedis + + with Connection(transport=Transport) as conn: + def pipeline(transaction=True, shard_hint=None): + pipeline_obj = original_pipeline( + transaction=transaction, shard_hint=shard_hint + ) + mock_execute_command.side_effect = [ + None, None, pipeline_obj, pipeline_obj + ] + return pipeline_obj + + client = PrefixedStrictRedis(global_keyprefix='foo_') + original_pipeline = client.pipeline + client.pipeline = pipeline + + channel = conn.channel() + channel._create_client = Mock() + channel._create_client.return_value = client + + channel.qos.restore_by_tag('test-tag') + assert mock_execute_command is not None + assert mock_execute_command.mock_calls == [ + call('WATCH', 'foo_unacked'), + call('HGET', 'foo_unacked', 'test-tag'), + call('ZREM', 'foo_unacked_index', 'test-tag'), + call('HDEL', 'foo_unacked', 'test-tag') + ] + class test_Redis: diff --git a/t/unit/transport/test_sqlalchemy.py b/t/unit/transport/test_sqlalchemy.py index 5ddca5ac..aa0907f7 100644 --- a/t/unit/transport/test_sqlalchemy.py +++ b/t/unit/transport/test_sqlalchemy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import patch import pytest diff --git a/t/unit/transport/test_transport.py b/t/unit/transport/test_transport.py index ca84dd80..b5b5e6eb 100644 --- a/t/unit/transport/test_transport.py +++ b/t/unit/transport/test_transport.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock, patch from kombu import transport diff --git a/t/unit/transport/test_zookeeper.py b/t/unit/transport/test_zookeeper.py index 21fcac42..8b6d159c 100644 --- a/t/unit/transport/test_zookeeper.py +++ b/t/unit/transport/test_zookeeper.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu import Connection diff --git a/t/unit/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py index 681841a0..124e19dd 100644 --- a/t/unit/transport/virtual/test_base.py +++ b/t/unit/transport/virtual/test_base.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import io import socket import warnings +from array import array from time import monotonic from unittest.mock import MagicMock, Mock, patch @@ -178,13 +181,19 @@ class test_Channel: if self.channel._qos is not None: self.channel._qos._on_collect.cancel() - def test_exceeds_channel_max(self): - c = client() - t = c.transport - avail = t._avail_channel_ids = Mock(name='_avail_channel_ids') - avail.pop.side_effect = IndexError() + def test_get_free_channel_id(self): + conn = client() + channel = conn.channel() + assert channel.channel_id == 1 + assert channel._get_free_channel_id() == 2 + + def test_get_free_channel_id__exceeds_channel_max(self): + conn = client() + conn.transport.channel_max = 2 + channel = conn.channel() + channel._get_free_channel_id() with pytest.raises(ResourceError): - virtual.Channel(t) + channel._get_free_channel_id() def test_exchange_bind_interface(self): with pytest.raises(NotImplementedError): @@ -455,9 +464,8 @@ class test_Channel: assert 'could not be delivered' in log[0].message.args[0] def test_context(self): - x = self.channel.__enter__() - assert x is self.channel - x.__exit__() + with self.channel as x: + assert x is self.channel assert x.closed def test_cycle_property(self): @@ -574,8 +582,25 @@ class test_Transport: assert len(self.transport.channels) == 2 self.transport.close_connection(self.transport) assert not self.transport.channels - del(c1) # so pyflakes doesn't complain - del(c2) + del c1 # so pyflakes doesn't complain + del c2 + + def test_create_channel(self): + """Ensure create_channel can create channels successfully.""" + assert self.transport.channels == [] + created_channel = self.transport.create_channel(self.transport) + assert self.transport.channels == [created_channel] + + def test_close_channel(self): + """Ensure close_channel actually removes the channel and updates + _used_channel_ids. + """ + assert self.transport._used_channel_ids == array('H') + created_channel = self.transport.create_channel(self.transport) + assert self.transport._used_channel_ids == array('H', (1,)) + self.transport.close_channel(created_channel) + assert self.transport.channels == [] + assert self.transport._used_channel_ids == array('H') def test_drain_channel(self): channel = self.transport.create_channel(self.transport) diff --git a/t/unit/transport/virtual/test_exchange.py b/t/unit/transport/virtual/test_exchange.py index 55741445..5e5a61d7 100644 --- a/t/unit/transport/virtual/test_exchange.py +++ b/t/unit/transport/virtual/test_exchange.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/utils/test_amq_manager.py b/t/unit/utils/test_amq_manager.py index ca6adb6e..22fb9355 100644 --- a/t/unit/utils/test_amq_manager.py +++ b/t/unit/utils/test_amq_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import patch import pytest diff --git a/t/unit/utils/test_compat.py b/t/unit/utils/test_compat.py index d3159b76..d1fa0055 100644 --- a/t/unit/utils/test_compat.py +++ b/t/unit/utils/test_compat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import socket import sys import types @@ -14,10 +16,14 @@ def test_entrypoints(): 'kombu.utils.compat.importlib_metadata.entry_points', create=True ) as iterep: eps = [Mock(), Mock()] - iterep.return_value = {'kombu.test': eps} + iterep.return_value = ( + {'kombu.test': eps} if sys.version_info < (3, 10) else eps) assert list(entrypoints('kombu.test')) - iterep.assert_called_with() + if sys.version_info < (3, 10): + iterep.assert_called_with() + else: + iterep.assert_called_with(group='kombu.test') eps[0].load.assert_called_with() eps[1].load.assert_called_with() diff --git a/t/unit/utils/test_debug.py b/t/unit/utils/test_debug.py index 020bc849..a4955507 100644 --- a/t/unit/utils/test_debug.py +++ b/t/unit/utils/test_debug.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from unittest.mock import Mock, patch diff --git a/t/unit/utils/test_div.py b/t/unit/utils/test_div.py index b29b6119..a6e988e8 100644 --- a/t/unit/utils/test_div.py +++ b/t/unit/utils/test_div.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from io import BytesIO, StringIO diff --git a/t/unit/utils/test_encoding.py b/t/unit/utils/test_encoding.py index 26e3ef36..81358a7a 100644 --- a/t/unit/utils/test_encoding.py +++ b/t/unit/utils/test_encoding.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from contextlib import contextmanager from unittest.mock import patch diff --git a/t/unit/utils/test_functional.py b/t/unit/utils/test_functional.py index 73a98e52..26f28733 100644 --- a/t/unit/utils/test_functional.py +++ b/t/unit/utils/test_functional.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pickle from itertools import count from unittest.mock import Mock diff --git a/t/unit/utils/test_imports.py b/t/unit/utils/test_imports.py index 8a4873df..8f515bd8 100644 --- a/t/unit/utils/test_imports.py +++ b/t/unit/utils/test_imports.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/utils/test_json.py b/t/unit/utils/test_json.py index 6af1c13b..8dcc7e32 100644 --- a/t/unit/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -1,14 +1,17 @@ +from __future__ import annotations + +import uuid from collections import namedtuple from datetime import datetime from decimal import Decimal -from unittest.mock import MagicMock, Mock -from uuid import uuid4 import pytest import pytz +from hypothesis import given, settings +from hypothesis import strategies as st from kombu.utils.encoding import str_to_bytes -from kombu.utils.json import _DecodeError, dumps, loads +from kombu.utils.json import dumps, loads class Custom: @@ -21,35 +24,54 @@ class Custom: class test_JSONEncoder: - + @pytest.mark.freeze_time("2015-10-21") def test_datetime(self): now = datetime.utcnow() now_utc = now.replace(tzinfo=pytz.utc) - stripped = datetime(*now.timetuple()[:3]) - serialized = loads(dumps({ + + original = { 'datetime': now, 'tz': now_utc, 'date': now.date(), - 'time': now.time()}, - )) + 'time': now.time(), + } + + serialized = loads(dumps(original)) + + assert serialized == original + + @given(message=st.binary()) + @settings(print_blob=True) + def test_binary(self, message): + serialized = loads(dumps({ + 'args': (message,), + })) assert serialized == { - 'datetime': now.isoformat(), - 'tz': '{}Z'.format(now_utc.isoformat().split('+', 1)[0]), - 'time': now.time().isoformat(), - 'date': stripped.isoformat(), + 'args': [message], } def test_Decimal(self): - d = Decimal('3314132.13363235235324234123213213214134') - assert loads(dumps({'d': d})), {'d': str(d)} + original = {'d': Decimal('3314132.13363235235324234123213213214134')} + serialized = loads(dumps(original)) + + assert serialized == original def test_namedtuple(self): Foo = namedtuple('Foo', ['bar']) assert loads(dumps(Foo(123))) == [123] def test_UUID(self): - id = uuid4() - assert loads(dumps({'u': id})), {'u': str(id)} + constructors = [ + uuid.uuid1, + lambda: uuid.uuid3(uuid.NAMESPACE_URL, "https://example.org"), + uuid.uuid4, + lambda: uuid.uuid5(uuid.NAMESPACE_URL, "https://example.org"), + ] + for constructor in constructors: + id = constructor() + loaded_value = loads(dumps({'u': id})) + assert loaded_value == {'u': id} + assert loaded_value["u"].version == id.version def test_default(self): with pytest.raises(TypeError): @@ -81,9 +103,3 @@ class test_dumps_loads: assert loads( str_to_bytes(dumps({'x': 'z'})), decode_bytes=True) == {'x': 'z'} - - def test_loads_DecodeError(self): - _loads = Mock(name='_loads') - _loads.side_effect = _DecodeError( - MagicMock(), MagicMock(), MagicMock()) - assert loads(dumps({'x': 'z'}), _loads=_loads) == {'x': 'z'} diff --git a/t/unit/utils/test_objects.py b/t/unit/utils/test_objects.py index 93a88b4f..b9f1484a 100644 --- a/t/unit/utils/test_objects.py +++ b/t/unit/utils/test_objects.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from kombu.utils.objects import cached_property diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py index 44cf01a2..7bc76b96 100644 --- a/t/unit/utils/test_scheduling.py +++ b/t/unit/utils/test_scheduling.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest.mock import Mock import pytest diff --git a/t/unit/utils/test_time.py b/t/unit/utils/test_time.py index 660ae8ec..a8f7de0f 100644 --- a/t/unit/utils/test_time.py +++ b/t/unit/utils/test_time.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu.utils.time import maybe_s_to_ms diff --git a/t/unit/utils/test_url.py b/t/unit/utils/test_url.py index 71ea0f9b..f219002b 100644 --- a/t/unit/utils/test_url.py +++ b/t/unit/utils/test_url.py @@ -1,3 +1,5 @@ +from __future__ import annotations + try: from urllib.parse import urlencode except ImportError: diff --git a/t/unit/utils/test_utils.py b/t/unit/utils/test_utils.py index d118d46e..08f95083 100644 --- a/t/unit/utils/test_utils.py +++ b/t/unit/utils/test_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from kombu import version_info_t diff --git a/t/unit/utils/test_uuid.py b/t/unit/utils/test_uuid.py index 05d89125..bc69474a 100644 --- a/t/unit/utils/test_uuid.py +++ b/t/unit/utils/test_uuid.py @@ -1,12 +1,14 @@ +from __future__ import annotations + from kombu.utils.uuid import uuid class test_UUID: - def test_uuid4(self): + def test_uuid4(self) -> None: assert uuid() != uuid() - def test_uuid(self): + def test_uuid(self) -> None: i1 = uuid() i2 = uuid() assert isinstance(i1, str) |
