diff options
-rw-r--r-- | .gitignore | 6 | ||||
-rw-r--r-- | MANIFEST.in | 6 | ||||
-rw-r--r-- | kombu/tests/__init__.py | 62 | ||||
-rw-r--r-- | kombu/tests/test_clocks.py | 104 | ||||
-rw-r--r-- | kombu/tests/test_exceptions.py | 11 | ||||
-rw-r--r-- | kombu/tests/transport/test_qpid.py | 1953 | ||||
-rw-r--r-- | kombu/tests/transport/virtual/test_exchange.py | 200 | ||||
-rw-r--r-- | kombu/tests/utils/test_div.py | 51 | ||||
-rw-r--r-- | kombu/tests/utils/test_encoding.py | 109 | ||||
-rw-r--r-- | kombu/tests/utils/test_scheduling.py | 112 | ||||
-rw-r--r-- | kombu/tests/utils/test_url.py | 50 | ||||
-rw-r--r-- | kombu/tests/utils/test_utils.py | 42 | ||||
-rw-r--r-- | kombu/transport/redis.py | 7 | ||||
-rw-r--r-- | kombu/transport/virtual/__init__.py | 1001 | ||||
-rw-r--r-- | kombu/transport/virtual/base.py | 989 | ||||
-rw-r--r-- | requirements/test-ci.txt | 2 | ||||
-rw-r--r-- | requirements/test.txt | 3 | ||||
-rw-r--r-- | setup.cfg | 7 | ||||
-rw-r--r-- | setup.py | 53 | ||||
-rw-r--r-- | t/__init__.py (renamed from kombu/tests/async/__init__.py) | 0 | ||||
-rw-r--r-- | t/conftest.py | 98 | ||||
-rw-r--r-- | t/mocks.py (renamed from kombu/tests/mocks.py) | 24 | ||||
-rw-r--r-- | t/unit/__init__.py | 1 | ||||
-rw-r--r-- | t/unit/async/__init__.py (renamed from kombu/tests/async/aws/__init__.py) | 0 | ||||
-rw-r--r-- | t/unit/async/aws/__init__.py (renamed from kombu/tests/async/aws/sqs/__init__.py) | 0 | ||||
-rw-r--r-- | t/unit/async/aws/case.py (renamed from kombu/tests/async/aws/case.py) | 7 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/__init__.py (renamed from kombu/tests/async/http/__init__.py) | 0 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_connection.py (renamed from kombu/tests/async/aws/sqs/test_connection.py) | 17 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_message.py (renamed from kombu/tests/async/aws/sqs/test_message.py) | 19 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_queue.py (renamed from kombu/tests/async/aws/sqs/test_queue.py) | 41 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_sqs.py (renamed from kombu/tests/async/aws/sqs/test_sqs.py) | 21 | ||||
-rw-r--r-- | t/unit/async/aws/test_aws.py (renamed from kombu/tests/async/aws/test_aws.py) | 7 | ||||
-rw-r--r-- | t/unit/async/aws/test_connection.py (renamed from kombu/tests/async/aws/test_connection.py) | 148 | ||||
-rw-r--r-- | t/unit/async/http/__init__.py (renamed from kombu/tests/transport/__init__.py) | 0 | ||||
-rw-r--r-- | t/unit/async/http/test_curl.py (renamed from kombu/tests/async/http/test_curl.py) | 47 | ||||
-rw-r--r-- | t/unit/async/http/test_http.py (renamed from kombu/tests/async/http/test_http.py) | 91 | ||||
-rw-r--r-- | t/unit/async/test_hub.py (renamed from kombu/tests/async/test_hub.py) | 203 | ||||
-rw-r--r-- | t/unit/async/test_semaphore.py (renamed from kombu/tests/async/test_semaphore.py) | 24 | ||||
-rw-r--r-- | t/unit/async/test_timer.py (renamed from kombu/tests/async/test_timer.py) | 73 | ||||
-rw-r--r-- | t/unit/case.py (renamed from kombu/tests/case.py) | 0 | ||||
-rw-r--r-- | t/unit/test_clocks.py | 89 | ||||
-rw-r--r-- | t/unit/test_common.py (renamed from kombu/tests/test_common.py) | 183 | ||||
-rw-r--r-- | t/unit/test_compat.py (renamed from kombu/tests/test_compat.py) | 155 | ||||
-rw-r--r-- | t/unit/test_compression.py (renamed from kombu/tests/test_compression.py) | 20 | ||||
-rw-r--r-- | t/unit/test_connection.py (renamed from kombu/tests/test_connection.py) | 384 | ||||
-rw-r--r-- | t/unit/test_entity.py (renamed from kombu/tests/test_entity.py) | 220 | ||||
-rw-r--r-- | t/unit/test_exceptions.py | 11 | ||||
-rw-r--r-- | t/unit/test_log.py (renamed from kombu/tests/test_log.py) | 62 | ||||
-rw-r--r-- | t/unit/test_message.py (renamed from kombu/tests/test_message.py) | 21 | ||||
-rw-r--r-- | t/unit/test_messaging.py (renamed from kombu/tests/test_messaging.py) | 286 | ||||
-rw-r--r-- | t/unit/test_mixins.py (renamed from kombu/tests/test_mixins.py) | 45 | ||||
-rw-r--r-- | t/unit/test_pidbox.py (renamed from kombu/tests/test_pidbox.py) | 125 | ||||
-rw-r--r-- | t/unit/test_pools.py (renamed from kombu/tests/test_pools.py) | 84 | ||||
-rw-r--r-- | t/unit/test_serialization.py (renamed from kombu/tests/test_serialization.py) | 209 | ||||
-rw-r--r-- | t/unit/test_simple.py (renamed from kombu/tests/test_simple.py) | 69 | ||||
-rw-r--r-- | t/unit/test_syn.py (renamed from kombu/tests/test_syn.py) | 24 | ||||
-rw-r--r-- | t/unit/transport/__init__.py (renamed from kombu/tests/transport/virtual/__init__.py) | 0 | ||||
-rw-r--r-- | t/unit/transport/test_SQS.py (renamed from kombu/tests/transport/test_SQS.py) | 73 | ||||
-rw-r--r-- | t/unit/transport/test_base.py (renamed from kombu/tests/transport/test_base.py) | 52 | ||||
-rw-r--r-- | t/unit/transport/test_consul.py (renamed from kombu/tests/transport/test_consul.py) | 38 | ||||
-rw-r--r-- | t/unit/transport/test_filesystem.py (renamed from kombu/tests/transport/test_filesystem.py) | 50 | ||||
-rw-r--r-- | t/unit/transport/test_librabbitmq.py (renamed from kombu/tests/transport/test_librabbitmq.py) | 66 | ||||
-rw-r--r-- | t/unit/transport/test_memory.py (renamed from kombu/tests/transport/test_memory.py) | 37 | ||||
-rw-r--r-- | t/unit/transport/test_mongodb.py (renamed from kombu/tests/transport/test_mongodb.py) | 52 | ||||
-rw-r--r-- | t/unit/transport/test_pyamqp.py (renamed from kombu/tests/transport/test_pyamqp.py) | 50 | ||||
-rw-r--r-- | t/unit/transport/test_redis.py (renamed from kombu/tests/transport/test_redis.py) | 299 | ||||
-rw-r--r-- | t/unit/transport/test_transport.py (renamed from kombu/tests/transport/test_transport.py) | 19 | ||||
-rw-r--r-- | t/unit/transport/virtual/__init__.py (renamed from kombu/tests/utils/__init__.py) | 0 | ||||
-rw-r--r-- | t/unit/transport/virtual/test_base.py (renamed from kombu/tests/transport/virtual/test_base.py) | 271 | ||||
-rw-r--r-- | t/unit/transport/virtual/test_exchange.py | 169 | ||||
-rw-r--r-- | t/unit/utils/__init__.py | 0 | ||||
-rw-r--r-- | t/unit/utils/test_amq_manager.py (renamed from kombu/tests/utils/test_amq_manager.py) | 14 | ||||
-rw-r--r-- | t/unit/utils/test_compat.py (renamed from kombu/tests/utils/test_compat.py) | 24 | ||||
-rw-r--r-- | t/unit/utils/test_debug.py (renamed from kombu/tests/utils/test_debug.py) | 20 | ||||
-rw-r--r-- | t/unit/utils/test_div.py | 49 | ||||
-rw-r--r-- | t/unit/utils/test_encoding.py | 105 | ||||
-rw-r--r-- | t/unit/utils/test_functional.py (renamed from kombu/tests/utils/test_functional.py) | 209 | ||||
-rw-r--r-- | t/unit/utils/test_imports.py (renamed from kombu/tests/utils/test_imports.py) | 27 | ||||
-rw-r--r-- | t/unit/utils/test_json.py (renamed from kombu/tests/utils/test_json.py) | 51 | ||||
-rw-r--r-- | t/unit/utils/test_objects.py (renamed from kombu/tests/utils/test_objects.py) | 20 | ||||
-rw-r--r-- | t/unit/utils/test_scheduling.py | 102 | ||||
-rw-r--r-- | t/unit/utils/test_url.py | 39 | ||||
-rw-r--r-- | t/unit/utils/test_utils.py | 22 | ||||
-rw-r--r-- | t/unit/utils/test_uuid.py (renamed from kombu/tests/utils/test_uuid.py) | 10 | ||||
-rw-r--r-- | tox.ini | 3 |
85 files changed, 3674 insertions, 5773 deletions
@@ -8,7 +8,7 @@ settings_local.py build/ .build/ _build/ -.*.sw[pon] +.*.sw* dist/ *.egg-info pip-log.txt @@ -26,3 +26,7 @@ kombu/tests/coverage.xml .coverage dump.rdb .idea/ +.cache/ +htmlcov/ +test.db +coverage.xml diff --git a/MANIFEST.in b/MANIFEST.in index 35c0ac82..ff42284e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -12,6 +12,12 @@ include setup.cfg recursive-include extra * recursive-include docs * recursive-include kombu *.py +recursive-include t *.py recursive-include requirements *.txt recursive-include funtests *.py setup.cfg recursive-include examples *.py + +recursive-exclude docs/_build * +recursive-exclude * __pycache__ +recursive-exclude * *.py[co] +recursive-exclude * .*.sw* diff --git a/kombu/tests/__init__.py b/kombu/tests/__init__.py deleted file mode 100644 index 17ac49e3..00000000 --- a/kombu/tests/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import atexit -import os -import sys - -from kombu.exceptions import VersionMismatch - - -def teardown(): - # Workaround for multiprocessing bug where logging - # is attempted after global already collected at shutdown. - canceled = set() - try: - import multiprocessing.util - canceled.add(multiprocessing.util._exit_function) - except (AttributeError, ImportError): - pass - - try: - atexit._exithandlers[:] = [ - e for e in atexit._exithandlers if e[0] not in canceled - ] - except AttributeError: # pragma: no cover - pass # Py3 missing _exithandlers - - -def find_distribution_modules(name=__name__, file=__file__): - current_dist_depth = len(name.split('.')) - 1 - current_dist = os.path.join(os.path.dirname(file), - *([os.pardir] * current_dist_depth)) - abs = os.path.abspath(current_dist) - dist_name = os.path.basename(abs) - - for dirpath, dirnames, filenames in os.walk(abs): - package = (dist_name + dirpath[len(abs):]).replace('/', '.') - if '__init__.py' in filenames: - yield package - for filename in filenames: - if filename.endswith('.py') and filename != '__init__.py': - yield '.'.join([package, filename])[:-3] - - -def import_all_modules(name=__name__, file=__file__, skip=[]): - for module in find_distribution_modules(name, file): - if module not in skip: - print('preimporting %r for coverage...' % (module,)) - try: - __import__(module) - except (ImportError, VersionMismatch, AttributeError): - pass - - -def is_in_coverage(): - return (os.environ.get('COVER_ALL_MODULES') or - '--with-coverage3' in sys.argv) - - -def setup(): - # so coverage sees all our modules. - if is_in_coverage(): - import_all_modules() diff --git a/kombu/tests/test_clocks.py b/kombu/tests/test_clocks.py deleted file mode 100644 index ef4ab87d..00000000 --- a/kombu/tests/test_clocks.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import pickle - -from heapq import heappush -from time import time - -from kombu.clocks import LamportClock, timetuple - -from .case import Mock, Case - - -class test_LamportClock(Case): - - def test_clocks(self): - c1 = LamportClock() - c2 = LamportClock() - - c1.forward() - c2.forward() - c1.forward() - c1.forward() - c2.adjust(c1.value) - self.assertEqual(c2.value, c1.value + 1) - self.assertTrue(repr(c1)) - - c2_val = c2.value - c2.forward() - c2.forward() - c2.adjust(c1.value) - self.assertEqual(c2.value, c2_val + 2 + 1) - - c1.adjust(c2.value) - self.assertEqual(c1.value, c2.value + 1) - - def test_sort(self): - c = LamportClock() - pid1 = 'a.example.com:312' - pid2 = 'b.example.com:311' - - events = [] - - m1 = (c.forward(), pid1) - heappush(events, m1) - m2 = (c.forward(), pid2) - heappush(events, m2) - m3 = (c.forward(), pid1) - heappush(events, m3) - m4 = (30, pid1) - heappush(events, m4) - m5 = (30, pid2) - heappush(events, m5) - - self.assertEqual(str(c), str(c.value)) - - self.assertEqual(c.sort_heap(events), m1) - self.assertEqual(c.sort_heap([m4, m5]), m4) - self.assertEqual(c.sort_heap([m4, m5, m1]), m4) - - -class test_timetuple(Case): - - def test_repr(self): - x = timetuple(133, time(), 'id', Mock()) - self.assertTrue(repr(x)) - - def test_pickleable(self): - x = timetuple(133, time(), 'id', 'obj') - self.assertEqual(pickle.loads(pickle.dumps(x)), tuple(x)) - - def test_order(self): - t1 = time() - t2 = time() + 300 # windows clock not reliable - a = timetuple(133, t1, 'A', 'obj') - b = timetuple(140, t1, 'A', 'obj') - self.assertTrue(a.__getnewargs__()) - self.assertEqual(a.clock, 133) - self.assertEqual(a.timestamp, t1) - self.assertEqual(a.id, 'A') - self.assertEqual(a.obj, 'obj') - self.assertTrue( - a <= b, - ) - self.assertTrue( - b >= a, - ) - - self.assertEqual( - timetuple(134, time(), 'A', 'obj').__lt__(tuple()), - NotImplemented, - ) - self.assertGreater( - timetuple(134, t2, 'A', 'obj'), - timetuple(133, t1, 'A', 'obj'), - ) - self.assertGreater( - timetuple(134, t1, 'B', 'obj'), - timetuple(134, t1, 'A', 'obj'), - ) - - self.assertGreater( - timetuple(None, t2, 'B', 'obj'), - timetuple(None, t1, 'A', 'obj'), - ) diff --git a/kombu/tests/test_exceptions.py b/kombu/tests/test_exceptions.py deleted file mode 100644 index 1d04b359..00000000 --- a/kombu/tests/test_exceptions.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -from kombu.exceptions import HttpError - -from kombu.tests.case import Case, Mock - - -class test_HttpError(Case): - - def test_str(self): - self.assertTrue(str(HttpError(200, 'msg', Mock(name='response')))) diff --git a/kombu/tests/transport/test_qpid.py b/kombu/tests/transport/test_qpid.py deleted file mode 100644 index 0e913d84..00000000 --- a/kombu/tests/transport/test_qpid.py +++ /dev/null @@ -1,1953 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import select -import ssl -import socket -import sys -import time -import uuid - -from collections import Callable, OrderedDict -from itertools import count -from functools import wraps - -from mock import call -from nose import SkipTest - -from kombu.five import Empty, keys, range, monotonic -from kombu.transport.qpid import (AuthenticationFailure, Channel, Connection, - ConnectionError, Message, NotFound, QoS, - Transport) -from kombu.transport.virtual import Base64 -from kombu.tests.case import Case, Mock -from kombu.tests.case import patch - - -QPID_MODULE = 'kombu.transport.qpid' -PY3 = sys.version_info[0] == 3 - - -def case_no_pypy(cls): - setup = cls.setUp - - @wraps(setup) - def around_setup(self): - if getattr(sys, 'pypy_version_info', None): - raise SkipTest('pypy incompatible') - setup(self) - cls.setUp = around_setup - return cls - - -def case_no_python3(cls): - setup = cls.setUp - - @wraps(setup) - def around_setup(self): - if PY3: - raise SkipTest('Python3 incompatible') - setup(self) - cls.setUp = around_setup - return cls - - -def disable_runtime_dependency_check(cls): - """A decorator to disable runtime dependency checking""" - setup = cls.setUp - teardown = cls.tearDown - dependency_is_none_patcher = patch(QPID_MODULE + '.dependency_is_none') - - @wraps(setup) - def around_setup(self): - mock_dependency_is_none = dependency_is_none_patcher.start() - mock_dependency_is_none.return_value = False - setup(self) - - @wraps(setup) - def around_teardown(self): - dependency_is_none_patcher.stop() - teardown(self) - - cls.setUp = around_setup - cls.tearDown = around_teardown - return cls - - -class ExtraAssertionsMixin(object): - """A mixin class adding assertDictEqual and assertDictContainsSubset""" - - def assertDictEqual(self, a, b, msg=None): - """ - Test that two dictionaries are equal. - - Implemented here because this method was not available until Python - 2.6. This asserts that the unique set of keys are the same in a and b. - Also asserts that the value of each key is the same in a and b using - the is operator. - """ - self.assertEqual(set(keys(a)), set(keys(b))) - for key in keys(a): - self.assertEqual(a[key], b[key]) - - def assertDictContainsSubset(self, a, b, msg=None): - """ - Assert that all the key/value pairs in a exist in b. - """ - for key in keys(a): - self.assertIn(key, b) - self.assertEqual(a[key], b[key]) - - -class QpidException(Exception): - """ - An object used to mock Exceptions provided by qpid.messaging.exceptions - """ - - def __init__(self, code=None, text=None): - super(Exception, self).__init__(self) - self.code = code - self.text = text - - -class BreakOutException(Exception): - pass - - -@case_no_python3 -@case_no_pypy -class TestQoS__init__(Case): - - def setUp(self): - self.mock_session = Mock() - self.qos = QoS(self.mock_session) - - def test__init__prefetch_default_set_correct_without_prefetch_value(self): - self.assertEqual(self.qos.prefetch_count, 1) - - def test__init__prefetch_is_hard_set_to_one(self): - qos_limit_two = QoS(self.mock_session) - self.assertEqual(qos_limit_two.prefetch_count, 1) - - def test__init___not_yet_acked_is_initialized(self): - self.assertIsInstance(self.qos._not_yet_acked, OrderedDict) - - -@case_no_python3 -@case_no_pypy -class TestQoSCanConsume(Case): - - def setUp(self): - session = Mock() - self.qos = QoS(session) - - def test_True_when_prefetch_limit_is_zero(self): - self.qos.prefetch_count = 0 - self.qos._not_yet_acked = [] - self.assertTrue(self.qos.can_consume()) - - def test_True_when_len_of__not_yet_acked_is_lt_prefetch_count(self): - self.qos.prefetch_count = 3 - self.qos._not_yet_acked = ['a', 'b'] - self.assertTrue(self.qos.can_consume()) - - def test_False_when_len_of__not_yet_acked_is_eq_prefetch_count(self): - self.qos.prefetch_count = 3 - self.qos._not_yet_acked = ['a', 'b', 'c'] - self.assertFalse(self.qos.can_consume()) - - -@case_no_python3 -@case_no_pypy -class TestQoSCanConsumeMaxEstimate(Case): - - def setUp(self): - self.mock_session = Mock() - self.qos = QoS(self.mock_session) - - def test_return_one_when_prefetch_count_eq_zero(self): - self.qos.prefetch_count = 0 - self.assertEqual(self.qos.can_consume_max_estimate(), 1) - - def test_return_prefetch_count_sub_len__not_yet_acked(self): - self.qos._not_yet_acked = ['a', 'b'] - self.qos.prefetch_count = 4 - self.assertEqual(self.qos.can_consume_max_estimate(), 2) - - -@case_no_python3 -@case_no_pypy -class TestQoSAck(Case): - - def setUp(self): - self.mock_session = Mock() - self.qos = QoS(self.mock_session) - - def test_ack_pops__not_yet_acked(self): - message = Mock() - self.qos.append(message, 1) - self.assertIn(1, self.qos._not_yet_acked) - self.qos.ack(1) - self.assertNotIn(1, self.qos._not_yet_acked) - - def test_ack_calls_session_acknowledge_with_message(self): - message = Mock() - self.qos.append(message, 1) - self.qos.ack(1) - self.qos.session.acknowledge.assert_called_with(message=message) - - -@case_no_python3 -@case_no_pypy -class TestQoSReject(Case): - - def setUp(self): - self.mock_session = Mock() - self.mock_message = Mock() - self.qos = QoS(self.mock_session) - self.patch_qpid = patch(QPID_MODULE + '.qpid') - self.mock_qpid = self.patch_qpid.start() - self.mock_Disposition = self.mock_qpid.messaging.Disposition - self.mock_RELEASED = self.mock_qpid.messaging.RELEASED - self.mock_REJECTED = self.mock_qpid.messaging.REJECTED - - def tearDown(self): - self.patch_qpid.stop() - - def test_reject_pops__not_yet_acked(self): - self.qos.append(self.mock_message, 1) - self.assertIn(1, self.qos._not_yet_acked) - self.qos.reject(1) - self.assertNotIn(1, self.qos._not_yet_acked) - - def test_reject_requeue_true(self): - self.qos.append(self.mock_message, 1) - self.qos.reject(1, requeue=True) - self.mock_Disposition.assert_called_with(self.mock_RELEASED) - self.qos.session.acknowledge.assert_called_with( - message=self.mock_message, - disposition=self.mock_Disposition.return_value, - ) - - def test_reject_requeue_false(self): - message = Mock() - self.qos.append(message, 1) - self.qos.reject(1, requeue=False) - self.mock_Disposition.assert_called_with(self.mock_REJECTED) - self.qos.session.acknowledge.assert_called_with( - message=message, disposition=self.mock_Disposition.return_value, - ) - - -@case_no_python3 -@case_no_pypy -class TestQoS(Case): - - def mock_message_factory(self): - """Create and return a mock message tag and delivery_tag.""" - m_delivery_tag = self.delivery_tag_generator.next() - m = 'message %s' % (m_delivery_tag, ) - return m, m_delivery_tag - - def add_n_messages_to_qos(self, n, qos): - """Add N mock messages into the passed in qos object""" - for i in range(n): - self.add_message_to_qos(qos) - - def add_message_to_qos(self, qos): - """Add a single mock message into the passed in qos object. - - Uses the mock_message_factory() to create the message and - delivery_tag. - """ - m, m_delivery_tag = self.mock_message_factory() - qos.append(m, m_delivery_tag) - - def setUp(self): - self.mock_session = Mock() - self.qos_no_limit = QoS(self.mock_session) - self.qos_limit_2 = QoS(self.mock_session, prefetch_count=2) - self.delivery_tag_generator = count(1) - - def test_append(self): - """Append two messages and check inside the QoS object that they - were put into the internal data structures correctly - """ - qos = self.qos_no_limit - m1, m1_tag = self.mock_message_factory() - m2, m2_tag = self.mock_message_factory() - qos.append(m1, m1_tag) - length_not_yet_acked = len(qos._not_yet_acked) - self.assertEqual(length_not_yet_acked, 1) - checked_message1 = qos._not_yet_acked[m1_tag] - self.assertIs(m1, checked_message1) - qos.append(m2, m2_tag) - length_not_yet_acked = len(qos._not_yet_acked) - self.assertEqual(length_not_yet_acked, 2) - checked_message2 = qos._not_yet_acked[m2_tag] - self.assertIs(m2, checked_message2) - - def test_get(self): - """Append two messages, and use get to receive them""" - qos = self.qos_no_limit - m1, m1_tag = self.mock_message_factory() - m2, m2_tag = self.mock_message_factory() - qos.append(m1, m1_tag) - qos.append(m2, m2_tag) - message1 = qos.get(m1_tag) - message2 = qos.get(m2_tag) - self.assertIs(m1, message1) - self.assertIs(m2, message2) - - -@case_no_python3 -@case_no_pypy -class ConnectionTestBase(Case): - - @patch(QPID_MODULE + '.qpid') - def setUp(self, mock_qpid): - self.connection_options = { - 'host': 'localhost', - 'port': 5672, - 'transport': 'tcp', - 'timeout': 10, - 'sasl_mechanisms': 'ANONYMOUS', - } - self.mock_qpid_connection = mock_qpid.messaging.Connection - self.conn = Connection(**self.connection_options) - - -@case_no_python3 -@case_no_pypy -class TestConnectionInit(ExtraAssertionsMixin, ConnectionTestBase): - - def test_stores_connection_options(self): - # ensure that only one mech was passed into connection. The other - # options should all be passed through as-is - modified_conn_opts = self.connection_options - self.assertDictEqual( - modified_conn_opts, self.conn.connection_options, - ) - - def test_class_variables(self): - self.assertIsInstance(self.conn.channels, list) - self.assertIsInstance(self.conn._callbacks, dict) - - def test_establishes_connection(self): - modified_conn_opts = self.connection_options - self.mock_qpid_connection.establish.assert_called_with( - **modified_conn_opts - ) - - def test_saves_established_connection(self): - created_conn = self.mock_qpid_connection.establish.return_value - self.assertIs(self.conn._qpid_conn, created_conn) - - @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, )) - @patch(QPID_MODULE + '.sys.exc_info') - @patch(QPID_MODULE + '.qpid') - def test_mutates_ConnError_by_message(self, mock_qpid, mock_exc_info): - text = 'connection-forced: Authentication failed(320)' - my_conn_error = QpidException(text=text) - mock_qpid.messaging.Connection.establish.side_effect = my_conn_error - mock_exc_info.return_value = 'a', 'b', None - try: - self.conn = Connection(**self.connection_options) - except AuthenticationFailure as error: - exc_info = sys.exc_info() - self.assertNotIsInstance(error, QpidException) - self.assertIs(exc_info[1], 'b') - self.assertIsNone(exc_info[2]) - else: - self.fail('ConnectionError type was not mutated correctly') - - @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, )) - @patch(QPID_MODULE + '.sys.exc_info') - @patch(QPID_MODULE + '.qpid') - def test_mutates_ConnError_by_code(self, mock_qpid, mock_exc_info): - my_conn_error = QpidException(code=320, text='someothertext') - mock_qpid.messaging.Connection.establish.side_effect = my_conn_error - mock_exc_info.return_value = 'a', 'b', None - try: - self.conn = Connection(**self.connection_options) - except AuthenticationFailure as error: - exc_info = sys.exc_info() - self.assertNotIsInstance(error, QpidException) - self.assertIs(exc_info[1], 'b') - self.assertIsNone(exc_info[2]) - else: - self.fail('ConnectionError type was not mutated correctly') - - @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, )) - @patch(QPID_MODULE + '.sys.exc_info') - @patch(QPID_MODULE + '.qpid') - def test_connection__init__mutates_ConnError_by_message2(self, mock_qpid, - mock_exc_info): - """ - Test for PLAIN connection via python-saslwrapper, sans cyrus-sasl-plain - - This test is specific for what is returned when we attempt to connect - with PLAIN mech and python-saslwrapper is installed, but - cyrus-sasl-plain is not installed. - """ - my_conn_error = QpidException() - my_conn_error.text = 'Error in sasl_client_start (-4) SASL(-4): no '\ - 'mechanism available' - mock_qpid.messaging.Connection.establish.side_effect = my_conn_error - mock_exc_info.return_value = ('a', 'b', None) - try: - self.conn = Connection(**self.connection_options) - except AuthenticationFailure as error: - exc_info = sys.exc_info() - self.assertTrue(not isinstance(error, QpidException)) - self.assertTrue(exc_info[1] is 'b') - self.assertTrue(exc_info[2] is None) - else: - self.fail('ConnectionError type was not mutated correctly') - - @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, )) - @patch(QPID_MODULE + '.sys.exc_info') - @patch(QPID_MODULE + '.qpid') - def test_unknown_connection_error(self, mock_qpid, mock_exc_info): - # If we get a connection error that we don't understand, - # bubble it up as-is - my_conn_error = QpidException(code=999, text='someothertext') - mock_qpid.messaging.Connection.establish.side_effect = my_conn_error - mock_exc_info.return_value = 'a', 'b', None - try: - self.conn = Connection(**self.connection_options) - except Exception as error: - self.assertTrue(error.code == 999) - else: - self.fail('Connection should have thrown an exception') - - @patch.object(Transport, 'channel_errors', new=(QpidException, )) - @patch(QPID_MODULE + '.qpid') - @patch(QPID_MODULE + '.ConnectionError', new=IOError) - def test_non_qpid_error_raises(self, mock_qpid): - mock_Qpid_Connection = mock_qpid.messaging.Connection - my_conn_error = SyntaxError() - my_conn_error.text = 'some non auth related error message' - mock_Qpid_Connection.establish.side_effect = my_conn_error - with self.assertRaises(SyntaxError): - Connection(**self.connection_options) - - @patch(QPID_MODULE + '.qpid') - @patch(QPID_MODULE + '.ConnectionError', new=IOError) - def test_non_auth_conn_error_raises(self, mock_qpid): - mock_Qpid_Connection = mock_qpid.messaging.Connection - my_conn_error = IOError() - my_conn_error.text = 'some non auth related error message' - mock_Qpid_Connection.establish.side_effect = my_conn_error - with self.assertRaises(IOError): - Connection(**self.connection_options) - - -@case_no_python3 -@case_no_pypy -class TestConnectionClassAttributes(ConnectionTestBase): - - def test_connection_verify_class_attributes(self): - self.assertEqual(Channel, Connection.Channel) - - -@case_no_python3 -@case_no_pypy -class TestConnectionGetQpidConnection(ConnectionTestBase): - - def test_connection_get_qpid_connection(self): - self.conn._qpid_conn = Mock() - returned_connection = self.conn.get_qpid_connection() - self.assertIs(self.conn._qpid_conn, returned_connection) - - -@case_no_python3 -@case_no_pypy -class TestConnectionClose(ConnectionTestBase): - - def test_connection_close(self): - self.conn._qpid_conn = Mock() - self.conn.close() - self.conn._qpid_conn.close.assert_called_once_with() - - -@case_no_python3 -@case_no_pypy -class TestConnectionCloseChannel(ConnectionTestBase): - - def setUp(self): - super(TestConnectionCloseChannel, self).setUp() - self.conn.channels = Mock() - - def test_connection_close_channel_removes_channel_from_channel_list(self): - mock_channel = Mock() - self.conn.close_channel(mock_channel) - self.conn.channels.remove.assert_called_once_with(mock_channel) - - def test_connection_close_channel_handles_ValueError_being_raised(self): - self.conn.channels.remove = Mock(side_effect=ValueError()) - self.conn.close_channel(Mock()) - - def test_connection_close_channel_set_channel_connection_to_None(self): - mock_channel = Mock() - mock_channel.connection = False - self.conn.channels.remove = Mock(side_effect=ValueError()) - self.conn.close_channel(mock_channel) - self.assertIsNone(mock_channel.connection) - - -@case_no_python3 -@case_no_pypy -class ChannelTestBase(Case): - - def setUp(self): - self.patch_qpidtoollibs = patch(QPID_MODULE + '.qpidtoollibs') - self.mock_qpidtoollibs = self.patch_qpidtoollibs.start() - self.mock_broker_agent = self.mock_qpidtoollibs.BrokerAgent - self.conn = Mock() - self.transport = Mock() - self.channel = Channel(self.conn, self.transport) - - def tearDown(self): - self.patch_qpidtoollibs.stop() - - -@case_no_python3 -@case_no_pypy -class TestChannelPurge(ChannelTestBase): - - def setUp(self): - super(TestChannelPurge, self).setUp() - self.mock_queue = Mock() - - def test_gets_queue(self): - self.channel._purge(self.mock_queue) - getQueue = self.mock_broker_agent.return_value.getQueue - getQueue.assert_called_once_with(self.mock_queue) - - def test_does_not_call_purge_if_message_count_is_zero(self): - values = {'msgDepth': 0} - queue_obj = self.mock_broker_agent.return_value.getQueue.return_value - queue_obj.values = values - self.channel._purge(self.mock_queue) - self.assertFalse(queue_obj.purge.called) - - def test_purges_all_messages_from_queue(self): - values = {'msgDepth': 5} - queue_obj = self.mock_broker_agent.return_value.getQueue.return_value - queue_obj.values = values - self.channel._purge(self.mock_queue) - queue_obj.purge.assert_called_with(5) - - def test_returns_message_count(self): - values = {'msgDepth': 5} - queue_obj = self.mock_broker_agent.return_value.getQueue.return_value - queue_obj.values = values - result = self.channel._purge(self.mock_queue) - self.assertEqual(result, 5) - - @patch(QPID_MODULE + '.NotFound', new=QpidException) - def test_raises_channel_error_if_queue_does_not_exist(self): - self.mock_broker_agent.return_value.getQueue.return_value = None - self.assertRaises(QpidException, self.channel._purge, self.mock_queue) - - -@case_no_python3 -@case_no_pypy -class TestChannelPut(ChannelTestBase): - - @patch(QPID_MODULE + '.qpid') - def test_channel__put_onto_queue(self, mock_qpid): - routing_key = 'routingkey' - mock_message = Mock() - mock_Message_cls = mock_qpid.messaging.Message - - self.channel._put(routing_key, mock_message) - - address_str = '{0}; {{assert: always, node: {{type: queue}}}}'.format( - routing_key, - ) - self.transport.session.sender.assert_called_with(address_str) - mock_Message_cls.assert_called_with( - content=mock_message, subject=None, - ) - mock_sender = self.transport.session.sender.return_value - mock_sender.send.assert_called_with( - mock_Message_cls.return_value, sync=True, - ) - mock_sender.close.assert_called_with() - - @patch(QPID_MODULE + '.qpid') - def test_channel__put_onto_exchange(self, mock_qpid): - mock_routing_key = 'routingkey' - mock_exchange_name = 'myexchange' - mock_message = Mock() - mock_Message_cls = mock_qpid.messaging.Message - - self.channel._put(mock_routing_key, mock_message, mock_exchange_name) - - addrstr = '{0}/{1}; {{assert: always, node: {{type: topic}}}}'.format( - mock_exchange_name, mock_routing_key, - ) - self.transport.session.sender.assert_called_with(addrstr) - mock_Message_cls.assert_called_with( - content=mock_message, subject=mock_routing_key, - ) - mock_sender = self.transport.session.sender.return_value - mock_sender.send.assert_called_with( - mock_Message_cls.return_value, sync=True, - ) - mock_sender.close.assert_called_with() - - -@case_no_python3 -@case_no_pypy -class TestChannelGet(ChannelTestBase): - - def test_channel__get(self): - mock_queue = Mock() - - result = self.channel._get(mock_queue) - - self.transport.session.receiver.assert_called_once_with(mock_queue) - mock_rx = self.transport.session.receiver.return_value - mock_rx.fetch.assert_called_once_with(timeout=0) - mock_rx.close.assert_called_once_with() - self.assertIs(mock_rx.fetch.return_value, result) - - -@case_no_python3 -@case_no_pypy -class TestChannelClose(ChannelTestBase): - - def setUp(self): - super(TestChannelClose, self).setUp() - self.patch_basic_cancel = patch.object(self.channel, 'basic_cancel') - self.mock_basic_cancel = self.patch_basic_cancel.start() - self.mock_receiver1 = Mock() - self.mock_receiver2 = Mock() - self.channel._receivers = { - 1: self.mock_receiver1, 2: self.mock_receiver2, - } - self.channel.closed = False - - def tearDown(self): - self.patch_basic_cancel.stop() - super(TestChannelClose, self).tearDown() - - def test_channel_close_sets_close_attribute(self): - self.channel.close() - self.assertTrue(self.channel.closed) - - def test_channel_close_calls_basic_cancel_on_all_receivers(self): - self.channel.close() - self.mock_basic_cancel.assert_has_calls([call(1), call(2)]) - - def test_channel_close_calls_close_channel_on_connection(self): - self.channel.close() - self.conn.close_channel.assert_called_once_with(self.channel) - - def test_channel_close_calls_close_on_broker_agent(self): - self.channel.close() - self.channel._broker.close.assert_called_once_with() - - def test_channel_close_does_nothing_if_already_closed(self): - self.channel.closed = True - self.channel.close() - self.assertFalse(self.mock_basic_cancel.called) - - def test_channel_close_does_not_call_close_channel_if_conn_is_None(self): - self.channel.connection = None - self.channel.close() - self.assertFalse(self.conn.close_channel.called) - - -@case_no_python3 -@case_no_pypy -class TestChannelBasicQoS(ChannelTestBase): - - def test_channel_basic_qos_always_returns_one(self): - self.channel.basic_qos(2) - self.assertEqual(self.channel.qos.prefetch_count, 1) - - -@case_no_python3 -@case_no_pypy -class TestChannelBasicGet(ChannelTestBase): - - def setUp(self): - super(TestChannelBasicGet, self).setUp() - self.channel.Message = Mock() - self.channel._get = Mock() - - def test_channel_basic_get_calls__get_with_queue(self): - mock_queue = Mock() - self.channel.basic_get(mock_queue) - self.channel._get.assert_called_once_with(mock_queue) - - def test_channel_basic_get_creates_Message_correctly(self): - mock_queue = Mock() - self.channel.basic_get(mock_queue) - mock_raw_message = self.channel._get.return_value.content - self.channel.Message.assert_called_once_with( - self.channel, mock_raw_message, - ) - - def test_channel_basic_get_acknowledges_message_by_default(self): - mock_queue = Mock() - self.channel.basic_get(mock_queue) - mock_qpid_message = self.channel._get.return_value - acknowledge = self.transport.session.acknowledge - acknowledge.assert_called_once_with(message=mock_qpid_message) - - def test_channel_basic_get_acknowledges_message_with_no_ack_False(self): - mock_queue = Mock() - self.channel.basic_get(mock_queue, no_ack=False) - mock_qpid_message = self.channel._get.return_value - acknowledge = self.transport.session.acknowledge - acknowledge.assert_called_once_with(message=mock_qpid_message) - - def test_channel_basic_get_acknowledges_message_with_no_ack_True(self): - mock_queue = Mock() - self.channel.basic_get(mock_queue, no_ack=True) - mock_qpid_message = self.channel._get.return_value - acknowledge = self.transport.session.acknowledge - acknowledge.assert_called_once_with(message=mock_qpid_message) - - def test_channel_basic_get_returns_correct_message(self): - mock_queue = Mock() - basic_get_result = self.channel.basic_get(mock_queue) - expected_message = self.channel.Message.return_value - self.assertIs(expected_message, basic_get_result) - - def test_basic_get_returns_None_when_channel__get_raises_Empty(self): - mock_queue = Mock() - self.channel._get = Mock(side_effect=Empty) - basic_get_result = self.channel.basic_get(mock_queue) - self.assertEqual(self.channel.Message.call_count, 0) - self.assertIsNone(basic_get_result) - - -@case_no_python3 -@case_no_pypy -class TestChannelBasicCancel(ChannelTestBase): - - def setUp(self): - super(TestChannelBasicCancel, self).setUp() - self.channel._receivers = {1: Mock()} - - def test_channel_basic_cancel_no_error_if_consumer_tag_not_found(self): - self.channel.basic_cancel(2) - - def test_channel_basic_cancel_pops_receiver(self): - self.channel.basic_cancel(1) - self.assertNotIn(1, self.channel._receivers) - - def test_channel_basic_cancel_closes_receiver(self): - mock_receiver = self.channel._receivers[1] - self.channel.basic_cancel(1) - mock_receiver.close.assert_called_once_with() - - def test_channel_basic_cancel_pops__tag_to_queue(self): - self.channel._tag_to_queue = Mock() - self.channel.basic_cancel(1) - self.channel._tag_to_queue.pop.assert_called_once_with(1, None) - - def test_channel_basic_cancel_pops_connection__callbacks(self): - self.channel._tag_to_queue = Mock() - self.channel.basic_cancel(1) - mock_queue = self.channel._tag_to_queue.pop.return_value - self.conn._callbacks.pop.assert_called_once_with(mock_queue, None) - - -@case_no_python3 -@case_no_pypy -class TestChannelInit(ChannelTestBase, ExtraAssertionsMixin): - - def test_channel___init__sets_variables_as_expected(self): - self.assertIs(self.conn, self.channel.connection) - self.assertIs(self.transport, self.channel.transport) - self.assertFalse(self.channel.closed) - self.conn.get_qpid_connection.assert_called_once_with() - expected_broker_agent = self.mock_broker_agent.return_value - self.assertIs(self.channel._broker, expected_broker_agent) - self.assertDictEqual(self.channel._tag_to_queue, {}) - self.assertDictEqual(self.channel._receivers, {}) - self.assertIs(self.channel._qos, None) - - -@case_no_python3 -@case_no_pypy -class TestChannelBasicConsume(ChannelTestBase, ExtraAssertionsMixin): - - def setUp(self): - super(TestChannelBasicConsume, self).setUp() - self.conn._callbacks = {} - - def test_channel_basic_consume_adds_queue_to__tag_to_queue(self): - mock_tag = Mock() - mock_queue = Mock() - self.channel.basic_consume(mock_queue, Mock(), Mock(), mock_tag) - expected_dict = {mock_tag: mock_queue} - self.assertDictEqual(expected_dict, self.channel._tag_to_queue) - - def test_channel_basic_consume_adds_entry_to_connection__callbacks(self): - mock_queue = Mock() - self.channel.basic_consume(mock_queue, Mock(), Mock(), Mock()) - self.assertIn(mock_queue, self.conn._callbacks) - self.assertIsInstance(self.conn._callbacks[mock_queue], Callable) - - def test_channel_basic_consume_creates_new_receiver(self): - mock_queue = Mock() - self.channel.basic_consume(mock_queue, Mock(), Mock(), Mock()) - self.transport.session.receiver.assert_called_once_with(mock_queue) - - def test_channel_basic_consume_saves_new_receiver(self): - mock_tag = Mock() - self.channel.basic_consume(Mock(), Mock(), Mock(), mock_tag) - new_mock_receiver = self.transport.session.receiver.return_value - expected_dict = {mock_tag: new_mock_receiver} - self.assertDictEqual(expected_dict, self.channel._receivers) - - def test_channel_basic_consume_sets_capacity_on_new_receiver(self): - mock_prefetch_count = Mock() - self.channel.qos.prefetch_count = mock_prefetch_count - self.channel.basic_consume(Mock(), Mock(), Mock(), Mock()) - new_receiver = self.transport.session.receiver.return_value - self.assertTrue(new_receiver.capacity is mock_prefetch_count) - - def get_callback(self, no_ack=Mock(), original_cb=Mock()): - self.channel.Message = Mock() - mock_queue = Mock() - self.channel.basic_consume(mock_queue, no_ack, original_cb, Mock()) - return self.conn._callbacks[mock_queue] - - def test_channel_basic_consume_callback_creates_Message_correctly(self): - callback = self.get_callback() - mock_qpid_message = Mock() - callback(mock_qpid_message) - mock_content = mock_qpid_message.content - self.channel.Message.assert_called_once_with( - self.channel, mock_content, - ) - - def test_channel_basic_consume_callback_adds_message_to_QoS(self): - self.channel._qos = Mock() - callback = self.get_callback() - mock_qpid_message = Mock() - callback(mock_qpid_message) - mock_delivery_tag = self.channel.Message.return_value.delivery_tag - self.channel._qos.append.assert_called_once_with( - mock_qpid_message, mock_delivery_tag, - ) - - def test_channel_basic_consume_callback_gratuitously_acks(self): - self.channel.basic_ack = Mock() - callback = self.get_callback() - mock_qpid_message = Mock() - callback(mock_qpid_message) - mock_delivery_tag = self.channel.Message.return_value.delivery_tag - self.channel.basic_ack.assert_called_once_with(mock_delivery_tag) - - def test_channel_basic_consume_callback_does_not_ack_when_needed(self): - self.channel.basic_ack = Mock() - callback = self.get_callback(no_ack=False) - mock_qpid_message = Mock() - callback(mock_qpid_message) - self.assertFalse(self.channel.basic_ack.called) - - def test_channel_basic_consume_callback_calls_real_callback(self): - self.channel.basic_ack = Mock() - mock_original_callback = Mock() - callback = self.get_callback(original_cb=mock_original_callback) - mock_qpid_message = Mock() - callback(mock_qpid_message) - expected_message = self.channel.Message.return_value - mock_original_callback.assert_called_once_with(expected_message) - - -@case_no_python3 -@case_no_pypy -class TestChannelQueueDelete(ChannelTestBase): - - def setUp(self): - super(TestChannelQueueDelete, self).setUp() - self.patch__has_queue = patch.object(self.channel, '_has_queue') - self.mock__has_queue = self.patch__has_queue.start() - self.patch__size = patch.object(self.channel, '_size') - self.mock__size = self.patch__size.start() - self.patch__delete = patch.object(self.channel, '_delete') - self.mock__delete = self.patch__delete.start() - self.mock_queue = Mock() - - def tearDown(self): - self.patch__has_queue.stop() - self.patch__size.stop() - self.patch__delete.stop() - super(TestChannelQueueDelete, self).tearDown() - - def test_checks_if_queue_exists(self): - self.channel.queue_delete(self.mock_queue) - self.mock__has_queue.assert_called_once_with(self.mock_queue) - - def test_does_nothing_if_queue_does_not_exist(self): - self.mock__has_queue.return_value = False - self.channel.queue_delete(self.mock_queue) - self.assertFalse(self.mock__delete.called) - - def test_not_empty_and_if_empty_True_no_delete(self): - self.mock__size.return_value = 1 - self.channel.queue_delete(self.mock_queue, if_empty=True) - mock_broker = self.mock_broker_agent.return_value - self.assertFalse(mock_broker.getQueue.called) - - def test_calls_get_queue(self): - self.channel.queue_delete(self.mock_queue) - getQueue = self.mock_broker_agent.return_value.getQueue - getQueue.assert_called_once_with(self.mock_queue) - - def test_gets_queue_attribute(self): - self.channel.queue_delete(self.mock_queue) - queue_obj = self.mock_broker_agent.return_value.getQueue.return_value - queue_obj.getAttributes.assert_called_once_with() - - def test_queue_in_use_and_if_unused_no_delete(self): - queue_obj = self.mock_broker_agent.return_value.getQueue.return_value - queue_obj.getAttributes.return_value = {'consumerCount': 1} - self.channel.queue_delete(self.mock_queue, if_unused=True) - self.assertFalse(self.mock__delete.called) - - def test_calls__delete_with_queue(self): - self.channel.queue_delete(self.mock_queue) - self.mock__delete.assert_called_once_with(self.mock_queue) - - -@case_no_python3 -@case_no_pypy -class TestChannel(ExtraAssertionsMixin, Case): - - @patch(QPID_MODULE + '.qpidtoollibs') - def setUp(self, mock_qpidtoollibs): - self.mock_connection = Mock() - self.mock_qpid_connection = Mock() - self.mock_qpid_session = Mock() - self.mock_qpid_connection.session = Mock( - return_value=self.mock_qpid_session, - ) - self.mock_connection.get_qpid_connection = Mock( - return_value=self.mock_qpid_connection, - ) - self.mock_transport = Mock() - self.mock_broker = Mock() - self.mock_Message = Mock() - self.mock_BrokerAgent = mock_qpidtoollibs.BrokerAgent - self.mock_BrokerAgent.return_value = self.mock_broker - self.my_channel = Channel( - self.mock_connection, self.mock_transport, - ) - self.my_channel.Message = self.mock_Message - - def test_verify_QoS_class_attribute(self): - """Verify that the class attribute QoS refers to the QoS object""" - self.assertIs(QoS, Channel.QoS) - - def test_verify_Message_class_attribute(self): - """Verify that the class attribute Message refers to the Message - object.""" - self.assertIs(Message, Channel.Message) - - def test_body_encoding_class_attribute(self): - """Verify that the class attribute body_encoding is set to base64""" - self.assertEqual('base64', Channel.body_encoding) - - def test_codecs_class_attribute(self): - """Verify that the codecs class attribute has a correct key and - value.""" - self.assertIsInstance(Channel.codecs, dict) - self.assertIn('base64', Channel.codecs) - self.assertIsInstance(Channel.codecs['base64'], Base64) - - def test_size(self): - """Test getting the number of messages in a queue specified by - name and returning them.""" - message_count = 5 - mock_queue = Mock() - mock_queue_to_check = Mock() - mock_queue_to_check.values = {'msgDepth': message_count} - self.mock_broker.getQueue.return_value = mock_queue_to_check - result = self.my_channel._size(mock_queue) - self.mock_broker.getQueue.assert_called_with(mock_queue) - self.assertEqual(message_count, result) - - def test_delete(self): - """Test deleting a queue calls purge and delQueue with queue name.""" - mock_queue = Mock() - self.my_channel._purge = Mock() - result = self.my_channel._delete(mock_queue) - self.my_channel._purge.assert_called_with(mock_queue) - self.mock_broker.delQueue.assert_called_with(mock_queue) - self.assertIsNone(result) - - def test_has_queue_true(self): - """Test checking if a queue exists, and it does.""" - mock_queue = Mock() - self.mock_broker.getQueue.return_value = True - result = self.my_channel._has_queue(mock_queue) - self.assertTrue(result) - - def test_has_queue_false(self): - """Test checking if a queue exists, and it does not.""" - mock_queue = Mock() - self.mock_broker.getQueue.return_value = False - result = self.my_channel._has_queue(mock_queue) - self.assertFalse(result) - - @patch('amqp.protocol.queue_declare_ok_t') - def test_queue_declare_with_exception_raised(self, - mock_queue_declare_ok_t): - """Test declare_queue, where an exception is raised and silenced.""" - mock_queue = Mock() - mock_passive = Mock() - mock_durable = Mock() - mock_exclusive = Mock() - mock_auto_delete = Mock() - mock_nowait = Mock() - mock_arguments = Mock() - mock_msg_count = Mock() - mock_queue.startswith.return_value = False - mock_queue.endswith.return_value = False - options = { - 'passive': mock_passive, - 'durable': mock_durable, - 'exclusive': mock_exclusive, - 'auto-delete': mock_auto_delete, - 'arguments': mock_arguments, - } - mock_consumer_count = Mock() - mock_return_value = Mock() - values_dict = { - 'msgDepth': mock_msg_count, - 'consumerCount': mock_consumer_count, - } - mock_queue_data = Mock() - mock_queue_data.values = values_dict - exception_to_raise = Exception('The foo object already exists.') - self.mock_broker.addQueue.side_effect = exception_to_raise - self.mock_broker.getQueue.return_value = mock_queue_data - mock_queue_declare_ok_t.return_value = mock_return_value - result = self.my_channel.queue_declare( - mock_queue, - passive=mock_passive, - durable=mock_durable, - exclusive=mock_exclusive, - auto_delete=mock_auto_delete, - nowait=mock_nowait, - arguments=mock_arguments, - ) - self.mock_broker.addQueue.assert_called_with( - mock_queue, options=options, - ) - mock_queue_declare_ok_t.assert_called_with( - mock_queue, mock_msg_count, mock_consumer_count, - ) - self.assertIs(mock_return_value, result) - - def test_queue_declare_set_ring_policy_for_celeryev(self): - """Test declare_queue sets ring_policy for celeryev.""" - mock_queue = Mock() - mock_queue.startswith.return_value = True - mock_queue.endswith.return_value = False - expected_default_options = { - 'passive': False, - 'durable': False, - 'exclusive': False, - 'auto-delete': True, - 'arguments': None, - 'qpid.policy_type': 'ring', - } - mock_msg_count = Mock() - mock_consumer_count = Mock() - values_dict = { - 'msgDepth': mock_msg_count, - 'consumerCount': mock_consumer_count, - } - mock_queue_data = Mock() - mock_queue_data.values = values_dict - self.mock_broker.addQueue.return_value = None - self.mock_broker.getQueue.return_value = mock_queue_data - self.my_channel.queue_declare(mock_queue) - mock_queue.startswith.assert_called_with('celeryev') - self.mock_broker.addQueue.assert_called_with( - mock_queue, options=expected_default_options, - ) - - def test_queue_declare_set_ring_policy_for_pidbox(self): - """Test declare_queue sets ring_policy for pidbox.""" - mock_queue = Mock() - mock_queue.startswith.return_value = False - mock_queue.endswith.return_value = True - expected_default_options = { - 'passive': False, - 'durable': False, - 'exclusive': False, - 'auto-delete': True, - 'arguments': None, - 'qpid.policy_type': 'ring', - } - mock_msg_count = Mock() - mock_consumer_count = Mock() - values_dict = { - 'msgDepth': mock_msg_count, - 'consumerCount': mock_consumer_count, - } - mock_queue_data = Mock() - mock_queue_data.values = values_dict - self.mock_broker.addQueue.return_value = None - self.mock_broker.getQueue.return_value = mock_queue_data - self.my_channel.queue_declare(mock_queue) - mock_queue.endswith.assert_called_with('pidbox') - self.mock_broker.addQueue.assert_called_with( - mock_queue, options=expected_default_options, - ) - - def test_queue_declare_ring_policy_not_set_as_expected(self): - """Test declare_queue does not set ring_policy as expected.""" - mock_queue = Mock() - mock_queue.startswith.return_value = False - mock_queue.endswith.return_value = False - expected_default_options = { - 'passive': False, - 'durable': False, - 'exclusive': False, - 'auto-delete': True, - 'arguments': None, - } - mock_msg_count = Mock() - mock_consumer_count = Mock() - values_dict = { - 'msgDepth': mock_msg_count, - 'consumerCount': mock_consumer_count, - } - mock_queue_data = Mock() - mock_queue_data.values = values_dict - self.mock_broker.addQueue.return_value = None - self.mock_broker.getQueue.return_value = mock_queue_data - self.my_channel.queue_declare(mock_queue) - mock_queue.startswith.assert_called_with('celeryev') - mock_queue.endswith.assert_called_with('pidbox') - self.mock_broker.addQueue.assert_called_with( - mock_queue, options=expected_default_options, - ) - - def test_queue_declare_test_defaults(self): - """Test declare_queue defaults.""" - mock_queue = Mock() - mock_queue.startswith.return_value = False - mock_queue.endswith.return_value = False - expected_default_options = { - 'passive': False, - 'durable': False, - 'exclusive': False, - 'auto-delete': True, - 'arguments': None, - } - mock_msg_count = Mock() - mock_consumer_count = Mock() - values_dict = { - 'msgDepth': mock_msg_count, - 'consumerCount': mock_consumer_count, - } - mock_queue_data = Mock() - mock_queue_data.values = values_dict - self.mock_broker.addQueue.return_value = None - self.mock_broker.getQueue.return_value = mock_queue_data - self.my_channel.queue_declare(mock_queue) - self.mock_broker.addQueue.assert_called_with( - mock_queue, - options=expected_default_options, - ) - - def test_queue_declare_raises_exception_not_silenced(self): - unique_exception = Exception('This exception should not be silenced') - mock_queue = Mock() - self.mock_broker.addQueue.side_effect = unique_exception - with self.assertRaises(unique_exception.__class__): - self.my_channel.queue_declare(mock_queue) - self.mock_broker.addQueue.assert_called_once_with( - mock_queue, - options={ - 'exclusive': False, - 'durable': False, - 'qpid.policy_type': 'ring', - 'passive': False, - 'arguments': None, - 'auto-delete': True - }) - - def test_exchange_declare_raises_exception_and_silenced(self): - """Create exchange where an exception is raised and then silenced""" - self.mock_broker.addExchange.side_effect = Exception( - 'The foo object already exists.', - ) - self.my_channel.exchange_declare() - - def test_exchange_declare_raises_exception_not_silenced(self): - """Create Exchange where an exception is raised and not silenced.""" - unique_exception = Exception('This exception should not be silenced') - self.mock_broker.addExchange.side_effect = unique_exception - with self.assertRaises(unique_exception.__class__): - self.my_channel.exchange_declare() - - def test_exchange_declare(self): - """Create Exchange where an exception is NOT raised.""" - mock_exchange = Mock() - mock_type = Mock() - mock_durable = Mock() - options = {'durable': mock_durable} - result = self.my_channel.exchange_declare( - mock_exchange, mock_type, mock_durable, - ) - self.mock_broker.addExchange.assert_called_with( - mock_type, mock_exchange, options, - ) - self.assertIsNone(result) - - def test_exchange_delete(self): - """Test the deletion of an exchange by name.""" - mock_exchange = Mock() - result = self.my_channel.exchange_delete(mock_exchange) - self.mock_broker.delExchange.assert_called_with(mock_exchange) - self.assertIsNone(result) - - def test_queue_bind(self): - """Test binding a queue to an exchange using a routing key.""" - mock_queue = Mock() - mock_exchange = Mock() - mock_routing_key = Mock() - self.my_channel.queue_bind( - mock_queue, mock_exchange, mock_routing_key, - ) - self.mock_broker.bind.assert_called_with( - mock_exchange, mock_queue, mock_routing_key, - ) - - def test_queue_unbind(self): - """Test unbinding a queue from an exchange using a routing key.""" - mock_queue = Mock() - mock_exchange = Mock() - mock_routing_key = Mock() - self.my_channel.queue_unbind( - mock_queue, mock_exchange, mock_routing_key, - ) - self.mock_broker.unbind.assert_called_with( - mock_exchange, mock_queue, mock_routing_key, - ) - - def test_queue_purge(self): - """Test purging a queue by name.""" - mock_queue = Mock() - purge_result = Mock() - self.my_channel._purge = Mock(return_value=purge_result) - result = self.my_channel.queue_purge(mock_queue) - self.my_channel._purge.assert_called_with(mock_queue) - self.assertIs(purge_result, result) - - @patch(QPID_MODULE + '.Channel.qos') - def test_basic_ack(self, mock_qos): - """Test that basic_ack calls the QoS object properly.""" - mock_delivery_tag = Mock() - self.my_channel.basic_ack(mock_delivery_tag) - mock_qos.ack.assert_called_with(mock_delivery_tag) - - @patch(QPID_MODULE + '.Channel.qos') - def test_basic_reject(self, mock_qos): - """Test that basic_reject calls the QoS object properly.""" - mock_delivery_tag = Mock() - mock_requeue_value = Mock() - self.my_channel.basic_reject(mock_delivery_tag, mock_requeue_value) - mock_qos.reject.assert_called_with( - mock_delivery_tag, requeue=mock_requeue_value, - ) - - def test_qos_manager_is_none(self): - """Test the qos property if the QoS object did not already exist.""" - self.my_channel._qos = None - result = self.my_channel.qos - self.assertIsInstance(result, QoS) - self.assertEqual(result, self.my_channel._qos) - - def test_qos_manager_already_exists(self): - """Test the qos property if the QoS object already exists.""" - mock_existing_qos = Mock() - self.my_channel._qos = mock_existing_qos - result = self.my_channel.qos - self.assertIs(mock_existing_qos, result) - - def test_prepare_message(self): - """Test that prepare_message() returns the correct result.""" - mock_body = Mock() - mock_priority = Mock() - mock_content_encoding = Mock() - mock_content_type = Mock() - mock_header1 = Mock() - mock_header2 = Mock() - mock_properties1 = Mock() - mock_properties2 = Mock() - headers = {'header1': mock_header1, 'header2': mock_header2} - properties = {'properties1': mock_properties1, - 'properties2': mock_properties2} - result = self.my_channel.prepare_message( - mock_body, - priority=mock_priority, - content_type=mock_content_type, - content_encoding=mock_content_encoding, - headers=headers, - properties=properties) - self.assertIs(mock_body, result['body']) - self.assertIs(mock_content_encoding, result['content-encoding']) - self.assertIs(mock_content_type, result['content-type']) - self.assertDictEqual(headers, result['headers']) - self.assertDictContainsSubset(properties, result['properties']) - self.assertIs( - mock_priority, result['properties']['delivery_info']['priority'], - ) - - @patch('__builtin__.buffer') - @patch(QPID_MODULE + '.Channel.body_encoding') - @patch(QPID_MODULE + '.Channel.encode_body') - @patch(QPID_MODULE + '.Channel._put') - def test_basic_publish(self, mock_put, - mock_encode_body, - mock_body_encoding, - mock_buffer): - """Test basic_publish().""" - mock_original_body = Mock() - mock_encoded_body = 'this is my encoded body' - mock_message = {'body': mock_original_body, - 'properties': {'delivery_info': {}}} - mock_encode_body.return_value = ( - mock_encoded_body, mock_body_encoding, - ) - mock_exchange = Mock() - mock_routing_key = Mock() - mock_encoded_buffered_body = Mock() - mock_buffer.return_value = mock_encoded_buffered_body - self.my_channel.basic_publish( - mock_message, mock_exchange, mock_routing_key, - ) - mock_encode_body.assert_called_once_with( - mock_original_body, mock_body_encoding, - ) - mock_buffer.assert_called_once_with(mock_encoded_body) - self.assertIs(mock_message['body'], mock_encoded_buffered_body) - self.assertIs( - mock_message['properties']['body_encoding'], mock_body_encoding, - ) - self.assertIsInstance( - mock_message['properties']['delivery_tag'], uuid.UUID, - ) - self.assertIs( - mock_message['properties']['delivery_info']['exchange'], - mock_exchange, - ) - self.assertIs( - mock_message['properties']['delivery_info']['routing_key'], - mock_routing_key, - ) - mock_put.assert_called_with( - mock_routing_key, mock_message, mock_exchange, - ) - - @patch(QPID_MODULE + '.Channel.codecs') - def test_encode_body_expected_encoding(self, mock_codecs): - """Test if encode_body() works when encoding is set correctly""" - mock_body = Mock() - mock_encoder = Mock() - mock_encoded_result = Mock() - mock_codecs.get.return_value = mock_encoder - mock_encoder.encode.return_value = mock_encoded_result - result = self.my_channel.encode_body(mock_body, encoding='base64') - expected_result = (mock_encoded_result, 'base64') - self.assertEqual(expected_result, result) - - @patch(QPID_MODULE + '.Channel.codecs') - def test_encode_body_not_expected_encoding(self, mock_codecs): - """Test if encode_body() works when encoding is not set correctly.""" - mock_body = Mock() - result = self.my_channel.encode_body(mock_body, encoding=None) - expected_result = mock_body, None - self.assertEqual(expected_result, result) - - @patch(QPID_MODULE + '.Channel.codecs') - def test_decode_body_expected_encoding(self, mock_codecs): - """Test if decode_body() works when encoding is set correctly.""" - mock_body = Mock() - mock_decoder = Mock() - mock_decoded_result = Mock() - mock_codecs.get.return_value = mock_decoder - mock_decoder.decode.return_value = mock_decoded_result - result = self.my_channel.decode_body(mock_body, encoding='base64') - self.assertEqual(mock_decoded_result, result) - - @patch(QPID_MODULE + '.Channel.codecs') - def test_decode_body_not_expected_encoding(self, mock_codecs): - """Test if decode_body() works when encoding is not set correctly.""" - mock_body = Mock() - result = self.my_channel.decode_body(mock_body, encoding=None) - self.assertEqual(mock_body, result) - - def test_typeof_exchange_exists(self): - """Test that typeof() finds an exchange that already exists.""" - mock_exchange = Mock() - mock_qpid_exchange = Mock() - mock_attributes = {} - mock_type = Mock() - mock_attributes['type'] = mock_type - mock_qpid_exchange.getAttributes.return_value = mock_attributes - self.mock_broker.getExchange.return_value = mock_qpid_exchange - result = self.my_channel.typeof(mock_exchange) - self.assertIs(mock_type, result) - - def test_typeof_exchange_does_not_exist(self): - """Test that typeof() finds an exchange that does not exists.""" - mock_exchange = Mock() - mock_default = Mock() - self.mock_broker.getExchange.return_value = None - result = self.my_channel.typeof(mock_exchange, default=mock_default) - self.assertIs(mock_default, result) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportInit(Case): - - def setUp(self): - self.patch_a = patch.object(Transport, 'verify_runtime_environment') - self.mock_verify_runtime_environment = self.patch_a.start() - - self.patch_b = patch(QPID_MODULE + '.base.Transport.__init__') - self.mock_base_Transport__init__ = self.patch_b.start() - - def tearDown(self): - self.patch_a.stop() - self.patch_b.stop() - - def test_Transport___init___calls_verify_runtime_environment(self): - Transport(Mock()) - self.mock_verify_runtime_environment.assert_called_once_with() - - def test_transport___init___calls_parent_class___init__(self): - m = Mock() - Transport(m) - self.mock_base_Transport__init__.assert_called_once_with(m) - - def test_transport___init___sets_use_async_interface_False(self): - transport = Transport(Mock()) - self.assertFalse(transport.use_async_interface) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportDrainEvents(Case): - - def setUp(self): - self.transport = Transport(Mock()) - self.transport.session = Mock() - self.mock_queue = Mock() - self.mock_message = Mock() - self.mock_conn = Mock() - self.mock_callback = Mock() - self.mock_conn._callbacks = {self.mock_queue: self.mock_callback} - - def mock_next_receiver(self, timeout): - time.sleep(0.3) - mock_receiver = Mock() - mock_receiver.source = self.mock_queue - mock_receiver.fetch.return_value = self.mock_message - return mock_receiver - - def test_socket_timeout_raised_when_all_receivers_empty(self): - with patch(QPID_MODULE + '.QpidEmpty', new=QpidException): - self.transport.session.next_receiver.side_effect = QpidException() - with self.assertRaises(socket.timeout): - self.transport.drain_events(Mock()) - - def test_socket_timeout_raised_when_by_timeout(self): - self.transport.session.next_receiver = self.mock_next_receiver - with self.assertRaises(socket.timeout): - self.transport.drain_events(self.mock_conn, timeout=1) - - def test_timeout_returns_no_earlier_then_asked_for(self): - self.transport.session.next_receiver = self.mock_next_receiver - start_time = monotonic() - try: - self.transport.drain_events(self.mock_conn, timeout=1) - except socket.timeout: - pass - elapsed_time_in_s = monotonic() - start_time - self.assertGreaterEqual(elapsed_time_in_s, 1.0) - - def test_callback_is_called(self): - self.transport.session.next_receiver = self.mock_next_receiver - try: - self.transport.drain_events(self.mock_conn, timeout=1) - except socket.timeout: - pass - self.mock_callback.assert_called_with(self.mock_message) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportCreateChannel(Case): - - def setUp(self): - self.transport = Transport(Mock()) - self.mock_conn = Mock() - self.mock_new_channel = Mock() - self.mock_conn.Channel.return_value = self.mock_new_channel - self.returned_channel = self.transport.create_channel(self.mock_conn) - - def test_new_channel_created_from_connection(self): - self.assertIs(self.mock_new_channel, self.returned_channel) - self.mock_conn.Channel.assert_called_with( - self.mock_conn, self.transport, - ) - - def test_new_channel_added_to_connection_channel_list(self): - append_method = self.mock_conn.channels.append - append_method.assert_called_with(self.mock_new_channel) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportEstablishConnection(Case): - - def setUp(self): - - class MockClient(object): - pass - - self.client = MockClient() - self.client.connect_timeout = 4 - self.client.ssl = False - self.client.transport_options = {} - self.client.userid = None - self.client.password = None - self.client.login_method = None - self.transport = Transport(self.client) - self.mock_conn = Mock() - self.transport.Connection = self.mock_conn - - def test_transport_establish_conn_new_option_overwrites_default(self): - self.client.userid = 'new-userid' - self.client.password = 'new-password' - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - username=self.client.userid, - password=self.client.password, - sasl_mechanisms='PLAIN', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_establish_conn_empty_client_is_default(self): - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - sasl_mechanisms='ANONYMOUS', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_establish_conn_additional_transport_option(self): - new_param_value = 'mynewparam' - self.client.transport_options['new_param'] = new_param_value - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - sasl_mechanisms='ANONYMOUS', - host='localhost', - timeout=4, - new_param=new_param_value, - port=5672, - transport='tcp', - ) - - def test_transport_establish_conn_transform_localhost_to_127_0_0_1(self): - self.client.hostname = 'localhost' - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - sasl_mechanisms='ANONYMOUS', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_password_no_userid_raises_exception(self): - self.client.password = 'somepass' - self.assertRaises(Exception, self.transport.establish_connection) - - def test_transport_userid_no_password_raises_exception(self): - self.client.userid = 'someusername' - self.assertRaises(Exception, self.transport.establish_connection) - - def test_transport_overrides_sasl_mech_from_login_method(self): - self.client.login_method = 'EXTERNAL' - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - sasl_mechanisms='EXTERNAL', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_overrides_sasl_mech_has_username(self): - self.client.userid = 'new-userid' - self.client.login_method = 'EXTERNAL' - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - username=self.client.userid, - sasl_mechanisms='EXTERNAL', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_establish_conn_set_password(self): - self.client.userid = 'someuser' - self.client.password = 'somepass' - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - username='someuser', - password='somepass', - sasl_mechanisms='PLAIN', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_establish_conn_no_ssl_sets_transport_tcp(self): - self.client.ssl = False - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - sasl_mechanisms='ANONYMOUS', - host='localhost', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_establish_conn_with_ssl_with_hostname_check(self): - self.client.ssl = { - 'keyfile': 'my_keyfile', - 'certfile': 'my_certfile', - 'ca_certs': 'my_cacerts', - 'cert_reqs': ssl.CERT_REQUIRED, - } - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - ssl_certfile='my_certfile', - ssl_trustfile='my_cacerts', - timeout=4, - ssl_skip_hostname_check=False, - sasl_mechanisms='ANONYMOUS', - host='localhost', - ssl_keyfile='my_keyfile', - port=5672, transport='ssl', - ) - - def test_transport_establish_conn_with_ssl_skip_hostname_check(self): - self.client.ssl = { - 'keyfile': 'my_keyfile', - 'certfile': 'my_certfile', - 'ca_certs': 'my_cacerts', - 'cert_reqs': ssl.CERT_OPTIONAL, - } - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - ssl_certfile='my_certfile', - ssl_trustfile='my_cacerts', - timeout=4, - ssl_skip_hostname_check=True, - sasl_mechanisms='ANONYMOUS', - host='localhost', - ssl_keyfile='my_keyfile', - port=5672, transport='ssl', - ) - - def test_transport_establish_conn_sets_client_on_connection_object(self): - self.transport.establish_connection() - self.assertIs(self.mock_conn.return_value.client, self.client) - - def test_transport_establish_conn_creates_session_on_transport(self): - self.transport.establish_connection() - qpid_conn = self.mock_conn.return_value.get_qpid_connection - new_mock_session = qpid_conn.return_value.session.return_value - self.assertIs(self.transport.session, new_mock_session) - - def test_transport_establish_conn_returns_new_connection_object(self): - new_conn = self.transport.establish_connection() - self.assertIs(new_conn, self.mock_conn.return_value) - - def test_transport_establish_conn_uses_hostname_if_not_default(self): - self.client.hostname = 'some_other_hostname' - self.transport.establish_connection() - self.mock_conn.assert_called_once_with( - sasl_mechanisms='ANONYMOUS', - host='some_other_hostname', - timeout=4, - port=5672, - transport='tcp', - ) - - def test_transport_sets_qpid_message_ready_handler(self): - self.transport.establish_connection() - qpid_conn_call = self.mock_conn.return_value.get_qpid_connection - mock_session = qpid_conn_call.return_value.session.return_value - mock_set_callback = mock_session.set_message_received_notify_handler - expected_msg_callback = self.transport._qpid_message_ready_handler - mock_set_callback.assert_called_once_with(expected_msg_callback) - - def test_transport_sets_session_exception_handler(self): - self.transport.establish_connection() - qpid_conn_call = self.mock_conn.return_value.get_qpid_connection - mock_session = qpid_conn_call.return_value.session.return_value - mock_set_callback = mock_session.set_async_exception_notify_handler - exc_callback = self.transport._qpid_async_exception_notify_handler - mock_set_callback.assert_called_once_with(exc_callback) - - def test_transport_sets_connection_exception_handler(self): - self.transport.establish_connection() - qpid_conn_call = self.mock_conn.return_value.get_qpid_connection - qpid_conn = qpid_conn_call.return_value - mock_set_callback = qpid_conn.set_async_exception_notify_handler - exc_callback = self.transport._qpid_async_exception_notify_handler - mock_set_callback.assert_called_once_with(exc_callback) - - -@case_no_python3 -@case_no_pypy -class TestTransportClassAttributes(Case): - - def test_verify_Connection_attribute(self): - self.assertIs(Connection, Transport.Connection) - - def test_verify_polling_disabled(self): - self.assertIsNone(Transport.polling_interval) - - def test_transport_verify_supports_asynchronous_events(self): - self.assertTrue(Transport.supports_ev) - - def test_verify_driver_type_and_name(self): - self.assertEqual('qpid', Transport.driver_type) - self.assertEqual('qpid', Transport.driver_name) - - def test_transport_verify_recoverable_connection_errors(self): - connection_errors = Transport.recoverable_connection_errors - self.assertIn(ConnectionError, connection_errors) - self.assertIn(select.error, connection_errors) - - def test_transport_verify_recoverable_channel_errors(self): - channel_errors = Transport.recoverable_channel_errors - self.assertIn(NotFound, channel_errors) - - def test_transport_verify_pre_kombu_3_0_exception_labels(self): - self.assertEqual(Transport.recoverable_channel_errors, - Transport.channel_errors) - self.assertEqual(Transport.recoverable_connection_errors, - Transport.connection_errors) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportRegisterWithEventLoop(Case): - - def test_transport_register_with_event_loop_calls_add_reader(self): - transport = Transport(Mock()) - mock_connection = Mock() - mock_loop = Mock() - transport.register_with_event_loop(mock_connection, mock_loop) - mock_loop.add_reader.assert_called_with( - transport.r, transport.on_readable, mock_connection, mock_loop, - ) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportQpidCallbackHandlersAsync(Case): - - def setUp(self): - self.patch_a = patch(QPID_MODULE + '.os.write') - self.mock_os_write = self.patch_a.start() - self.transport = Transport(Mock()) - self.transport.register_with_event_loop(Mock(), Mock()) - - def tearDown(self): - self.patch_a.stop() - - def test__qpid_message_ready_handler_writes_symbol_to_fd(self): - self.transport._qpid_message_ready_handler(Mock()) - self.mock_os_write.assert_called_once_with(self.transport._w, '0') - - def test__qpid_async_exception_notify_handler_writes_symbol_to_fd(self): - self.transport._qpid_async_exception_notify_handler(Mock(), Mock()) - self.mock_os_write.assert_called_once_with(self.transport._w, 'e') - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportQpidCallbackHandlersSync(Case): - - def setUp(self): - self.patch_a = patch(QPID_MODULE + '.os.write') - self.mock_os_write = self.patch_a.start() - self.transport = Transport(Mock()) - - def tearDown(self): - self.patch_a.stop() - - def test__qpid_message_ready_handler_dows_not_write(self): - self.transport._qpid_message_ready_handler(Mock()) - self.assertTrue(not self.mock_os_write.called) - - def test__qpid_async_exception_notify_handler_does_not_write(self): - self.transport._qpid_async_exception_notify_handler(Mock(), Mock()) - self.assertTrue(not self.mock_os_write.called) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportOnReadable(Case): - - def setUp(self): - self.patch_a = patch(QPID_MODULE + '.os.read') - self.mock_os_read = self.patch_a.start() - - self.patch_b = patch.object(Transport, 'drain_events') - self.mock_drain_events = self.patch_b.start() - self.transport = Transport(Mock()) - self.transport.register_with_event_loop(Mock(), Mock()) - - def tearDown(self): - self.patch_a.stop() - self.patch_b.stop() - - def test_transport_on_readable_reads_symbol_from_fd(self): - self.transport.on_readable(Mock(), Mock()) - self.mock_os_read.assert_called_once_with(self.transport.r, 1) - - def test_transport_on_readable_calls_drain_events(self): - mock_connection = Mock() - self.transport.on_readable(mock_connection, Mock()) - self.mock_drain_events.assert_called_with(mock_connection) - - def test_transport_on_readable_catches_socket_timeout(self): - self.mock_drain_events.side_effect = socket.timeout() - self.transport.on_readable(Mock(), Mock()) - - def test_transport_on_readable_ignores_non_socket_timeout_exception(self): - self.mock_drain_events.side_effect = IOError() - with self.assertRaises(IOError): - self.transport.on_readable(Mock(), Mock()) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransportVerifyRuntimeEnvironment(Case): - - def setUp(self): - self.verify_runtime_environment = Transport.verify_runtime_environment - self.patch_a = patch.object(Transport, 'verify_runtime_environment') - self.patch_a.start() - self.transport = Transport(Mock()) - - def tearDown(self): - self.patch_a.stop() - - @patch(QPID_MODULE + '.PY3', new=True) - def test_raises_exception_for_Python3(self): - with self.assertRaises(RuntimeError): - self.verify_runtime_environment(self.transport) - - @patch('__builtin__.getattr') - def test_raises_exc_for_PyPy(self, mock_getattr): - mock_getattr.return_value = True - with self.assertRaises(RuntimeError): - self.verify_runtime_environment(self.transport) - - @patch(QPID_MODULE + '.dependency_is_none') - def test_raises_exc_dep_missing(self, mock_dep_is_none): - mock_dep_is_none.return_value = True - with self.assertRaises(RuntimeError): - self.verify_runtime_environment(self.transport) - - @patch(QPID_MODULE + '.dependency_is_none') - def test_calls_dependency_is_none(self, mock_dep_is_none): - mock_dep_is_none.return_value = False - self.verify_runtime_environment(self.transport) - self.assertTrue(mock_dep_is_none.called) - - def test_raises_no_exception(self): - self.verify_runtime_environment(self.transport) - - -@case_no_python3 -@case_no_pypy -@disable_runtime_dependency_check -class TestTransport(ExtraAssertionsMixin, Case): - - def setUp(self): - """Creates a mock_client to be used in testing.""" - self.mock_client = Mock() - - def test_close_connection(self): - """Test that close_connection calls close on the connection.""" - my_transport = Transport(self.mock_client) - mock_connection = Mock() - my_transport.close_connection(mock_connection) - mock_connection.close.assert_called_once_with() - - def test_default_connection_params(self): - """Test that the default_connection_params are correct""" - correct_params = { - 'hostname': 'localhost', - 'port': 5672, - } - my_transport = Transport(self.mock_client) - result_params = my_transport.default_connection_params - self.assertDictEqual(correct_params, result_params) - - @patch(QPID_MODULE + '.os.close') - def test_del_sync(self, close): - my_transport = Transport(self.mock_client) - my_transport.__del__() - self.assertFalse(close.called) - - @patch(QPID_MODULE + '.os.close') - def test_del_async(self, close): - my_transport = Transport(self.mock_client) - my_transport.register_with_event_loop(Mock(), Mock()) - my_transport.__del__() - self.assertTrue(close.called) - - @patch(QPID_MODULE + '.os.close') - def test_del_async_failed(self, close): - close.side_effect = OSError() - my_transport = Transport(self.mock_client) - my_transport.register_with_event_loop(Mock(), Mock()) - my_transport.__del__() - self.assertTrue(close.called) diff --git a/kombu/tests/transport/virtual/test_exchange.py b/kombu/tests/transport/virtual/test_exchange.py deleted file mode 100644 index e5f5df39..00000000 --- a/kombu/tests/transport/virtual/test_exchange.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -from kombu import Connection -from kombu.transport.virtual import exchange - -from kombu.tests.case import Case, Mock -from kombu.tests.mocks import Transport - - -class ExchangeCase(Case): - type = None - - def setup(self): - if self.type: - self.e = self.type(Connection(transport=Transport).channel()) - - -class test_Direct(ExchangeCase): - type = exchange.DirectExchange - table = [('rFoo', None, 'qFoo'), - ('rFoo', None, 'qFox'), - ('rBar', None, 'qBar'), - ('rBaz', None, 'qBaz')] - - def test_lookup(self): - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'rFoo', None), - {'qFoo', 'qFox'}, - ) - self.assertSetEqual( - self.e.lookup(self.table, 'eMoz', 'rMoz', 'DEFAULT'), - set(), - ) - self.assertSetEqual( - self.e.lookup(self.table, 'eBar', 'rBar', None), - {'qBar'}, - ) - - -class test_Fanout(ExchangeCase): - type = exchange.FanoutExchange - table = [(None, None, 'qFoo'), - (None, None, 'qFox'), - (None, None, 'qBar')] - - def test_lookup(self): - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'rFoo', None), - {'qFoo', 'qFox', 'qBar'}, - ) - - def test_deliver_when_fanout_supported(self): - self.e.channel = Mock() - self.e.channel.supports_fanout = True - message = Mock() - - self.e.deliver(message, 'exchange', 'rkey') - self.e.channel._put_fanout.assert_called_with( - 'exchange', message, 'rkey', - ) - - def test_deliver_when_fanout_unsupported(self): - self.e.channel = Mock() - self.e.channel.supports_fanout = False - - self.e.deliver(Mock(), 'exchange', None) - self.e.channel._put_fanout.assert_not_called() - - -class test_Topic(ExchangeCase): - type = exchange.TopicExchange - table = [ - ('stock.#', None, 'rFoo'), - ('stock.us.*', None, 'rBar'), - ] - - def setup(self): - super(test_Topic, self).setup() - self.table = [(rkey, self.e.key_to_pattern(rkey), queue) - for rkey, _, queue in self.table] - - def test_prepare_bind(self): - x = self.e.prepare_bind('qFoo', 'eFoo', 'stock.#', {}) - self.assertTupleEqual(x, ('stock.#', r'^stock\..*?$', 'qFoo')) - - def test_lookup(self): - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'stock.us.nasdaq', None), - {'rFoo', 'rBar'}, - ) - self.assertTrue(self.e._compiled) - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'stock.europe.OSE', None), - {'rFoo'}, - ) - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'stockxeuropexOSE', None), - set(), - ) - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', - 'candy.schleckpulver.snap_crackle', None), - set(), - ) - - def test_deliver(self): - self.e.channel = Mock() - self.e.channel._lookup.return_value = ('a', 'b') - message = Mock() - self.e.deliver(message, 'exchange', 'rkey') - - expected = [(('a', message), {}), - (('b', message), {})] - self.assertListEqual(self.e.channel._put.call_args_list, expected) - - -class test_TopicMultibind(ExchangeCase): - # Testing message delivery in case of multiple overlapping - # bindings for the same queue. As AMQP states, in case of - # overlapping bindings, a message must be delivered once to - # each matching queue. - type = exchange.TopicExchange - table = [ - ('stock', None, 'rFoo'), - ('stock.#', None, 'rFoo'), - ('stock.us.*', None, 'rFoo'), - ('#', None, 'rFoo'), - ] - - def setup(self): - super(test_TopicMultibind, self).setup() - self.table = [(rkey, self.e.key_to_pattern(rkey), queue) - for rkey, _, queue in self.table] - - def test_lookup(self): - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'stock.us.nasdaq', None), - {'rFoo'}, - ) - self.assertTrue(self.e._compiled) - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'stock.europe.OSE', None), - {'rFoo'}, - ) - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', 'stockxeuropexOSE', None), - {'rFoo'}, - ) - self.assertSetEqual( - self.e.lookup(self.table, 'eFoo', - 'candy.schleckpulver.snap_crackle', None), - {'rFoo'}, - ) - - -class test_ExchangeType(ExchangeCase): - type = exchange.ExchangeType - - def test_lookup(self): - with self.assertRaises(NotImplementedError): - self.e.lookup([], 'eFoo', 'rFoo', None) - - def test_prepare_bind(self): - self.assertTupleEqual( - self.e.prepare_bind('qFoo', 'eFoo', 'rFoo', {}), - ('rFoo', None, 'qFoo'), - ) - - def test_equivalent(self): - e1 = dict( - type='direct', - durable=True, - auto_delete=True, - arguments={}, - ) - self.assertTrue( - self.e.equivalent(e1, 'eFoo', 'direct', True, True, {}), - ) - self.assertFalse( - self.e.equivalent(e1, 'eFoo', 'topic', True, True, {}), - ) - self.assertFalse( - self.e.equivalent(e1, 'eFoo', 'direct', False, True, {}), - ) - self.assertFalse( - self.e.equivalent(e1, 'eFoo', 'direct', True, False, {}), - ) - self.assertFalse( - self.e.equivalent(e1, 'eFoo', 'direct', True, True, - {'expires': 3000}), - ) - e2 = dict(e1, arguments={'expires': 3000}) - self.assertTrue( - self.e.equivalent(e2, 'eFoo', 'direct', True, True, - {'expires': 3000}), - ) - self.assertFalse( - self.e.equivalent(e2, 'eFoo', 'direct', True, True, - {'expires': 6000}), - ) diff --git a/kombu/tests/utils/test_div.py b/kombu/tests/utils/test_div.py deleted file mode 100644 index 513cb1b7..00000000 --- a/kombu/tests/utils/test_div.py +++ /dev/null @@ -1,51 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -import pickle - -from io import StringIO, BytesIO - -from kombu.utils.div import emergency_dump_state - -from kombu.tests.case import Case, mock - - -class MyStringIO(StringIO): - - def close(self): - pass - - -class MyBytesIO(BytesIO): - - def close(self): - pass - - -class test_emergency_dump_state(Case): - - @mock.stdouts - def test_dump(self, stdout, stderr): - fh = MyBytesIO() - - emergency_dump_state( - {'foo': 'bar'}, open_file=lambda n, m: fh) - self.assertDictEqual( - pickle.loads(fh.getvalue()), {'foo': 'bar'}) - self.assertTrue(stderr.getvalue()) - self.assertFalse(stdout.getvalue()) - - @mock.stdouts - def test_dump_second_strategy(self, stdout, stderr): - fh = MyStringIO() - - def raise_something(*args, **kwargs): - raise KeyError('foo') - - emergency_dump_state( - {'foo': 'bar'}, - open_file=lambda n, m: fh, dump=raise_something - ) - self.assertIn('foo', fh.getvalue()) - self.assertIn('bar', fh.getvalue()) - self.assertTrue(stderr.getvalue()) - self.assertFalse(stdout.getvalue()) diff --git a/kombu/tests/utils/test_encoding.py b/kombu/tests/utils/test_encoding.py deleted file mode 100644 index 67bf4ad4..00000000 --- a/kombu/tests/utils/test_encoding.py +++ /dev/null @@ -1,109 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, unicode_literals - -import sys - -from contextlib import contextmanager - -from kombu.five import bytes_t, string_t -from kombu.utils.encoding import ( - get_default_encoding_file, safe_str, - set_default_encoding_file, default_encoding, -) - -from kombu.tests.case import Case, patch, skip - - -@contextmanager -def clean_encoding(): - old_encoding = sys.modules.pop('kombu.utils.encoding', None) - import kombu.utils.encoding - try: - yield kombu.utils.encoding - finally: - if old_encoding: - sys.modules['kombu.utils.encoding'] = old_encoding - - -class test_default_encoding(Case): - - def test_set_default_file(self): - prev = get_default_encoding_file() - try: - set_default_encoding_file('/foo.txt') - self.assertEqual(get_default_encoding_file(), '/foo.txt') - finally: - set_default_encoding_file(prev) - - @patch('sys.getfilesystemencoding') - def test_default(self, getdefaultencoding): - getdefaultencoding.return_value = 'ascii' - with clean_encoding() as encoding: - enc = encoding.default_encoding() - if sys.platform.startswith('java'): - self.assertEqual(enc, 'utf-8') - else: - self.assertEqual(enc, 'ascii') - getdefaultencoding.assert_called_with() - - -@skip.if_python3 -class test_encoding_utils(Case): - - def test_str_to_bytes(self): - with clean_encoding() as e: - self.assertIsInstance(e.str_to_bytes('foobar'), bytes_t) - - def test_from_utf8(self): - with clean_encoding() as e: - self.assertIsInstance(e.from_utf8('foobar'), bytes_t) - - def test_default_encode(self): - with clean_encoding() as e: - self.assertTrue(e.default_encode(b'foo')) - - -class test_safe_str(Case): - - def setup(self): - self._cencoding = patch('sys.getfilesystemencoding') - self._encoding = self._cencoding.__enter__() - self._encoding.return_value = 'ascii' - - def teardown(self): - self._cencoding.__exit__() - - def test_when_bytes(self): - self.assertEqual(safe_str('foo'), 'foo') - - def test_when_unicode(self): - self.assertIsInstance(safe_str('foo'), string_t) - - def test_when_encoding_utf8(self): - with patch('sys.getfilesystemencoding') as encoding: - encoding.return_value = 'utf-8' - self.assertEqual(default_encoding(), 'utf-8') - s = 'The quiæk fåx jømps øver the lazy dåg' - res = safe_str(s) - self.assertIsInstance(res, str) - - def test_when_containing_high_chars(self): - with patch('sys.getfilesystemencoding') as encoding: - encoding.return_value = 'ascii' - s = 'The quiæk fåx jømps øver the lazy dåg' - res = safe_str(s) - self.assertIsInstance(res, str) - self.assertEqual(len(s), len(res)) - - def test_when_not_string(self): - o = object() - self.assertEqual(safe_str(o), repr(o)) - - def test_when_unrepresentable(self): - - class O(object): - - def __repr__(self): - raise KeyError('foo') - - self.assertIn('<Unrepresentable', safe_str(O())) diff --git a/kombu/tests/utils/test_scheduling.py b/kombu/tests/utils/test_scheduling.py deleted file mode 100644 index ed668714..00000000 --- a/kombu/tests/utils/test_scheduling.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -from kombu.utils.scheduling import FairCycle, cycle_by_name - -from kombu.tests.case import Case - - -class MyEmpty(Exception): - pass - - -def consume(fun, n): - r = [] - for i in range(n): - r.append(fun()) - return r - - -class test_FairCycle(Case): - - def test_cycle(self): - resources = ['a', 'b', 'c', 'd', 'e'] - - def echo(r, timeout=None): - return r - - # cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat] - cycle = FairCycle(echo, resources, MyEmpty) - for i in range(len(resources)): - self.assertEqual(cycle.get(), (resources[i], - resources[i])) - for i in range(len(resources)): - self.assertEqual(cycle.get(), (resources[i], - resources[i])) - - def test_cycle_breaks(self): - resources = ['a', 'b', 'c', 'd', 'e'] - - def echo(r): - if r == 'c': - raise MyEmpty(r) - return r - - cycle = FairCycle(echo, resources, MyEmpty) - self.assertEqual( - consume(cycle.get, len(resources)), - [('a', 'a'), ('b', 'b'), ('d', 'd'), - ('e', 'e'), ('a', 'a')], - ) - self.assertEqual( - consume(cycle.get, len(resources)), - [('b', 'b'), ('d', 'd'), ('e', 'e'), - ('a', 'a'), ('b', 'b')], - ) - cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty) - with self.assertRaises(MyEmpty): - consume(cycle2.get, 3) - - def test_cycle_no_resources(self): - cycle = FairCycle(None, [], MyEmpty) - cycle.pos = 10 - - with self.assertRaises(MyEmpty): - cycle._next() - - def test__repr__(self): - self.assertTrue(repr(FairCycle(lambda x: x, [1, 2, 3], MyEmpty))) - - -class test_round_robin_cycle(Case): - - def test_round_robin_cycle(self): - it = cycle_by_name('round_robin')(['A', 'B', 'C']) - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('B') - self.assertListEqual(it.consume(3), ['A', 'C', 'B']) - it.rotate('A') - self.assertListEqual(it.consume(3), ['C', 'B', 'A']) - it.rotate('A') - self.assertListEqual(it.consume(3), ['C', 'B', 'A']) - it.rotate('C') - self.assertListEqual(it.consume(3), ['B', 'A', 'C']) - - -class test_priority_cycle(Case): - - def test_priority_cycle(self): - it = cycle_by_name('priority')(['A', 'B', 'C']) - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('B') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('A') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('A') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('C') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - - -class test_sorted_cycle(Case): - - def test_sorted_cycle(self): - it = cycle_by_name('sorted')(['B', 'C', 'A']) - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('B') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('A') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('A') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) - it.rotate('C') - self.assertListEqual(it.consume(3), ['A', 'B', 'C']) diff --git a/kombu/tests/utils/test_url.py b/kombu/tests/utils/test_url.py deleted file mode 100644 index 67c7efce..00000000 --- a/kombu/tests/utils/test_url.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -from kombu.utils.url import as_url, parse_url, maybe_sanitize_url - -from kombu.tests.case import Case - - -class test_parse_url(Case): - - def test_parse_url(self): - result = parse_url('amqp://user:pass@localhost:5672/my/vhost') - self.assertDictEqual(result, { - 'transport': 'amqp', - 'userid': 'user', - 'password': 'pass', - 'hostname': 'localhost', - 'port': 5672, - 'virtual_host': 'my/vhost', - }) - - -class test_as_url(Case): - - def test_as_url(self): - self.assertEqual(as_url('https'), 'https:///') - self.assertEqual(as_url('https', 'e.com'), 'https://e.com/') - self.assertEqual(as_url('https', 'e.com', 80), 'https://e.com:80/') - self.assertEqual( - as_url('https', 'e.com', 80, 'u'), 'https://u@e.com:80/', - ) - self.assertEqual( - as_url('https', 'e.com', 80, 'u', 'p'), 'https://u:p@e.com:80/', - ) - self.assertEqual( - as_url('https', 'e.com', 80, None, 'p'), 'https://:p@e.com:80/', - ) - self.assertEqual( - as_url('https', 'e.com', 80, None, 'p', '/foo'), - 'https://:p@e.com:80//foo', - ) - - -class test_maybe_sanitize_url(Case): - - def test_maybe_sanitize_url(self): - self.assertEqual(maybe_sanitize_url('foo'), 'foo') - self.assertEqual( - maybe_sanitize_url('http://u:p@e.com//foo'), - 'http://u:**@e.com//foo', - ) diff --git a/kombu/tests/utils/test_utils.py b/kombu/tests/utils/test_utils.py deleted file mode 100644 index 84adf6c9..00000000 --- a/kombu/tests/utils/test_utils.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import absolute_import, unicode_literals - -from kombu import version_info_t -from kombu.utils.text import version_string_as_tuple - -from kombu.tests.case import Case - - -class test_kombu_module(Case): - - def test_dir(self): - import kombu - self.assertTrue(dir(kombu)) - - -class test_version_string_as_tuple(Case): - - def test_versions(self): - self.assertTupleEqual( - version_string_as_tuple('3'), - version_info_t(3, 0, 0, '', ''), - ) - self.assertTupleEqual( - version_string_as_tuple('3.3'), - version_info_t(3, 3, 0, '', ''), - ) - self.assertTupleEqual( - version_string_as_tuple('3.3.1'), - version_info_t(3, 3, 1, '', ''), - ) - self.assertTupleEqual( - version_string_as_tuple('3.3.1a3'), - version_info_t(3, 3, 1, 'a3', ''), - ) - self.assertTupleEqual( - version_string_as_tuple('3.3.1a3-40c32'), - version_info_t(3, 3, 1, 'a3', '40c32'), - ) - self.assertEqual( - version_string_as_tuple('3.3.1.a3.40c32'), - version_info_t(3, 3, 1, 'a3', '40c32'), - ) diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index b6a818fe..d7546a60 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -1073,15 +1073,14 @@ class SentinelChannel(Channel): additional_params = connparams.copy() - del additional_params['host'] - del additional_params['port'] + additional_params.pop('host', None) + additional_params.pop('port', None) sentinel_inst = sentinel.Sentinel( [(connparams['host'], connparams['port'])], min_other_sentinels=getattr(self, 'min_other_sentinels', 0), sentinel_kwargs=getattr(self, 'sentinel_kwargs', {}), - **additional_params - ) + **additional_params) master_name = getattr(self, 'master_name', None) diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index ccfab350..6960104c 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -1,988 +1,13 @@ -"""Virtual transport implementation. - -Emulates the AMQ API for non-AMQ transports. -""" -from __future__ import absolute_import, print_function, unicode_literals - -import base64 -import socket -import sys -import warnings - -from array import array -from collections import OrderedDict, defaultdict, namedtuple -from itertools import count -from multiprocessing.util import Finalize -from time import sleep - -from amqp.protocol import queue_declare_ok_t - -from kombu.exceptions import ResourceError, ChannelError -from kombu.five import Empty, items, monotonic -from kombu.log import get_logger -from kombu.utils.encoding import str_to_bytes, bytes_to_str -from kombu.utils.div import emergency_dump_state -from kombu.utils.scheduling import FairCycle -from kombu.utils.uuid import uuid - -from kombu.transport import base - -from .exchange import STANDARD_EXCHANGE_TYPES - -ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H' - -UNDELIVERABLE_FMT = """\ -Message could not be delivered: No queues bound to exchange {exchange!r} \ -using binding key {routing_key!r}. -""" - -NOT_EQUIVALENT_FMT = """\ -Cannot redeclare exchange {0!r} in vhost {1!r} with \ -different type, durable, autodelete or arguments value.\ -""" - -W_NO_CONSUMERS = """\ -Requeuing undeliverable message for queue %r: No consumers.\ -""" - -RESTORING_FMT = 'Restoring {0!r} unacknowledged message(s)' -RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}' - -logger = get_logger(__name__) - -#: Key format used for queue argument lookups in BrokerState.bindings. -binding_key_t = namedtuple('binding_key_t', ( - 'queue', 'exchange', 'routing_key', -)) - -#: BrokerState.queue_bindings generates tuples in this format. -queue_binding_t = namedtuple('queue_binding_t', ( - 'exchange', 'routing_key', 'arguments', -)) - - -class Base64(object): - - def encode(self, s): - return bytes_to_str(base64.b64encode(str_to_bytes(s))) - - def decode(self, s): - return base64.b64decode(str_to_bytes(s)) - - -class NotEquivalentError(Exception): - """Entity declaration is not equivalent to the previous declaration.""" - pass - - -class UndeliverableWarning(UserWarning): - """The message could not be delivered to a queue.""" - pass - - -class BrokerState(object): - - #: Mapping of exchange name to - #: :class:`kombu.transport.virtual.exchange.ExchangeType` - exchanges = None - - #: This is the actual bindings registry, used to store bindings and to - #: test 'in' relationships in constant time. It has the following - #: structure:: - #: - #: { - #: (queue, exchange, routing_key): arguments, - #: # ..., - #: } - bindings = None - - #: The queue index is used to access directly (constant time) - #: all the bindings of a certain queue. It has the following structure:: - #: { - #: queue: { - #: (queue, exchange, routing_key), - #: # ..., - #: }, - #: # ..., - #: } - queue_index = None - - def __init__(self, exchanges=None): - self.exchanges = {} if exchanges is None else exchanges - self.bindings = {} - self.queue_index = defaultdict(set) - - def clear(self): - self.exchanges.clear() - self.bindings.clear() - self.queue_index.clear() - - def has_binding(self, queue, exchange, routing_key): - return (queue, exchange, routing_key) in self.bindings - - def binding_declare(self, queue, exchange, routing_key, arguments): - key = binding_key_t(queue, exchange, routing_key) - self.bindings.setdefault(key, arguments) - self.queue_index[queue].add(key) - - def binding_delete(self, queue, exchange, routing_key): - key = binding_key_t(queue, exchange, routing_key) - try: - del self.bindings[key] - except KeyError: - pass - else: - self.queue_index[queue].remove(key) - - def queue_bindings_delete(self, queue): - try: - bindings = self.queue_index.pop(queue) - except KeyError: - pass - else: - [self.bindings.pop(binding, None) for binding in bindings] - - def queue_bindings(self, queue): - return ( - queue_binding_t(key.exchange, key.routing_key, self.bindings[key]) - for key in self.queue_index[queue] - ) - - -class QoS(object): - """Quality of Service guarantees. - - Only supports `prefetch_count` at this point. - - Arguments: - channel (ChannelT): Connection channel. - prefetch_count (int): Initial prefetch count (defaults to 0). - """ - - #: current prefetch count value - prefetch_count = 0 - - #: :class:`~collections.OrderedDict` of active messages. - #: *NOTE*: Can only be modified by the consuming thread. - _delivered = None - - #: acks can be done by other threads than the consuming thread. - #: Instead of a mutex, which doesn't perform well here, we mark - #: the delivery tags as dirty, so subsequent calls to append() can remove - #: them. - _dirty = None - - #: If disabled, unacked messages won't be restored at shutdown. - restore_at_shutdown = True - - def __init__(self, channel, prefetch_count=0): - self.channel = channel - self.prefetch_count = prefetch_count or 0 - - self._delivered = OrderedDict() - self._delivered.restored = False - self._dirty = set() - self._quick_ack = self._dirty.add - self._quick_append = self._delivered.__setitem__ - self._on_collect = Finalize( - self, self.restore_unacked_once, exitpriority=1, - ) - - def can_consume(self): - """Return true if the channel can be consumed from. - - Used to ensure the client adhers to currently active - prefetch limits. - """ - pcount = self.prefetch_count - return not pcount or len(self._delivered) - len(self._dirty) < pcount - - def can_consume_max_estimate(self): - """Returns the maximum number of messages allowed to be returned. - - Returns an estimated number of messages that a consumer may be allowed - to consume at once from the broker. This is used for services where - bulk 'get message' calls are preferred to many individual 'get message' - calls - like SQS. - - Returns: - int: greater than zero. - """ - pcount = self.prefetch_count - if pcount: - return max(pcount - (len(self._delivered) - len(self._dirty)), 0) - - def append(self, message, delivery_tag): - """Append message to transactional state.""" - if self._dirty: - self._flush() - self._quick_append(delivery_tag, message) - - def get(self, delivery_tag): - return self._delivered[delivery_tag] - - def _flush(self): - """Flush dirty (acked/rejected) tags from.""" - dirty = self._dirty - delivered = self._delivered - while 1: - try: - dirty_tag = dirty.pop() - except KeyError: - break - delivered.pop(dirty_tag, None) - - def ack(self, delivery_tag): - """Acknowledge message and remove from transactional state.""" - self._quick_ack(delivery_tag) - - def reject(self, delivery_tag, requeue=False): - """Remove from transactional state and requeue message.""" - if requeue: - self.channel._restore_at_beginning(self._delivered[delivery_tag]) - self._quick_ack(delivery_tag) - - def restore_unacked(self): - """Restore all unacknowledged messages.""" - self._flush() - delivered = self._delivered - errors = [] - restore = self.channel._restore - pop_message = delivered.popitem - - while delivered: - try: - _, message = pop_message() - except KeyError: # pragma: no cover - break - - try: - restore(message) - except BaseException as exc: - errors.append((exc, message)) - delivered.clear() - return errors - - def restore_unacked_once(self, stderr=None): - """Restores all unacknowledged messages at shutdown/gc collect. - - Note: - Can only be called once for each instance, subsequent - calls will be ignored. - """ - self._on_collect.cancel() - self._flush() - stderr = sys.stderr if stderr is None else stderr - state = self._delivered - - if not self.restore_at_shutdown or not self.channel.do_restore: - return - if getattr(state, 'restored', None): - assert not state - return - try: - if state: - print(RESTORING_FMT.format(len(self._delivered)), - file=stderr) - unrestored = self.restore_unacked() - - if unrestored: - errors, messages = list(zip(*unrestored)) - print(RESTORE_PANIC_FMT.format(len(errors), errors), - file=stderr) - emergency_dump_state(messages, stderr=stderr) - finally: - state.restored = True - - def restore_visible(self, *args, **kwargs): - """Restore any pending unackwnowledged messages for visibility_timeout - style implementations. - - Note: - This is implementation optional, and currently only - used by the Redis transport. - """ - pass - - -class Message(base.Message): - - def __init__(self, channel, payload, **kwargs): - self._raw = payload - properties = payload['properties'] - body = payload.get('body') - if body: - body = channel.decode_body(body, properties.get('body_encoding')) - kwargs.update({ - 'body': body, - 'delivery_tag': properties['delivery_tag'], - 'content_type': payload.get('content-type'), - 'content_encoding': payload.get('content-encoding'), - 'headers': payload.get('headers'), - 'properties': properties, - 'delivery_info': properties.get('delivery_info'), - 'postencode': 'utf-8', - }) - super(Message, self).__init__(channel, **kwargs) - - def serializable(self): - props = self.properties - body, _ = self.channel.encode_body(self.body, - props.get('body_encoding')) - headers = dict(self.headers) - # remove compression header - headers.pop('compression', None) - return { - 'body': body, - 'properties': props, - 'content-type': self.content_type, - 'content-encoding': self.content_encoding, - 'headers': headers, - } - - -class AbstractChannel(object): - """This is an abstract class defining the channel methods - you'd usually want to implement in a virtual channel. - - Note: - Do not subclass directly, but rather inherit - from :class:`Channel`. - """ - - def _get(self, queue, timeout=None): - """Get next message from `queue`.""" - raise NotImplementedError('Virtual channels must implement _get') - - def _put(self, queue, message): - """Put `message` onto `queue`.""" - raise NotImplementedError('Virtual channels must implement _put') - - def _purge(self, queue): - """Remove all messages from `queue`.""" - raise NotImplementedError('Virtual channels must implement _purge') - - def _size(self, queue): - """Return the number of messages in `queue` as an :class:`int`.""" - return 0 - - def _delete(self, queue, *args, **kwargs): - """Delete `queue`. - - Note: - This just purges the queue, if you need to do more you can - override this method. - """ - self._purge(queue) - - def _new_queue(self, queue, **kwargs): - """Create new queue. - - Note: - Your transport can override this method if it needs - to do something whenever a new queue is declared. - """ - pass - - def _has_queue(self, queue, **kwargs): - """Verify that queue exists. - - Returns: - bool: Should return :const:`True` if the queue exists - or :const:`False` otherwise. - """ - return True - - def _poll(self, cycle, timeout=None): - """Poll a list of queues for available messages.""" - return cycle.get() - - -class Channel(AbstractChannel, base.StdChannel): - """Virtual channel. - - Arguments: - connection (ConnectionT): The transport instance this - channel is part of. - """ - #: message class used. - Message = Message - - #: QoS class used. - QoS = QoS - - #: flag to restore unacked messages when channel - #: goes out of scope. - do_restore = True - - #: mapping of exchange types and corresponding classes. - exchange_types = dict(STANDARD_EXCHANGE_TYPES) - - #: flag set if the channel supports fanout exchanges. - supports_fanout = False - - #: Binary <-> ASCII codecs. - codecs = {'base64': Base64()} - - #: Default body encoding. - #: NOTE: ``transport_options['body_encoding']`` will override this value. - body_encoding = 'base64' - - #: counter used to generate delivery tags for this channel. - _delivery_tags = count(1) - - #: Optional queue where messages with no route is delivered. - #: Set by ``transport_options['deadletter_queue']``. - deadletter_queue = None - - # List of options to transfer from :attr:`transport_options`. - from_transport_options = ('body_encoding', 'deadletter_queue') - - # Priority defaults - default_priority = 0 - min_priority = 0 - max_priority = 9 - - def __init__(self, connection, **kwargs): - self.connection = connection - self._consumers = set() - self._cycle = None - self._tag_to_queue = {} - self._active_queues = [] - self._qos = None - self.closed = False - - # instantiate exchange types - self.exchange_types = dict( - (typ, cls(self)) for typ, cls in items(self.exchange_types) - ) - - try: - self.channel_id = self.connection._avail_channel_ids.pop() - except IndexError: - raise ResourceError( - 'No free channel ids, current={0}, channel_max={1}'.format( - len(self.connection.channels), - self.connection.channel_max), (20, 10), - ) - - topts = self.connection.client.transport_options - for opt_name in self.from_transport_options: - try: - setattr(self, opt_name, topts[opt_name]) - except KeyError: - pass - - def exchange_declare(self, exchange=None, type='direct', durable=False, - auto_delete=False, arguments=None, - nowait=False, passive=False): - """Declare exchange.""" - type = type or 'direct' - exchange = exchange or 'amq.%s' % type - if passive: - if exchange not in self.state.exchanges: - raise ChannelError( - 'NOT_FOUND - no exchange {0!r} in vhost {1!r}'.format( - exchange, self.connection.client.virtual_host or '/'), - (50, 10), 'Channel.exchange_declare', '404', - ) - return - try: - prev = self.state.exchanges[exchange] - if not self.typeof(exchange).equivalent(prev, exchange, type, - durable, auto_delete, - arguments): - raise NotEquivalentError(NOT_EQUIVALENT_FMT.format( - exchange, self.connection.client.virtual_host or '/')) - except KeyError: - self.state.exchanges[exchange] = { - 'type': type, - 'durable': durable, - 'auto_delete': auto_delete, - 'arguments': arguments or {}, - 'table': [], - } - - def exchange_delete(self, exchange, if_unused=False, nowait=False): - """Delete `exchange` and all its bindings.""" - for rkey, _, queue in self.get_table(exchange): - self.queue_delete(queue, if_unused=True, if_empty=True) - self.state.exchanges.pop(exchange, None) - - def queue_declare(self, queue=None, passive=False, **kwargs): - """Declare queue.""" - queue = queue or 'amq.gen-%s' % uuid() - if passive and not self._has_queue(queue, **kwargs): - raise ChannelError( - 'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format( - queue, self.connection.client.virtual_host or '/'), - (50, 10), 'Channel.queue_declare', '404', - ) - else: - self._new_queue(queue, **kwargs) - return queue_declare_ok_t(queue, self._size(queue), 0) - - def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs): - """Delete queue.""" - if if_empty and self._size(queue): - return - for exchange, routing_key, args in self.state.queue_bindings(queue): - meta = self.typeof(exchange).prepare_bind( - queue, exchange, routing_key, args, - ) - self._delete(queue, exchange, *meta, **kwargs) - self.state.queue_bindings_delete(queue) - - def after_reply_message_received(self, queue): - self.queue_delete(queue) - - def exchange_bind(self, destination, source='', routing_key='', - nowait=False, arguments=None): - raise NotImplementedError('transport does not support exchange_bind') - - def exchange_unbind(self, destination, source='', routing_key='', - nowait=False, arguments=None): - raise NotImplementedError('transport does not support exchange_unbind') - - def queue_bind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): - """Bind `queue` to `exchange` with `routing key`.""" - exchange = exchange or 'amq.direct' - if self.state.has_binding(queue, exchange, routing_key): - return - # Add binding: - self.state.binding_declare(queue, exchange, routing_key, arguments) - # Update exchange's routing table: - table = self.state.exchanges[exchange].setdefault('table', []) - meta = self.typeof(exchange).prepare_bind( - queue, exchange, routing_key, arguments, - ) - table.append(meta) - if self.supports_fanout: - self._queue_bind(exchange, *meta) - - def queue_unbind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): - # Remove queue binding: - self.state.binding_delete(queue, exchange, routing_key) - try: - table = self.get_table(exchange) - except KeyError: - return - binding_meta = self.typeof(exchange).prepare_bind( - queue, exchange, routing_key, arguments, - ) - # TODO: the complexity of this operation is O(number of bindings). - # Should be optimized. Modifying table in place. - table[:] = [meta for meta in table if meta != binding_meta] - - def list_bindings(self): - return ((queue, exchange, rkey) - for exchange in self.state.exchanges - for rkey, pattern, queue in self.get_table(exchange)) - - def queue_purge(self, queue, **kwargs): - """Remove all ready messages from queue.""" - return self._purge(queue) - - def _next_delivery_tag(self): - return uuid() - - def basic_publish(self, message, exchange, routing_key, **kwargs): - """Publish message.""" - message['body'], body_encoding = self.encode_body( - message['body'], self.body_encoding, - ) - props = message['properties'] - props.update( - body_encoding=body_encoding, - delivery_tag=self._next_delivery_tag(), - ) - props['delivery_info'].update( - exchange=exchange, - routing_key=routing_key, - ) - if exchange: - return self.typeof(exchange).deliver( - message, exchange, routing_key, **kwargs - ) - # anon exchange: routing_key is the destination queue - return self._put(routing_key, message, **kwargs) - - def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): - """Consume from `queue`""" - self._tag_to_queue[consumer_tag] = queue - self._active_queues.append(queue) - - def _callback(raw_message): - message = self.Message(self, raw_message) - if not no_ack: - self.qos.append(message, message.delivery_tag) - return callback(message) - - self.connection._callbacks[queue] = _callback - self._consumers.add(consumer_tag) - - self._reset_cycle() - - def basic_cancel(self, consumer_tag): - """Cancel consumer by consumer tag.""" - if consumer_tag in self._consumers: - self._consumers.remove(consumer_tag) - self._reset_cycle() - queue = self._tag_to_queue.pop(consumer_tag, None) - try: - self._active_queues.remove(queue) - except ValueError: - pass - self.connection._callbacks.pop(queue, None) - - def basic_get(self, queue, no_ack=False, **kwargs): - """Get message by direct access (synchronous).""" - try: - message = self.Message(self, self._get(queue)) - if not no_ack: - self.qos.append(message, message.delivery_tag) - return message - except Empty: - pass - - def basic_ack(self, delivery_tag, multiple=False): - """Acknowledge message.""" - self.qos.ack(delivery_tag) - - def basic_recover(self, requeue=False): - """Recover unacked messages.""" - if requeue: - return self.qos.restore_unacked() - raise NotImplementedError('Does not support recover(requeue=False)') - - def basic_reject(self, delivery_tag, requeue=False): - """Reject message.""" - self.qos.reject(delivery_tag, requeue=requeue) - - def basic_qos(self, prefetch_size=0, prefetch_count=0, - apply_global=False): - """Change QoS settings for this channel. - - Note: - Only `prefetch_count` is supported. - """ - self.qos.prefetch_count = prefetch_count - - def get_exchanges(self): - return list(self.state.exchanges) - - def get_table(self, exchange): - """Get table of bindings for `exchange`.""" - return self.state.exchanges[exchange]['table'] - - def typeof(self, exchange, default='direct'): - """Get the exchange type instance for `exchange`.""" - try: - type = self.state.exchanges[exchange]['type'] - except KeyError: - type = default - return self.exchange_types[type] - - def _lookup(self, exchange, routing_key, default=None): - """Find all queues matching `routing_key` for the given `exchange`. - - Returns: - str: queue name -- must return the string `default` - if no queues matched. - """ - if default is None: - default = self.deadletter_queue - if not exchange: # anon exchange - return [routing_key or default] - - try: - R = self.typeof(exchange).lookup( - self.get_table(exchange), - exchange, routing_key, default, - ) - except KeyError: - R = [] - - if not R and default is not None: - warnings.warn(UndeliverableWarning(UNDELIVERABLE_FMT.format( - exchange=exchange, routing_key=routing_key)), - ) - self._new_queue(default) - R = [default] - return R - - def _restore(self, message): - """Redeliver message to its original destination.""" - delivery_info = message.delivery_info - message = message.serializable() - message['redelivered'] = True - for queue in self._lookup( - delivery_info['exchange'], delivery_info['routing_key']): - self._put(queue, message) - - def _restore_at_beginning(self, message): - return self._restore(message) - - def drain_events(self, timeout=None): - if self._consumers and self.qos.can_consume(): - if hasattr(self, '_get_many'): - return self._get_many(self._active_queues, timeout=timeout) - return self._poll(self.cycle, timeout=timeout) - raise Empty() - - def message_to_python(self, raw_message): - """Convert raw message to :class:`Message` instance.""" - if not isinstance(raw_message, self.Message): - return self.Message(self, payload=raw_message) - return raw_message - - def prepare_message(self, body, priority=None, content_type=None, - content_encoding=None, headers=None, properties=None): - """Prepare message data.""" - properties = properties or {} - properties.setdefault('delivery_info', {}) - properties.setdefault('priority', priority or self.default_priority) - - return {'body': body, - 'content-encoding': content_encoding, - 'content-type': content_type, - 'headers': headers or {}, - 'properties': properties or {}} - - def flow(self, active=True): - """Enable/disable message flow. - - Raises: - NotImplementedError: as flow - is not implemented by the base virtual implementation. - """ - raise NotImplementedError('virtual channels do not support flow.') - - def close(self): - """Close channel, cancel all consumers, and requeue unacked - messages.""" - if not self.closed: - self.closed = True - for consumer in list(self._consumers): - self.basic_cancel(consumer) - if self._qos: - self._qos.restore_unacked_once() - if self._cycle is not None: - self._cycle.close() - self._cycle = None - if self.connection is not None: - self.connection.close_channel(self) - self.exchange_types = None - - def encode_body(self, body, encoding=None): - if encoding: - return self.codecs.get(encoding).encode(body), encoding - return body, encoding - - def decode_body(self, body, encoding=None): - if encoding: - return self.codecs.get(encoding).decode(body) - return body - - def _reset_cycle(self): - self._cycle = FairCycle(self._get, self._active_queues, Empty) - - def __enter__(self): - return self - - def __exit__(self, *exc_info): - self.close() - - @property - def state(self): - """Broker state containing exchanges and bindings.""" - return self.connection.state - - @property - def qos(self): - """:class:`QoS` manager for this channel.""" - if self._qos is None: - self._qos = self.QoS(self) - return self._qos - - @property - def cycle(self): - if self._cycle is None: - self._reset_cycle() - return self._cycle - - def _get_message_priority(self, message, reverse=False): - """Get priority from message and limit the value within a - boundary of 0 to 9. - - Note: - Higher value has more priority. - """ - try: - priority = max( - min(int(message['properties']['priority']), - self.max_priority), - self.min_priority, - ) - except (TypeError, ValueError, KeyError): - priority = self.default_priority - - return (self.max_priority - priority) if reverse else priority - - -class Management(base.Management): - - def __init__(self, transport): - super(Management, self).__init__(transport) - self.channel = transport.client.channel() - - def get_bindings(self): - return [dict(destination=q, source=e, routing_key=r) - for q, e, r in self.channel.list_bindings()] - - def close(self): - self.channel.close() - - -class Transport(base.Transport): - """Virtual transport. - - Arguments: - client (kombu.Connection): The client this is a transport for. - """ - Channel = Channel - Cycle = FairCycle - Management = Management - - #: :class:`BrokerState` containing declared exchanges and - #: bindings (set by constructor). - state = BrokerState() - - #: :class:`~kombu.utils.scheduling.FairCycle` instance - #: used to fairly drain events from channels (set by constructor). - cycle = None - - #: port number used when no port is specified. - default_port = None - - #: active channels. - channels = None - - #: queue/callback map. - _callbacks = None - - #: Time to sleep between unsuccessful polls. - polling_interval = 1.0 - - #: Max number of channels - channel_max = 65535 - - implements = base.Transport.implements.extend( - async=False, - exchange_type=frozenset(['direct', 'topic']), - heartbeats=False, - ) - - def __init__(self, client, **kwargs): - self.client = client - self.channels = [] - self._avail_channels = [] - self._callbacks = {} - self.cycle = self.Cycle(self._drain_channel, self.channels, Empty) - polling_interval = client.transport_options.get('polling_interval') - if polling_interval is not None: - self.polling_interval = polling_interval - self._avail_channel_ids = array( - ARRAY_TYPE_H, range(self.channel_max, 0, -1), - ) - - def create_channel(self, connection): - try: - return self._avail_channels.pop() - except IndexError: - channel = self.Channel(connection) - self.channels.append(channel) - return channel - - def close_channel(self, channel): - try: - self._avail_channel_ids.append(channel.channel_id) - try: - self.channels.remove(channel) - except ValueError: - pass - finally: - channel.connection = None - - def establish_connection(self): - # creates channel to verify connection. - # this channel is then used as the next requested channel. - # (returned by ``create_channel``). - self._avail_channels.append(self.create_channel(self)) - return self # for drain events - - def close_connection(self, connection): - self.cycle.close() - for l in self._avail_channels, self.channels: - while l: - try: - channel = l.pop() - except LookupError: # pragma: no cover - pass - else: - channel.close() - - def drain_events(self, connection, timeout=None): - loop = 0 - time_start = monotonic() - get = self.cycle.get - polling_interval = self.polling_interval - while 1: - try: - item, channel = get(timeout=timeout) - except Empty: - if timeout and monotonic() - time_start >= timeout: - raise socket.timeout() - loop += 1 - if polling_interval is not None: - sleep(polling_interval) - else: - break - self._deliver(*item) - - def _deliver(self, message, queue): - if not queue: - raise KeyError( - 'Received message without destination queue: {0}'.format( - message)) - try: - callback = self._callbacks[queue] - except KeyError: - logger.warn(W_NO_CONSUMERS, queue) - self._reject_inbound_message(message) - else: - callback(message) - - def _reject_inbound_message(self, raw_message): - for channel in self.channels: - if channel: - message = channel.Message(channel, raw_message) - channel.qos.append(message, message.delivery_tag) - channel.basic_reject(message.delivery_tag, requeue=True) - break - - def on_message_ready(self, channel, message, queue): - if not queue or queue not in self._callbacks: - raise KeyError( - 'Message for queue {0!r} without consumers: {1}'.format( - queue, message)) - self._callbacks[queue](message) - - def _drain_channel(self, channel, timeout=None): - return channel.drain_events(timeout=timeout) - - @property - def default_connection_params(self): - return {'port': self.default_port, 'hostname': 'localhost'} +from __future__ import absolute_import, unicode_literals + +from .base import ( + Base64, NotEquivalentError, UndeliverableWarning, BrokerState, + QoS, Message, AbstractChannel, Channel, Management, Transport, + Empty, binding_key_t, queue_binding_t, +) + +__all__ = [ + 'Base64', 'NotEquivalentError', 'UndeliverableWarning', 'BrokerState', + 'QoS', 'Message', 'AbstractChannel', 'Channel', 'Management', 'Transport', + 'Empty', 'binding_key_t', 'queue_binding_t', +] diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py new file mode 100644 index 00000000..3811ddb5 --- /dev/null +++ b/kombu/transport/virtual/base.py @@ -0,0 +1,989 @@ +"""Virtual transport implementation. + +Emulates the AMQ API for non-AMQ transports. +""" +from __future__ import absolute_import, print_function, unicode_literals + +import base64 +import socket +import sys +import warnings + +from array import array +from collections import OrderedDict, defaultdict, namedtuple +from itertools import count +from multiprocessing.util import Finalize +from time import sleep + +from amqp.protocol import queue_declare_ok_t + +from kombu.exceptions import ResourceError, ChannelError +from kombu.five import Empty, items, monotonic +from kombu.log import get_logger +from kombu.utils.encoding import str_to_bytes, bytes_to_str +from kombu.utils.div import emergency_dump_state +from kombu.utils.scheduling import FairCycle +from kombu.utils.uuid import uuid + +from kombu.transport import base + +from .exchange import STANDARD_EXCHANGE_TYPES + +ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H' + +UNDELIVERABLE_FMT = """\ +Message could not be delivered: No queues bound to exchange {exchange!r} \ +using binding key {routing_key!r}. +""" + +NOT_EQUIVALENT_FMT = """\ +Cannot redeclare exchange {0!r} in vhost {1!r} with \ +different type, durable, autodelete or arguments value.\ +""" + +W_NO_CONSUMERS = """\ +Requeuing undeliverable message for queue %r: No consumers.\ +""" + +RESTORING_FMT = 'Restoring {0!r} unacknowledged message(s)' +RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}' + +logger = get_logger(__name__) + +#: Key format used for queue argument lookups in BrokerState.bindings. +binding_key_t = namedtuple('binding_key_t', ( + 'queue', 'exchange', 'routing_key', +)) + +#: BrokerState.queue_bindings generates tuples in this format. +queue_binding_t = namedtuple('queue_binding_t', ( + 'exchange', 'routing_key', 'arguments', +)) + + +class Base64(object): + + def encode(self, s): + return bytes_to_str(base64.b64encode(str_to_bytes(s))) + + def decode(self, s): + return base64.b64decode(str_to_bytes(s)) + + +class NotEquivalentError(Exception): + """Entity declaration is not equivalent to the previous declaration.""" + pass + + +class UndeliverableWarning(UserWarning): + """The message could not be delivered to a queue.""" + pass + + +class BrokerState(object): + + #: Mapping of exchange name to + #: :class:`kombu.transport.virtual.exchange.ExchangeType` + exchanges = None + + #: This is the actual bindings registry, used to store bindings and to + #: test 'in' relationships in constant time. It has the following + #: structure:: + #: + #: { + #: (queue, exchange, routing_key): arguments, + #: # ..., + #: } + bindings = None + + #: The queue index is used to access directly (constant time) + #: all the bindings of a certain queue. It has the following structure:: + #: { + #: queue: { + #: (queue, exchange, routing_key), + #: # ..., + #: }, + #: # ..., + #: } + queue_index = None + + def __init__(self, exchanges=None): + self.exchanges = {} if exchanges is None else exchanges + self.bindings = {} + self.queue_index = defaultdict(set) + + def clear(self): + self.exchanges.clear() + self.bindings.clear() + self.queue_index.clear() + + def has_binding(self, queue, exchange, routing_key): + return (queue, exchange, routing_key) in self.bindings + + def binding_declare(self, queue, exchange, routing_key, arguments): + key = binding_key_t(queue, exchange, routing_key) + self.bindings.setdefault(key, arguments) + self.queue_index[queue].add(key) + + def binding_delete(self, queue, exchange, routing_key): + key = binding_key_t(queue, exchange, routing_key) + try: + del self.bindings[key] + except KeyError: + pass + else: + self.queue_index[queue].remove(key) + + def queue_bindings_delete(self, queue): + try: + bindings = self.queue_index.pop(queue) + except KeyError: + pass + else: + [self.bindings.pop(binding, None) for binding in bindings] + + def queue_bindings(self, queue): + return ( + queue_binding_t(key.exchange, key.routing_key, self.bindings[key]) + for key in self.queue_index[queue] + ) + + +class QoS(object): + """Quality of Service guarantees. + + Only supports `prefetch_count` at this point. + + Arguments: + channel (ChannelT): Connection channel. + prefetch_count (int): Initial prefetch count (defaults to 0). + """ + + #: current prefetch count value + prefetch_count = 0 + + #: :class:`~collections.OrderedDict` of active messages. + #: *NOTE*: Can only be modified by the consuming thread. + _delivered = None + + #: acks can be done by other threads than the consuming thread. + #: Instead of a mutex, which doesn't perform well here, we mark + #: the delivery tags as dirty, so subsequent calls to append() can remove + #: them. + _dirty = None + + #: If disabled, unacked messages won't be restored at shutdown. + restore_at_shutdown = True + + def __init__(self, channel, prefetch_count=0): + self.channel = channel + self.prefetch_count = prefetch_count or 0 + + self._delivered = OrderedDict() + self._delivered.restored = False + self._dirty = set() + self._quick_ack = self._dirty.add + self._quick_append = self._delivered.__setitem__ + self._on_collect = Finalize( + self, self.restore_unacked_once, exitpriority=1, + ) + + def can_consume(self): + """Return true if the channel can be consumed from. + + Used to ensure the client adhers to currently active + prefetch limits. + """ + pcount = self.prefetch_count + return not pcount or len(self._delivered) - len(self._dirty) < pcount + + def can_consume_max_estimate(self): + """Returns the maximum number of messages allowed to be returned. + + Returns an estimated number of messages that a consumer may be allowed + to consume at once from the broker. This is used for services where + bulk 'get message' calls are preferred to many individual 'get message' + calls - like SQS. + + Returns: + int: greater than zero. + """ + pcount = self.prefetch_count + if pcount: + return max(pcount - (len(self._delivered) - len(self._dirty)), 0) + + def append(self, message, delivery_tag): + """Append message to transactional state.""" + if self._dirty: + self._flush() + self._quick_append(delivery_tag, message) + + def get(self, delivery_tag): + return self._delivered[delivery_tag] + + def _flush(self): + """Flush dirty (acked/rejected) tags from.""" + dirty = self._dirty + delivered = self._delivered + while 1: + try: + dirty_tag = dirty.pop() + except KeyError: + break + delivered.pop(dirty_tag, None) + + def ack(self, delivery_tag): + """Acknowledge message and remove from transactional state.""" + self._quick_ack(delivery_tag) + + def reject(self, delivery_tag, requeue=False): + """Remove from transactional state and requeue message.""" + if requeue: + self.channel._restore_at_beginning(self._delivered[delivery_tag]) + self._quick_ack(delivery_tag) + + def restore_unacked(self): + """Restore all unacknowledged messages.""" + self._flush() + delivered = self._delivered + errors = [] + restore = self.channel._restore + pop_message = delivered.popitem + + while delivered: + try: + _, message = pop_message() + except KeyError: # pragma: no cover + break + + try: + restore(message) + except BaseException as exc: + errors.append((exc, message)) + delivered.clear() + return errors + + def restore_unacked_once(self, stderr=None): + """Restores all unacknowledged messages at shutdown/gc collect. + + Note: + Can only be called once for each instance, subsequent + calls will be ignored. + """ + self._on_collect.cancel() + self._flush() + stderr = sys.stderr if stderr is None else stderr + state = self._delivered + + if not self.restore_at_shutdown or not self.channel.do_restore: + return + if getattr(state, 'restored', None): + assert not state + return + try: + if state: + print('GOING TO RESTORE') + print(RESTORING_FMT.format(len(self._delivered)), + file=stderr) + unrestored = self.restore_unacked() + + if unrestored: + errors, messages = list(zip(*unrestored)) + print(RESTORE_PANIC_FMT.format(len(errors), errors), + file=stderr) + emergency_dump_state(messages, stderr=stderr) + finally: + state.restored = True + + def restore_visible(self, *args, **kwargs): + """Restore any pending unackwnowledged messages for visibility_timeout + style implementations. + + Note: + This is implementation optional, and currently only + used by the Redis transport. + """ + pass + + +class Message(base.Message): + + def __init__(self, channel, payload, **kwargs): + self._raw = payload + properties = payload['properties'] + body = payload.get('body') + if body: + body = channel.decode_body(body, properties.get('body_encoding')) + kwargs.update({ + 'body': body, + 'delivery_tag': properties['delivery_tag'], + 'content_type': payload.get('content-type'), + 'content_encoding': payload.get('content-encoding'), + 'headers': payload.get('headers'), + 'properties': properties, + 'delivery_info': properties.get('delivery_info'), + 'postencode': 'utf-8', + }) + super(Message, self).__init__(channel, **kwargs) + + def serializable(self): + props = self.properties + body, _ = self.channel.encode_body(self.body, + props.get('body_encoding')) + headers = dict(self.headers) + # remove compression header + headers.pop('compression', None) + return { + 'body': body, + 'properties': props, + 'content-type': self.content_type, + 'content-encoding': self.content_encoding, + 'headers': headers, + } + + +class AbstractChannel(object): + """This is an abstract class defining the channel methods + you'd usually want to implement in a virtual channel. + + Note: + Do not subclass directly, but rather inherit + from :class:`Channel`. + """ + + def _get(self, queue, timeout=None): + """Get next message from `queue`.""" + raise NotImplementedError('Virtual channels must implement _get') + + def _put(self, queue, message): + """Put `message` onto `queue`.""" + raise NotImplementedError('Virtual channels must implement _put') + + def _purge(self, queue): + """Remove all messages from `queue`.""" + raise NotImplementedError('Virtual channels must implement _purge') + + def _size(self, queue): + """Return the number of messages in `queue` as an :class:`int`.""" + return 0 + + def _delete(self, queue, *args, **kwargs): + """Delete `queue`. + + Note: + This just purges the queue, if you need to do more you can + override this method. + """ + self._purge(queue) + + def _new_queue(self, queue, **kwargs): + """Create new queue. + + Note: + Your transport can override this method if it needs + to do something whenever a new queue is declared. + """ + pass + + def _has_queue(self, queue, **kwargs): + """Verify that queue exists. + + Returns: + bool: Should return :const:`True` if the queue exists + or :const:`False` otherwise. + """ + return True + + def _poll(self, cycle, timeout=None): + """Poll a list of queues for available messages.""" + return cycle.get() + + +class Channel(AbstractChannel, base.StdChannel): + """Virtual channel. + + Arguments: + connection (ConnectionT): The transport instance this + channel is part of. + """ + #: message class used. + Message = Message + + #: QoS class used. + QoS = QoS + + #: flag to restore unacked messages when channel + #: goes out of scope. + do_restore = True + + #: mapping of exchange types and corresponding classes. + exchange_types = dict(STANDARD_EXCHANGE_TYPES) + + #: flag set if the channel supports fanout exchanges. + supports_fanout = False + + #: Binary <-> ASCII codecs. + codecs = {'base64': Base64()} + + #: Default body encoding. + #: NOTE: ``transport_options['body_encoding']`` will override this value. + body_encoding = 'base64' + + #: counter used to generate delivery tags for this channel. + _delivery_tags = count(1) + + #: Optional queue where messages with no route is delivered. + #: Set by ``transport_options['deadletter_queue']``. + deadletter_queue = None + + # List of options to transfer from :attr:`transport_options`. + from_transport_options = ('body_encoding', 'deadletter_queue') + + # Priority defaults + default_priority = 0 + min_priority = 0 + max_priority = 9 + + def __init__(self, connection, **kwargs): + self.connection = connection + self._consumers = set() + self._cycle = None + self._tag_to_queue = {} + self._active_queues = [] + self._qos = None + self.closed = False + + # instantiate exchange types + self.exchange_types = dict( + (typ, cls(self)) for typ, cls in items(self.exchange_types) + ) + + try: + self.channel_id = self.connection._avail_channel_ids.pop() + except IndexError: + raise ResourceError( + 'No free channel ids, current={0}, channel_max={1}'.format( + len(self.connection.channels), + self.connection.channel_max), (20, 10), + ) + + topts = self.connection.client.transport_options + for opt_name in self.from_transport_options: + try: + setattr(self, opt_name, topts[opt_name]) + except KeyError: + pass + + def exchange_declare(self, exchange=None, type='direct', durable=False, + auto_delete=False, arguments=None, + nowait=False, passive=False): + """Declare exchange.""" + type = type or 'direct' + exchange = exchange or 'amq.%s' % type + if passive: + if exchange not in self.state.exchanges: + raise ChannelError( + 'NOT_FOUND - no exchange {0!r} in vhost {1!r}'.format( + exchange, self.connection.client.virtual_host or '/'), + (50, 10), 'Channel.exchange_declare', '404', + ) + return + try: + prev = self.state.exchanges[exchange] + if not self.typeof(exchange).equivalent(prev, exchange, type, + durable, auto_delete, + arguments): + raise NotEquivalentError(NOT_EQUIVALENT_FMT.format( + exchange, self.connection.client.virtual_host or '/')) + except KeyError: + self.state.exchanges[exchange] = { + 'type': type, + 'durable': durable, + 'auto_delete': auto_delete, + 'arguments': arguments or {}, + 'table': [], + } + + def exchange_delete(self, exchange, if_unused=False, nowait=False): + """Delete `exchange` and all its bindings.""" + for rkey, _, queue in self.get_table(exchange): + self.queue_delete(queue, if_unused=True, if_empty=True) + self.state.exchanges.pop(exchange, None) + + def queue_declare(self, queue=None, passive=False, **kwargs): + """Declare queue.""" + queue = queue or 'amq.gen-%s' % uuid() + if passive and not self._has_queue(queue, **kwargs): + raise ChannelError( + 'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format( + queue, self.connection.client.virtual_host or '/'), + (50, 10), 'Channel.queue_declare', '404', + ) + else: + self._new_queue(queue, **kwargs) + return queue_declare_ok_t(queue, self._size(queue), 0) + + def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs): + """Delete queue.""" + if if_empty and self._size(queue): + return + for exchange, routing_key, args in self.state.queue_bindings(queue): + meta = self.typeof(exchange).prepare_bind( + queue, exchange, routing_key, args, + ) + self._delete(queue, exchange, *meta, **kwargs) + self.state.queue_bindings_delete(queue) + + def after_reply_message_received(self, queue): + self.queue_delete(queue) + + def exchange_bind(self, destination, source='', routing_key='', + nowait=False, arguments=None): + raise NotImplementedError('transport does not support exchange_bind') + + def exchange_unbind(self, destination, source='', routing_key='', + nowait=False, arguments=None): + raise NotImplementedError('transport does not support exchange_unbind') + + def queue_bind(self, queue, exchange=None, routing_key='', + arguments=None, **kwargs): + """Bind `queue` to `exchange` with `routing key`.""" + exchange = exchange or 'amq.direct' + if self.state.has_binding(queue, exchange, routing_key): + return + # Add binding: + self.state.binding_declare(queue, exchange, routing_key, arguments) + # Update exchange's routing table: + table = self.state.exchanges[exchange].setdefault('table', []) + meta = self.typeof(exchange).prepare_bind( + queue, exchange, routing_key, arguments, + ) + table.append(meta) + if self.supports_fanout: + self._queue_bind(exchange, *meta) + + def queue_unbind(self, queue, exchange=None, routing_key='', + arguments=None, **kwargs): + # Remove queue binding: + self.state.binding_delete(queue, exchange, routing_key) + try: + table = self.get_table(exchange) + except KeyError: + return + binding_meta = self.typeof(exchange).prepare_bind( + queue, exchange, routing_key, arguments, + ) + # TODO: the complexity of this operation is O(number of bindings). + # Should be optimized. Modifying table in place. + table[:] = [meta for meta in table if meta != binding_meta] + + def list_bindings(self): + return ((queue, exchange, rkey) + for exchange in self.state.exchanges + for rkey, pattern, queue in self.get_table(exchange)) + + def queue_purge(self, queue, **kwargs): + """Remove all ready messages from queue.""" + return self._purge(queue) + + def _next_delivery_tag(self): + return uuid() + + def basic_publish(self, message, exchange, routing_key, **kwargs): + """Publish message.""" + message['body'], body_encoding = self.encode_body( + message['body'], self.body_encoding, + ) + props = message['properties'] + props.update( + body_encoding=body_encoding, + delivery_tag=self._next_delivery_tag(), + ) + props['delivery_info'].update( + exchange=exchange, + routing_key=routing_key, + ) + if exchange: + return self.typeof(exchange).deliver( + message, exchange, routing_key, **kwargs + ) + # anon exchange: routing_key is the destination queue + return self._put(routing_key, message, **kwargs) + + def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): + """Consume from `queue`""" + self._tag_to_queue[consumer_tag] = queue + self._active_queues.append(queue) + + def _callback(raw_message): + message = self.Message(self, raw_message) + if not no_ack: + self.qos.append(message, message.delivery_tag) + return callback(message) + + self.connection._callbacks[queue] = _callback + self._consumers.add(consumer_tag) + + self._reset_cycle() + + def basic_cancel(self, consumer_tag): + """Cancel consumer by consumer tag.""" + if consumer_tag in self._consumers: + self._consumers.remove(consumer_tag) + self._reset_cycle() + queue = self._tag_to_queue.pop(consumer_tag, None) + try: + self._active_queues.remove(queue) + except ValueError: + pass + self.connection._callbacks.pop(queue, None) + + def basic_get(self, queue, no_ack=False, **kwargs): + """Get message by direct access (synchronous).""" + try: + message = self.Message(self, self._get(queue)) + if not no_ack: + self.qos.append(message, message.delivery_tag) + return message + except Empty: + pass + + def basic_ack(self, delivery_tag, multiple=False): + """Acknowledge message.""" + self.qos.ack(delivery_tag) + + def basic_recover(self, requeue=False): + """Recover unacked messages.""" + if requeue: + return self.qos.restore_unacked() + raise NotImplementedError('Does not support recover(requeue=False)') + + def basic_reject(self, delivery_tag, requeue=False): + """Reject message.""" + self.qos.reject(delivery_tag, requeue=requeue) + + def basic_qos(self, prefetch_size=0, prefetch_count=0, + apply_global=False): + """Change QoS settings for this channel. + + Note: + Only `prefetch_count` is supported. + """ + self.qos.prefetch_count = prefetch_count + + def get_exchanges(self): + return list(self.state.exchanges) + + def get_table(self, exchange): + """Get table of bindings for `exchange`.""" + return self.state.exchanges[exchange]['table'] + + def typeof(self, exchange, default='direct'): + """Get the exchange type instance for `exchange`.""" + try: + type = self.state.exchanges[exchange]['type'] + except KeyError: + type = default + return self.exchange_types[type] + + def _lookup(self, exchange, routing_key, default=None): + """Find all queues matching `routing_key` for the given `exchange`. + + Returns: + str: queue name -- must return the string `default` + if no queues matched. + """ + if default is None: + default = self.deadletter_queue + if not exchange: # anon exchange + return [routing_key or default] + + try: + R = self.typeof(exchange).lookup( + self.get_table(exchange), + exchange, routing_key, default, + ) + except KeyError: + R = [] + + if not R and default is not None: + warnings.warn(UndeliverableWarning(UNDELIVERABLE_FMT.format( + exchange=exchange, routing_key=routing_key)), + ) + self._new_queue(default) + R = [default] + return R + + def _restore(self, message): + """Redeliver message to its original destination.""" + delivery_info = message.delivery_info + message = message.serializable() + message['redelivered'] = True + for queue in self._lookup( + delivery_info['exchange'], delivery_info['routing_key']): + self._put(queue, message) + + def _restore_at_beginning(self, message): + return self._restore(message) + + def drain_events(self, timeout=None): + if self._consumers and self.qos.can_consume(): + if hasattr(self, '_get_many'): + return self._get_many(self._active_queues, timeout=timeout) + return self._poll(self.cycle, timeout=timeout) + raise Empty() + + def message_to_python(self, raw_message): + """Convert raw message to :class:`Message` instance.""" + if not isinstance(raw_message, self.Message): + return self.Message(self, payload=raw_message) + return raw_message + + def prepare_message(self, body, priority=None, content_type=None, + content_encoding=None, headers=None, properties=None): + """Prepare message data.""" + properties = properties or {} + properties.setdefault('delivery_info', {}) + properties.setdefault('priority', priority or self.default_priority) + + return {'body': body, + 'content-encoding': content_encoding, + 'content-type': content_type, + 'headers': headers or {}, + 'properties': properties or {}} + + def flow(self, active=True): + """Enable/disable message flow. + + Raises: + NotImplementedError: as flow + is not implemented by the base virtual implementation. + """ + raise NotImplementedError('virtual channels do not support flow.') + + def close(self): + """Close channel, cancel all consumers, and requeue unacked + messages.""" + if not self.closed: + self.closed = True + for consumer in list(self._consumers): + self.basic_cancel(consumer) + if self._qos: + self._qos.restore_unacked_once() + if self._cycle is not None: + self._cycle.close() + self._cycle = None + if self.connection is not None: + self.connection.close_channel(self) + self.exchange_types = None + + def encode_body(self, body, encoding=None): + if encoding: + return self.codecs.get(encoding).encode(body), encoding + return body, encoding + + def decode_body(self, body, encoding=None): + if encoding: + return self.codecs.get(encoding).decode(body) + return body + + def _reset_cycle(self): + self._cycle = FairCycle(self._get, self._active_queues, Empty) + + def __enter__(self): + return self + + def __exit__(self, *exc_info): + self.close() + + @property + def state(self): + """Broker state containing exchanges and bindings.""" + return self.connection.state + + @property + def qos(self): + """:class:`QoS` manager for this channel.""" + if self._qos is None: + self._qos = self.QoS(self) + return self._qos + + @property + def cycle(self): + if self._cycle is None: + self._reset_cycle() + return self._cycle + + def _get_message_priority(self, message, reverse=False): + """Get priority from message and limit the value within a + boundary of 0 to 9. + + Note: + Higher value has more priority. + """ + try: + priority = max( + min(int(message['properties']['priority']), + self.max_priority), + self.min_priority, + ) + except (TypeError, ValueError, KeyError): + priority = self.default_priority + + return (self.max_priority - priority) if reverse else priority + + +class Management(base.Management): + + def __init__(self, transport): + super(Management, self).__init__(transport) + self.channel = transport.client.channel() + + def get_bindings(self): + return [dict(destination=q, source=e, routing_key=r) + for q, e, r in self.channel.list_bindings()] + + def close(self): + self.channel.close() + + +class Transport(base.Transport): + """Virtual transport. + + Arguments: + client (kombu.Connection): The client this is a transport for. + """ + Channel = Channel + Cycle = FairCycle + Management = Management + + #: :class:`BrokerState` containing declared exchanges and + #: bindings (set by constructor). + state = BrokerState() + + #: :class:`~kombu.utils.scheduling.FairCycle` instance + #: used to fairly drain events from channels (set by constructor). + cycle = None + + #: port number used when no port is specified. + default_port = None + + #: active channels. + channels = None + + #: queue/callback map. + _callbacks = None + + #: Time to sleep between unsuccessful polls. + polling_interval = 1.0 + + #: Max number of channels + channel_max = 65535 + + implements = base.Transport.implements.extend( + async=False, + exchange_type=frozenset(['direct', 'topic']), + heartbeats=False, + ) + + def __init__(self, client, **kwargs): + self.client = client + self.channels = [] + self._avail_channels = [] + self._callbacks = {} + self.cycle = self.Cycle(self._drain_channel, self.channels, Empty) + polling_interval = client.transport_options.get('polling_interval') + if polling_interval is not None: + self.polling_interval = polling_interval + self._avail_channel_ids = array( + ARRAY_TYPE_H, range(self.channel_max, 0, -1), + ) + + def create_channel(self, connection): + try: + return self._avail_channels.pop() + except IndexError: + channel = self.Channel(connection) + self.channels.append(channel) + return channel + + def close_channel(self, channel): + try: + self._avail_channel_ids.append(channel.channel_id) + try: + self.channels.remove(channel) + except ValueError: + pass + finally: + channel.connection = None + + def establish_connection(self): + # creates channel to verify connection. + # this channel is then used as the next requested channel. + # (returned by ``create_channel``). + self._avail_channels.append(self.create_channel(self)) + return self # for drain events + + def close_connection(self, connection): + self.cycle.close() + for l in self._avail_channels, self.channels: + while l: + try: + channel = l.pop() + except LookupError: # pragma: no cover + pass + else: + channel.close() + + def drain_events(self, connection, timeout=None): + loop = 0 + time_start = monotonic() + get = self.cycle.get + polling_interval = self.polling_interval + while 1: + try: + item, channel = get(timeout=timeout) + except Empty: + if timeout and monotonic() - time_start >= timeout: + raise socket.timeout() + loop += 1 + if polling_interval is not None: + sleep(polling_interval) + else: + break + self._deliver(*item) + + def _deliver(self, message, queue): + if not queue: + raise KeyError( + 'Received message without destination queue: {0}'.format( + message)) + try: + callback = self._callbacks[queue] + except KeyError: + logger.warn(W_NO_CONSUMERS, queue) + self._reject_inbound_message(message) + else: + callback(message) + + def _reject_inbound_message(self, raw_message): + for channel in self.channels: + if channel: + message = channel.Message(channel, raw_message) + channel.qos.append(message, message.delivery_tag) + channel.basic_reject(message.delivery_tag, requeue=True) + break + + def on_message_ready(self, channel, message, queue): + if not queue or queue not in self._callbacks: + raise KeyError( + 'Message for queue {0!r} without consumers: {1}'.format( + queue, message)) + self._callbacks[queue](message) + + def _drain_channel(self, channel, timeout=None): + return channel.drain_events(timeout=timeout) + + @property + def default_connection_params(self): + return {'port': self.default_port, 'hostname': 'localhost'} diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt index 485e8a79..0684adfc 100644 --- a/requirements/test-ci.txt +++ b/requirements/test-ci.txt @@ -1,4 +1,4 @@ -coverage>=3.0 +pytest-cov codecov redis PyYAML diff --git a/requirements/test.txt b/requirements/test.txt index cb284d73..b093eab7 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,2 +1,3 @@ pytz>dev -case>=1.2.2 +case>=1.3.1 +pytest @@ -1,7 +1,6 @@ -[nosetests] -verbosity = 1 -detailed-errors = 1 -where = kombu/tests +[tool:pytest] +testpaths = t/ +python_classes = test_* [build_sphinx] source-dir = docs/ @@ -5,6 +5,9 @@ import re import sys import codecs +import setuptools +import setuptools.command.test + from distutils.command.install import INSTALL_SCHEMES if sys.version_info < (2, 7): @@ -44,7 +47,7 @@ finally: meta_fh.close() # -- -packages, data_files = [], [] +data_files = [] root_dir = os.path.dirname(__file__) if root_dir != '': os.chdir(root_dir) @@ -71,9 +74,7 @@ for dirpath, dirnames, filenames in os.walk(src_dir): if dirname.startswith('.'): del dirnames[i] for filename in filenames: - if filename.endswith('.py'): - packages.append('.'.join(fullsplit(dirpath))) - else: + if not filename.endswith('.py'): data_files.append( [dirpath, [os.path.join(dirpath, f) for f in filenames]], ) @@ -104,20 +105,46 @@ def reqs(*f): def extras(*p): return reqs('extras', *p) + +class pytest(setuptools.command.test.test): + user_options = [('pytest-args=', 'a', 'Arguments to pass to py.test')] + + def initialize_options(self): + setuptools.command.test.test.initialize_options(self) + self.pytest_args = [] + + def run_tests(self): + import pytest + sys.exit(pytest.main(self.pytest_args)) + setup( name='kombu', + packages=setuptools.find_packages(exclude=['t', 't.*']), version=meta['version'], description=meta['doc'], + long_description=long_description, author=meta['author'], author_email=meta['contact'], url=meta['homepage'], platforms=['any'], - packages=packages, data_files=data_files, zip_safe=False, - test_suite='nose.collector', + cmdclass={'test': pytest}, install_requires=reqs('default.txt'), tests_require=reqs('test.txt'), + extras_require={ + 'msgpack': extras('msgpack.txt'), + 'yaml': extras('yaml.txt'), + 'redis': extras('redis.txt'), + 'mongodb': extras('mongodb.txt'), + 'sqs': extras('sqs.txt'), + 'zookeeper': extras('zookeeper.txt'), + 'librabbitmq': extras('librabbitmq.txt'), + 'pyro': extras('pyro.txt'), + 'slmq': extras('slmq.txt'), + 'qpid': extras('qpid.txt'), + 'consul': extras('consul.txt'), + }, classifiers=[ 'Development Status :: 5 - Production/Stable', 'License :: OSI Approved :: BSD License', @@ -137,18 +164,4 @@ setup( 'Topic :: System :: Networking', 'Topic :: Software Development :: Libraries :: Python Modules', ], - long_description=long_description, - extras_require={ - 'msgpack': extras('msgpack.txt'), - 'yaml': extras('yaml.txt'), - 'redis': extras('redis.txt'), - 'mongodb': extras('mongodb.txt'), - 'sqs': extras('sqs.txt'), - 'zookeeper': extras('zookeeper.txt'), - 'librabbitmq': extras('librabbitmq.txt'), - 'pyro': extras('pyro.txt'), - 'slmq': extras('slmq.txt'), - 'qpid': extras('qpid.txt'), - 'consul': extras('consul.txt'), - }, ) diff --git a/kombu/tests/async/__init__.py b/t/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/async/__init__.py +++ b/t/__init__.py diff --git a/t/conftest.py b/t/conftest.py new file mode 100644 index 00000000..9e8e959f --- /dev/null +++ b/t/conftest.py @@ -0,0 +1,98 @@ +from __future__ import absolute_import, unicode_literals + +import atexit +import os +import pytest +import sys + +from kombu.exceptions import VersionMismatch + + +@pytest.fixture(scope='session') +def multiprocessing_workaround(request): + def fin(): + # Workaround for multiprocessing bug where logging + # is attempted after global already collected at shutdown. + canceled = set() + try: + import multiprocessing.util + canceled.add(multiprocessing.util._exit_function) + except (AttributeError, ImportError): + pass + + try: + atexit._exithandlers[:] = [ + e for e in atexit._exithandlers if e[0] not in canceled + ] + except AttributeError: # pragma: no cover + pass # Py3 missing _exithandlers + request.addfinalizer(fin) + + +@pytest.fixture(autouse=True) +def zzzz_test_cases_calls_setup_teardown(request): + if request.instance: + # we set the .patching attribute for every test class. + setup = getattr(request.instance, 'setup', None) + # we also call .setup() and .teardown() after every test method. + teardown = getattr(request.instance, 'teardown', None) + setup and setup() + teardown and request.addfinalizer(teardown) + + +@pytest.fixture(autouse=True) +def test_cases_has_patching(request, patching): + if request.instance: + request.instance.patching = patching + + +@pytest.fixture +def hub(request): + from kombu.async import Hub, get_event_loop, set_event_loop + _prev_hub = get_event_loop() + hub = Hub() + set_event_loop(hub) + + def fin(): + if _prev_hub is not None: + set_event_loop(_prev_hub) + request.addfinalizer(fin) + return hub + + +def find_distribution_modules(name=__name__, file=__file__): + current_dist_depth = len(name.split('.')) - 1 + current_dist = os.path.join(os.path.dirname(file), + *([os.pardir] * current_dist_depth)) + abs = os.path.abspath(current_dist) + dist_name = os.path.basename(abs) + + for dirpath, dirnames, filenames in os.walk(abs): + package = (dist_name + dirpath[len(abs):]).replace('/', '.') + if '__init__.py' in filenames: + yield package + for filename in filenames: + if filename.endswith('.py') and filename != '__init__.py': + yield '.'.join([package, filename])[:-3] + + +def import_all_modules(name=__name__, file=__file__, skip=[]): + for module in find_distribution_modules(name, file): + if module not in skip: + print('preimporting %r for coverage...' % (module,)) + try: + __import__(module) + except (ImportError, VersionMismatch, AttributeError): + pass + + +def is_in_coverage(): + return (os.environ.get('COVER_ALL_MODULES') or + any('--cov' in arg for arg in sys.argv)) + + +@pytest.fixture(scope='session') +def cover_all_modules(): + # so coverage sees all our modules. + if is_in_coverage(): + import_all_modules() diff --git a/kombu/tests/mocks.py b/t/mocks.py index dac2b761..ea58eb22 100644 --- a/kombu/tests/mocks.py +++ b/t/mocks.py @@ -2,10 +2,34 @@ from __future__ import absolute_import, unicode_literals from itertools import count +from case import ContextMock, Mock + from kombu.transport import base from kombu.utils import json +def PromiseMock(*args, **kwargs): + m = Mock(*args, **kwargs) + + def on_throw(exc=None, *args, **kwargs): + if exc: + raise exc + raise + m.throw.side_effect = on_throw + m.set_error_state.side_effect = on_throw + m.throw1.side_effect = on_throw + return m + + +class MockPool(object): + + def __init__(self, value=None): + self.value = value or ContextMock() + + def acquire(self, **kwargs): + return self.value + + class Message(base.Message): def __init__(self, *args, **kwargs): diff --git a/t/unit/__init__.py b/t/unit/__init__.py new file mode 100644 index 00000000..01e6d4f4 --- /dev/null +++ b/t/unit/__init__.py @@ -0,0 +1 @@ +from __future__ import absolute_import, unicode_literals diff --git a/kombu/tests/async/aws/__init__.py b/t/unit/async/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/async/aws/__init__.py +++ b/t/unit/async/__init__.py diff --git a/kombu/tests/async/aws/sqs/__init__.py b/t/unit/async/aws/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/async/aws/sqs/__init__.py +++ b/t/unit/async/aws/__init__.py diff --git a/kombu/tests/async/aws/case.py b/t/unit/async/aws/case.py index e7e1a20c..d0af4faf 100644 --- a/kombu/tests/async/aws/case.py +++ b/t/unit/async/aws/case.py @@ -1,11 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -from kombu.tests.case import HubCase, skip +import pytest + +from case import skip @skip.if_pypy() @skip.unless_module('boto') @skip.unless_module('pycurl') -class AWSCase(HubCase): +@pytest.mark.usefixtures('hub') +class AWSCase: pass diff --git a/kombu/tests/async/http/__init__.py b/t/unit/async/aws/sqs/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/async/http/__init__.py +++ b/t/unit/async/aws/sqs/__init__.py diff --git a/kombu/tests/async/aws/sqs/test_connection.py b/t/unit/async/aws/sqs/test_connection.py index 9a9dcdf0..ecc22623 100644 --- a/kombu/tests/async/aws/sqs/test_connection.py +++ b/t/unit/async/aws/sqs/test_connection.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock + from kombu.async.aws.sqs.connection import ( AsyncSQSConnection, Attributes, BatchResults, ) @@ -8,8 +12,9 @@ from kombu.async.aws.sqs.message import AsyncMessage from kombu.async.aws.sqs.queue import AsyncQueue from kombu.utils.uuid import uuid -from kombu.tests.async.aws.case import AWSCase -from kombu.tests.case import PromiseMock, Mock +from t.mocks import PromiseMock + +from ..case import AWSCase class test_AsyncSQSConnection(AWSCase): @@ -25,16 +30,14 @@ class test_AsyncSQSConnection(AWSCase): from kombu.async.aws.sqs import connection prev, connection.boto = connection.boto, None try: - with self.assertRaises(ImportError): + with pytest.raises(ImportError): AsyncSQSConnection('ak', 'sk', http_client=Mock()) finally: connection.boto = prev def test_default_region(self): - self.assertTrue(self.x.region) - self.assertTrue(issubclass( - self.x.region.connection_cls, AsyncSQSConnection, - )) + assert self.x.region + assert issubclass(self.x.region.connection_cls, AsyncSQSConnection) def test_create_queue(self): self.x.create_queue('foo', callback=self.callback) diff --git a/kombu/tests/async/aws/sqs/test_message.py b/t/unit/async/aws/sqs/test_message.py index 0f1a033b..44a0ac32 100644 --- a/kombu/tests/async/aws/sqs/test_message.py +++ b/t/unit/async/aws/sqs/test_message.py @@ -1,12 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -from kombu.async.aws.sqs.message import AsyncMessage +from case import Mock -from kombu.tests.async.aws.case import AWSCase -from kombu.tests.case import PromiseMock, Mock +from kombu.async.aws.sqs.message import AsyncMessage from kombu.utils.uuid import uuid +from t.mocks import PromiseMock + +from ..case import AWSCase + class test_AsyncMessage(AWSCase): @@ -17,20 +20,18 @@ class test_AsyncMessage(AWSCase): self.x.receipt_handle = uuid() def test_delete(self): - self.assertTrue(self.x.delete(callback=self.callback)) + assert self.x.delete(callback=self.callback) self.x.queue.delete_message.assert_called_with( self.x, self.callback, ) self.x.queue = None - self.assertIsNone(self.x.delete(callback=self.callback)) + assert self.x.delete(callback=self.callback) is None def test_change_visibility(self): - self.assertTrue(self.x.change_visibility(303, callback=self.callback)) + assert self.x.change_visibility(303, callback=self.callback) self.x.queue.connection.change_message_visibility.assert_called_with( self.x.queue, self.x.receipt_handle, 303, self.callback, ) self.x.queue = None - self.assertIsNone(self.x.change_visibility( - 303, callback=self.callback, - )) + assert self.x.change_visibility(303, callback=self.callback) is None diff --git a/kombu/tests/async/aws/sqs/test_queue.py b/t/unit/async/aws/sqs/test_queue.py index d34eb2e8..635016a6 100644 --- a/kombu/tests/async/aws/sqs/test_queue.py +++ b/t/unit/async/aws/sqs/test_queue.py @@ -1,11 +1,16 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock + from kombu.async.aws.sqs.message import AsyncMessage from kombu.async.aws.sqs.queue import AsyncQueue -from kombu.tests.async.aws.case import AWSCase -from kombu.tests.case import PromiseMock, Mock +from t.mocks import PromiseMock + +from ..case import AWSCase class test_AsyncQueue(AWSCase): @@ -16,7 +21,7 @@ class test_AsyncQueue(AWSCase): self.callback = PromiseMock(name='callback') def test_message_class(self): - self.assertTrue(issubclass(self.x.message_class, AsyncMessage)) + assert issubclass(self.x.message_class, AsyncMessage) def test_get_attributes(self): self.x.get_attributes(attributes='QueueSize', callback=self.callback) @@ -50,10 +55,10 @@ class test_AsyncQueue(AWSCase): ) on_ready(808) self.callback.assert_called_with(808) - self.assertEqual(self.x.visibility_timeout, 808) + assert self.x.visibility_timeout == 808 on_ready(None) - self.assertEqual(self.x.visibility_timeout, 808) + assert self.x.visibility_timeout == 808 def test_add_permission(self): self.x.add_permission( @@ -101,8 +106,8 @@ class test_AsyncQueue(AWSCase): new_message = self.MockMessage('id2', 'digest2') on_ready(new_message) - self.assertEqual(message.id, 'id2') - self.assertEqual(message.md5, 'digest2') + assert message.id == 'id2' + assert message.md5 == 'digest2' def test_write_batch(self): messages = [('id1', 'A', 0), ('id2', 'B', 303)] @@ -158,45 +163,45 @@ class test_AsyncQueue(AWSCase): self.callback.assert_called_with(909) def test_interface__count_slow(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.count_slow() def test_interface__dump(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.dump() def test_interface__save_to_file(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.save_to_file() def test_interface__save_to_filename(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.save_to_filename() def test_interface__save(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.save() def test_interface__save_to_s3(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.save_to_s3() def test_interface__load_from_s3(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.load_from_s3() def test_interface__load_from_file(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.load_from_file() def test_interface__load_from_filename(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.load_from_filename() def test_interface__load(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.load() def test_interface__clear(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.x.clear() diff --git a/kombu/tests/async/aws/sqs/test_sqs.py b/t/unit/async/aws/sqs/test_sqs.py index 433ffdad..ea58596d 100644 --- a/kombu/tests/async/aws/sqs/test_sqs.py +++ b/t/unit/async/aws/sqs/test_sqs.py @@ -1,23 +1,26 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock, patch + from kombu.async.aws.sqs import regions, connect_to_region from kombu.async.aws.sqs.connection import AsyncSQSConnection -from kombu.tests.async.aws.case import AWSCase -from kombu.tests.case import Mock, patch, set_module_symbol +from ..case import AWSCase class test_connect_to_region(AWSCase): - def test_when_no_boto_installed(self): - with set_module_symbol('kombu.async.aws.sqs', 'boto', None): - with self.assertRaises(ImportError): - regions() + def test_when_no_boto_installed(self, patching): + patching('kombu.async.aws.sqs.boto', None) + with pytest.raises(ImportError): + regions() def test_using_async_connection(self): for region in regions(): - self.assertIs(region.connection_cls, AsyncSQSConnection) + assert region.connection_cls is AsyncSQSConnection def test_connect_to_region(self): with patch('kombu.async.aws.sqs.regions') as regions: @@ -25,7 +28,7 @@ class test_connect_to_region(AWSCase): region.name = 'us-west-1' regions.return_value = [region] conn = connect_to_region('us-west-1', kw=3.33) - self.assertIs(conn, region.connect.return_value) + assert conn is region.connect.return_value region.connect.assert_called_with(kw=3.33) - self.assertIsNone(connect_to_region('foo')) + assert connect_to_region('foo') is None diff --git a/kombu/tests/async/aws/test_aws.py b/t/unit/async/aws/test_aws.py index b16b6bb8..f5ed9aef 100644 --- a/kombu/tests/async/aws/test_aws.py +++ b/t/unit/async/aws/test_aws.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +from case import Mock + from kombu.async.aws import connect_sqs -from kombu.tests.case import Mock from .case import AWSCase @@ -11,5 +12,5 @@ class test_connect_sqs(AWSCase): def test_connection(self): x = connect_sqs('AAKI', 'ASAK', http_client=Mock()) - self.assertTrue(x) - self.assertTrue(x.connection) + assert x + assert x.connection diff --git a/kombu/tests/async/aws/test_connection.py b/t/unit/async/aws/test_connection.py index 8304af05..5de76aa5 100644 --- a/kombu/tests/async/aws/test_connection.py +++ b/t/unit/async/aws/test_connection.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import pytest + from contextlib import contextmanager +from case import Mock, patch from vine.abstract import Thenable from kombu.exceptions import HttpError @@ -18,10 +21,10 @@ from kombu.async.aws.connection import ( AsyncAWSQueryConnection, ) -from kombu.tests.case import PromiseMock, Mock, patch, set_module_symbol - from .case import AWSCase +from t.mocks import PromiseMock + # Not currently working VALIDATES_CERT = False @@ -39,56 +42,56 @@ class test_AsyncHTTPConnection(AWSCase): def test_AsyncHTTPSConnection(self): x = AsyncHTTPSConnection('aws.vandelay.com') - self.assertEqual(x.scheme, 'https') + assert x.scheme == 'https' def test_http_client(self): x = AsyncHTTPConnection('aws.vandelay.com') - self.assertIs(x.http_client, http.get_client()) + assert x.http_client is http.get_client() client = Mock(name='http_client') y = AsyncHTTPConnection('aws.vandelay.com', http_client=client) - self.assertIs(y.http_client, client) + assert y.http_client is client def test_args(self): x = AsyncHTTPConnection( 'aws.vandelay.com', 8083, strict=True, timeout=33.3, ) - self.assertEqual(x.host, 'aws.vandelay.com') - self.assertEqual(x.port, 8083) - self.assertTrue(x.strict) - self.assertEqual(x.timeout, 33.3) - self.assertEqual(x.scheme, 'http') + assert x.host == 'aws.vandelay.com' + assert x.port == 8083 + assert x.strict + assert x.timeout == 33.3 + assert x.scheme == 'http' def test_request(self): x = AsyncHTTPConnection('aws.vandelay.com') x.request('PUT', '/importer-exporter') - self.assertEqual(x.path, '/importer-exporter') - self.assertEqual(x.method, 'PUT') + assert x.path == '/importer-exporter' + assert x.method == 'PUT' def test_request_with_body_buffer(self): x = AsyncHTTPConnection('aws.vandelay.com') body = Mock(name='body') body.read.return_value = 'Vandelay Industries' x.request('PUT', '/importer-exporter', body) - self.assertEqual(x.method, 'PUT') - self.assertEqual(x.path, '/importer-exporter') - self.assertEqual(x.body, 'Vandelay Industries') + assert x.method == 'PUT' + assert x.path == '/importer-exporter' + assert x.body == 'Vandelay Industries' body.read.assert_called_with() def test_request_with_body_text(self): x = AsyncHTTPConnection('aws.vandelay.com') x.request('PUT', '/importer-exporter', 'Vandelay Industries') - self.assertEqual(x.method, 'PUT') - self.assertEqual(x.path, '/importer-exporter') - self.assertEqual(x.body, 'Vandelay Industries') + assert x.method == 'PUT' + assert x.path == '/importer-exporter' + assert x.body == 'Vandelay Industries' def test_request_with_headers(self): x = AsyncHTTPConnection('aws.vandelay.com') headers = {'Proxy': 'proxy.vandelay.com'} x.request('PUT', '/importer-exporter', None, headers) - self.assertIn('Proxy', dict(x.headers)) - self.assertEqual(dict(x.headers)['Proxy'], 'proxy.vandelay.com') + assert 'Proxy' in dict(x.headers) + assert dict(x.headers)['Proxy'] == 'proxy.vandelay.com' - def assertRequestCreatedWith(self, url, conn): + def assert_request_created_with(self, url, conn): conn.Request.assert_called_with( url, method=conn.method, headers=http.Headers(conn.headers), body=conn.body, @@ -100,18 +103,18 @@ class test_AsyncHTTPConnection(AWSCase): x = AsyncHTTPSConnection('aws.vandelay.com') x.Request = Mock(name='Request') x.getrequest() - self.assertRequestCreatedWith('https://aws.vandelay.com/', x) + self.assert_request_created_with('https://aws.vandelay.com/', x) def test_getrequest_nondefault_port(self): x = AsyncHTTPConnection('aws.vandelay.com', port=8080) x.Request = Mock(name='Request') x.getrequest() - self.assertRequestCreatedWith('http://aws.vandelay.com:8080/', x) + self.assert_request_created_with('http://aws.vandelay.com:8080/', x) y = AsyncHTTPSConnection('aws.vandelay.com', port=8443) y.Request = Mock(name='Request') y.getrequest() - self.assertRequestCreatedWith('https://aws.vandelay.com:8443/', y) + self.assert_request_created_with('https://aws.vandelay.com:8443/', y) def test_getresponse(self): client = Mock(name='client') @@ -120,8 +123,8 @@ class test_AsyncHTTPConnection(AWSCase): x.Response = Mock(name='x.Response') request = x.getresponse() x.http_client.add_request.assert_called_with(request) - self.assertIsInstance(request, Thenable) - self.assertIsInstance(request.on_ready, Thenable) + assert isinstance(request, Thenable) + assert isinstance(request.on_ready, Thenable) response = Mock(name='Response') request.on_ready(response) @@ -145,46 +148,46 @@ class test_AsyncHTTPConnection(AWSCase): callback.assert_called() wresponse = callback.call_args[0][0] - self.assertEqual(wresponse.read(), 'The quick brown fox jumps') - self.assertEqual(wresponse.status, 200) - self.assertEqual(wresponse.getheader('X-Foo'), 'Hello') - self.assertDictEqual(dict(wresponse.getheaders()), headers) - self.assertTrue(wresponse.msg) - self.assertTrue(wresponse.msg) - self.assertTrue(repr(wresponse)) + assert wresponse.read() == 'The quick brown fox jumps' + assert wresponse.status == 200 + assert wresponse.getheader('X-Foo') == 'Hello' + assert dict(wresponse.getheaders()) == headers + assert wresponse.msg + assert wresponse.msg + assert repr(wresponse) def test_repr(self): - self.assertTrue(repr(AsyncHTTPConnection('aws.vandelay.com'))) + assert repr(AsyncHTTPConnection('aws.vandelay.com')) def test_putrequest(self): x = AsyncHTTPConnection('aws.vandelay.com') x.putrequest('UPLOAD', '/new') - self.assertEqual(x.method, 'UPLOAD') - self.assertEqual(x.path, '/new') + assert x.method == 'UPLOAD' + assert x.path == '/new' def test_putheader(self): x = AsyncHTTPConnection('aws.vandelay.com') x.putheader('X-Foo', 'bar') - self.assertListEqual(x.headers, [('X-Foo', 'bar')]) + assert x.headers == [('X-Foo', 'bar')] x.putheader('X-Bar', 'baz') - self.assertListEqual(x.headers, [ + assert x.headers == [ ('X-Foo', 'bar'), ('X-Bar', 'baz'), - ]) + ] def test_send(self): x = AsyncHTTPConnection('aws.vandelay.com') x.send('foo') - self.assertEqual(x.body, 'foo') + assert x.body == 'foo' x.send('bar') - self.assertEqual(x.body, 'foobar') + assert x.body == 'foobar' def test_interface(self): x = AsyncHTTPConnection('aws.vandelay.com') - self.assertIsNone(x.set_debuglevel(3)) - self.assertIsNone(x.connect()) - self.assertIsNone(x.close()) - self.assertIsNone(x.endheaders()) + assert x.set_debuglevel(3) is None + assert x.connect() is None + assert x.close() is None + assert x.endheaders() is None class test_AsyncHTTPResponse(AWSCase): @@ -193,41 +196,41 @@ class test_AsyncHTTPResponse(AWSCase): r = Mock(name='response') r.error = HttpError(404, 'NotFound') x = AsyncHTTPResponse(r) - self.assertEqual(x.reason, 'NotFound') + assert x.reason == 'NotFound' r.error = None - self.assertFalse(x.reason) + assert not x.reason class test_AsyncConnection(AWSCase): - def test_when_boto_missing(self): - with set_module_symbol('kombu.async.aws.connection', 'boto', None): - with self.assertRaises(ImportError): - AsyncConnection(Mock(name='client')) + def test_when_boto_missing(self, patching): + patching('kombu.async.aws.connection.boto', None) + with pytest.raises(ImportError): + AsyncConnection(Mock(name='client')) def test_client(self): x = AsyncConnection() - self.assertIs(x._httpclient, http.get_client()) + assert x._httpclient is http.get_client() client = Mock(name='client') y = AsyncConnection(http_client=client) - self.assertIs(y._httpclient, client) + assert y._httpclient is client def test_get_http_connection(self): x = AsyncConnection(client=Mock(name='client')) - self.assertIsInstance( + assert isinstance( x.get_http_connection('aws.vandelay.com', 80, False), AsyncHTTPConnection, ) - self.assertIsInstance( + assert isinstance( x.get_http_connection('aws.vandelay.com', 443, True), AsyncHTTPSConnection, ) conn = x.get_http_connection('aws.vandelay.com', 80, False) - self.assertIs(conn.http_client, x._httpclient) - self.assertEqual(conn.host, 'aws.vandelay.com') - self.assertEqual(conn.port, 80) + assert conn.http_client is x._httpclient + assert conn.host == 'aws.vandelay.com' + assert conn.port == 80 class test_AsyncAWSAuthConnection(AWSCase): @@ -239,7 +242,7 @@ class test_AsyncAWSAuthConnection(AWSCase): Conn = x.get_http_connection = Mock(name='get_http_connection') callback = PromiseMock(name='callback') ret = x.make_request('GET', '/foo', callback=callback) - self.assertIs(ret, callback) + assert ret is callback Conn.return_value.request.assert_called() Conn.return_value.getresponse.assert_called_with( callback=callback, @@ -261,9 +264,8 @@ class test_AsyncAWSAuthConnection(AWSCase): ) no_callback_ret = x._mexe(request) - self.assertIsInstance( - no_callback_ret, Thenable, '_mexe always returns promise', - ) + # _mexe always returns promise + assert isinstance(no_callback_ret, Thenable) @patch('boto.log', create=True) def test_mexe__with_sender(self, _): @@ -296,11 +298,11 @@ class test_AsyncAWSQueryConnection(AWSCase): ) self.x._mexe.assert_called() request = self.x._mexe.call_args[0][0] - self.assertEqual(request.params['Action'], 'action') - self.assertEqual(request.params['Version'], self.x.APIVersion) + assert request.params['Action'] == 'action' + assert request.params['Version'] == self.x.APIVersion ret = _mexe(request, callback=callback) - self.assertIs(ret, callback) + assert ret is callback Conn.return_value.request.assert_called() Conn.return_value.getresponse.assert_called_with( callback=callback, @@ -316,8 +318,8 @@ class test_AsyncAWSQueryConnection(AWSCase): ) self.x._mexe.assert_called() request = self.x._mexe.call_args[0][0] - self.assertNotIn('Action', request.params) - self.assertEqual(request.params['Version'], self.x.APIVersion) + assert 'Action' not in request.params + assert request.params['Version'] == self.x.APIVersion @contextmanager def mock_sax_parse(self, parser): @@ -364,7 +366,7 @@ class test_AsyncAWSQueryConnection(AWSCase): self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback) on_ready = self.assert_make_request_called() - with self.assertRaises(self.x.ResponseError): + with pytest.raises(self.x.ResponseError): on_ready(self.Response(404, 'Not found')) def test_get_object(self): @@ -388,15 +390,15 @@ class test_AsyncAWSQueryConnection(AWSCase): callback.assert_called() result = callback.call_args[0][0] - self.assertEqual(result.value, 42) - self.assertTrue(result.parent) + assert result.value == 42 + assert result.parent def test_get_object_error(self): with self.mock_make_request() as callback: self.x.get_object('action', {'p': 3.3}, object, callback=callback) on_ready = self.assert_make_request_called() - with self.assertRaises(self.x.ResponseError): + with pytest.raises(self.x.ResponseError): on_ready(self.Response(404, 'Not found')) def test_get_status(self): @@ -422,7 +424,7 @@ class test_AsyncAWSQueryConnection(AWSCase): self.x.get_status('action', {'p': 3.3}, callback=callback) on_ready = self.assert_make_request_called() - with self.assertRaises(self.x.ResponseError): + with pytest.raises(self.x.ResponseError): on_ready(self.Response(404, 'Not found')) def test_get_status_error_empty_body(self): @@ -430,5 +432,5 @@ class test_AsyncAWSQueryConnection(AWSCase): self.x.get_status('action', {'p': 3.3}, callback=callback) on_ready = self.assert_make_request_called() - with self.assertRaises(self.x.ResponseError): + with pytest.raises(self.x.ResponseError): on_ready(self.Response(200, '')) diff --git a/kombu/tests/transport/__init__.py b/t/unit/async/http/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/transport/__init__.py +++ b/t/unit/async/http/__init__.py diff --git a/kombu/tests/async/http/test_curl.py b/t/unit/async/http/test_curl.py index 7bfc312e..70f38cfa 100644 --- a/kombu/tests/async/http/test_curl.py +++ b/t/unit/async/http/test_curl.py @@ -1,41 +1,40 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -from kombu.async.http.curl import READ, WRITE, CurlClient +import pytest + +from case import Mock, call, patch, skip -from kombu.tests.case import ( - HubCase, Mock, call, patch, set_module_symbol, skip, -) +from kombu.async.http.curl import READ, WRITE, CurlClient @skip.if_pypy() @skip.unless_module('pycurl') -class test_CurlClient(HubCase): +@pytest.mark.usefixtures('hub') +class test_CurlClient: class Client(CurlClient): Curl = Mock(name='Curl') - def test_when_pycurl_missing(self): - with set_module_symbol('kombu.async.http.curl', 'pycurl', None): - with self.assertRaises(ImportError): - self.Client() + def test_when_pycurl_missing(self, patching): + patching('kombu.async.http.curl.pycurl', None) + with pytest.raises(ImportError): + self.Client() def test_max_clients_set(self): x = self.Client(max_clients=303) - self.assertEqual(x.max_clients, 303) + assert x.max_clients == 303 def test_init(self): with patch('kombu.async.http.curl.pycurl') as _pycurl: x = self.Client() - self.assertIsNotNone(x._multi) - self.assertIsNotNone(x._pending) - self.assertIsNotNone(x._free_list) - self.assertIsNotNone(x._fds) - self.assertEqual( - x._socket_action, x._multi.socket_action, - ) - self.assertEqual(len(x._curls), x.max_clients) - self.assertTrue(x._timeout_check_tref) + assert x._multi is not None + assert x._pending is not None + assert x._free_list is not None + assert x._fds is not None + assert x._socket_action == x._multi.socket_action + assert len(x._curls) == x.max_clients + assert x._timeout_check_tref x._multi.setopt.assert_has_calls([ call(_pycurl.M_TIMERFUNCTION, x._set_timeout), @@ -59,7 +58,7 @@ class test_CurlClient(HubCase): x._set_timeout = Mock(name='_set_timeout') request = Mock(name='request') x.add_request(request) - self.assertIn(request, x._pending) + assert request in x._pending x._process_queue.assert_called_with() x._set_timeout.assert_called_with(0) @@ -73,7 +72,7 @@ class test_CurlClient(HubCase): x._fds[fd] = fd x._handle_socket(_pycurl.POLL_REMOVE, fd, x._multi, None, _pycurl) hub.remove.assert_called_with(fd) - self.assertNotIn(fd, x._fds) + assert fd not in x._fds x._handle_socket(_pycurl.POLL_REMOVE, fd, x._multi, None, _pycurl) # POLL_IN @@ -83,20 +82,20 @@ class test_CurlClient(HubCase): x._handle_socket(_pycurl.POLL_IN, fd, x._multi, None, _pycurl) hub.remove.assert_has_calls([call(fd)]) hub.add_reader.assert_called_with(fd, x.on_readable, fd) - self.assertEqual(x._fds[fd], READ) + assert x._fds[fd] == READ # POLL_OUT hub = x.hub = Mock(name='hub') x._handle_socket(_pycurl.POLL_OUT, fd, x._multi, None, _pycurl) hub.add_writer.assert_called_with(fd, x.on_writable, fd) - self.assertEqual(x._fds[fd], WRITE) + assert x._fds[fd] == WRITE # POLL_INOUT hub = x.hub = Mock(name='hub') x._handle_socket(_pycurl.POLL_INOUT, fd, x._multi, None, _pycurl) hub.add_reader.assert_called_with(fd, x.on_readable, fd) hub.add_writer.assert_called_with(fd, x.on_writable, fd) - self.assertEqual(x._fds[fd], READ | WRITE) + assert x._fds[fd] == READ | WRITE # UNKNOWN EVENT hub = x.hub = Mock(name='hub') diff --git a/kombu/tests/async/http/test_http.py b/t/unit/async/http/test_http.py index 6177a05c..9e279142 100644 --- a/kombu/tests/async/http/test_http.py +++ b/t/unit/async/http/test_http.py @@ -1,38 +1,42 @@ from __future__ import absolute_import, unicode_literals +import pytest + from io import BytesIO from vine import promise +from case import Mock, skip + from kombu.async import http from kombu.async.http.base import BaseClient, normalize_header from kombu.exceptions import HttpError -from kombu.tests.case import HubCase, Mock, PromiseMock, skip +from t.mocks import PromiseMock -class test_Headers(HubCase): +class test_Headers: def test_normalize(self): - self.assertEqual(normalize_header('accept-encoding'), - 'Accept-Encoding') + assert normalize_header('accept-encoding') == 'Accept-Encoding' -class test_Request(HubCase): +@pytest.mark.usefixtures('hub') +class test_Request: def test_init(self): x = http.Request('http://foo', method='POST') - self.assertEqual(x.url, 'http://foo') - self.assertEqual(x.method, 'POST') + assert x.url == 'http://foo' + assert x.method == 'POST' x = http.Request('x', max_redirects=100) - self.assertEqual(x.max_redirects, 100) + assert x.max_redirects == 100 - self.assertIsInstance(x.headers, http.Headers) + assert isinstance(x.headers, http.Headers) h = http.Headers() x = http.Request('x', headers=h) - self.assertIs(x.headers, h) - self.assertIsInstance(x.on_ready, promise) + assert x.headers is h + assert isinstance(x.on_ready, promise) def test_then(self): callback = PromiseMock(name='callback') @@ -43,23 +47,24 @@ class test_Request(HubCase): callback.assert_called_with(1) -class test_Response(HubCase): +@pytest.mark.usefixtures('hub') +class test_Response: def test_init(self): req = http.Request('http://foo') r = http.Response(req, 200) - self.assertEqual(r.status, 'OK') - self.assertEqual(r.effective_url, 'http://foo') + assert r.status == 'OK' + assert r.effective_url == 'http://foo' r.raise_for_error() def test_raise_for_error(self): req = http.Request('http://foo') r = http.Response(req, 404) - self.assertEqual(r.status, 'Not Found') - self.assertTrue(r.error) + assert r.status == 'Not Found' + assert r.error - with self.assertRaises(HttpError): + with pytest.raises(HttpError): r.raise_for_error() def test_get_body(self): @@ -68,21 +73,25 @@ class test_Response(HubCase): req.buffer.write(b'hello') rn = http.Response(req, 200, buffer=None) - self.assertIsNone(rn.body) + assert rn.body is None r = http.Response(req, 200, buffer=req.buffer) - self.assertIsNone(r._body) - self.assertEqual(r.body, b'hello') - self.assertEqual(r._body, b'hello') - self.assertEqual(r.body, b'hello') + assert r._body is None + assert r.body == b'hello' + assert r._body == b'hello' + assert r.body == b'hello' + +class test_BaseClient: -class test_BaseClient(HubCase): + @pytest.fixture(autouse=True) + def setup_hub(self, hub): + self.hub = hub def test_init(self): c = BaseClient(Mock(name='hub')) - self.assertTrue(c.hub) - self.assertTrue(c._header_parser) + assert c.hub + assert c._header_parser def test_perform(self): c = BaseClient(Mock(name='hub')) @@ -90,7 +99,7 @@ class test_BaseClient(HubCase): c.perform('http://foo') c.add_request.assert_called() - self.assertIsInstance(c.add_request.call_args[0][0], http.Request) + assert isinstance(c.add_request.call_args[0][0], http.Request) req = http.Request('http://bar') c.perform(req) @@ -98,7 +107,7 @@ class test_BaseClient(HubCase): def test_add_request(self): c = BaseClient(Mock(name='hub')) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): c.add_request(Mock(name='request')) def test_header_parser(self): @@ -109,23 +118,21 @@ class test_BaseClient(HubCase): c.on_header(headers, 'HTTP/1.1') c.on_header(headers, 'x-foo-bar: 123') c.on_header(headers, 'People: George Costanza') - self.assertEqual(headers._prev_key, 'People') + assert headers._prev_key == 'People' c.on_header(headers, ' Jerry Seinfeld') c.on_header(headers, ' Elaine Benes') c.on_header(headers, ' Cosmo Kramer') - self.assertFalse(headers.complete) + assert not headers.complete c.on_header(headers, '') - self.assertTrue(headers.complete) + assert headers.complete - with self.assertRaises(KeyError): + with pytest.raises(KeyError): parser.throw(KeyError('foo')) c.on_header(headers, '') - self.assertEqual(headers['X-Foo-Bar'], '123') - self.assertEqual( - headers['People'], - 'George Costanza Jerry Seinfeld Elaine Benes Cosmo Kramer', - ) + assert headers['X-Foo-Bar'] == '123' + assert (headers['People'] == + 'George Costanza Jerry Seinfeld Elaine Benes Cosmo Kramer') def test_close(self): BaseClient(Mock(name='hub')).close() @@ -140,11 +147,11 @@ class test_BaseClient(HubCase): @skip.if_pypy() @skip.unless_module('pycurl') -class test_Client(HubCase): +class test_Client: - def test_get_client(self): + def test_get_client(self, hub): client = http.get_client() - self.assertIs(client.hub, self.hub) - client2 = http.get_client(self.hub) - self.assertIs(client2, client) - self.assertIs(client2.hub, self.hub) + assert client.hub is hub + client2 = http.get_client(hub) + assert client2 is client + assert client2.hub is hub diff --git a/kombu/tests/async/test_hub.py b/t/unit/async/test_hub.py index 94cf6b04..961de8fe 100644 --- a/kombu/tests/async/test_hub.py +++ b/t/unit/async/test_hub.py @@ -1,7 +1,9 @@ from __future__ import absolute_import, unicode_literals import errno +import pytest +from case import Mock, call, patch from vine import promise from kombu.async import hub as _hub @@ -13,8 +15,6 @@ from kombu.async.hub import ( ) from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore -from kombu.tests.case import Case, Mock, call, patch - class File(object): @@ -33,63 +33,60 @@ class File(object): return hash(self.fd) -class test_DummyLock(Case): - - def test_context(self): - mutex = DummyLock() - with mutex: - pass +def test_DummyLock(): + with DummyLock(): + pass -class test_LaxBoundedSemaphore(Case): +class test_LaxBoundedSemaphore: def test_acquire_release(self): x = LaxBoundedSemaphore(2) c1 = Mock() x.acquire(c1, 1) - self.assertEqual(x.value, 1) + assert x.value == 1 c1.assert_called_with(1) c2 = Mock() x.acquire(c2, 2) - self.assertEqual(x.value, 0) + assert x.value == 0 c2.assert_called_with(2) c3 = Mock() x.acquire(c3, 3) - self.assertEqual(x.value, 0) + assert x.value == 0 c3.assert_not_called() x.release() - self.assertEqual(x.value, 0) + assert x.value == 0 x.release() - self.assertEqual(x.value, 1) + assert x.value == 1 x.release() - self.assertEqual(x.value, 2) + assert x.value == 2 c3.assert_called_with(3) def test_repr(self): - self.assertTrue(repr(LaxBoundedSemaphore(2))) + assert repr(LaxBoundedSemaphore(2)) def test_bounded(self): x = LaxBoundedSemaphore(2) for i in range(100): x.release() - self.assertEqual(x.value, 2) + assert x.value == 2 def test_grow_shrink(self): x = LaxBoundedSemaphore(1) - self.assertEqual(x.initial_value, 1) + assert x.initial_value == 1 cb1 = Mock() x.acquire(cb1, 1) cb1.assert_called_with(1) - self.assertEqual(x.value, 0) + assert x.value == 0 cb2 = Mock() x.acquire(cb2, 2) cb2.assert_not_called() - self.assertEqual(x.value, 0) + assert x.value == 0 cb3 = Mock() x.acquire(cb3, 3) @@ -98,39 +95,39 @@ class test_LaxBoundedSemaphore(Case): x.grow(2) cb2.assert_called_with(2) cb3.assert_called_with(3) - self.assertEqual(x.value, 2) - self.assertEqual(x.initial_value, 3) + assert x.value == 2 + assert x.initial_value == 3 - self.assertFalse(x._waiting) + assert not x._waiting x.grow(3) for i in range(x.initial_value): - self.assertTrue(x.acquire(Mock())) - self.assertFalse(x.acquire(Mock())) + assert x.acquire(Mock()) + assert not x.acquire(Mock()) x.clear() x.shrink(3) for i in range(x.initial_value): - self.assertTrue(x.acquire(Mock())) - self.assertFalse(x.acquire(Mock())) - self.assertEqual(x.value, 0) + assert x.acquire(Mock()) + assert not x.acquire(Mock()) + assert x.value == 0 for i in range(100): x.release() - self.assertEqual(x.value, x.initial_value) + assert x.value == x.initial_value def test_clear(self): x = LaxBoundedSemaphore(10) for i in range(11): x.acquire(Mock()) - self.assertTrue(x._waiting) - self.assertEqual(x.value, 0) + assert x._waiting + assert x.value == 0 x.clear() - self.assertFalse(x._waiting) - self.assertEqual(x.value, x.initial_value) + assert not x._waiting + assert x.value == x.initial_value -class test_Utils(Case): +class test_Utils: def setup(self): self._prev_loop = get_event_loop() @@ -140,23 +137,23 @@ class test_Utils(Case): def test_get_set_event_loop(self): set_event_loop(None) - self.assertIsNone(_hub._current_loop) - self.assertIsNone(get_event_loop()) + assert _hub._current_loop is None + assert get_event_loop() is None hub = Hub() set_event_loop(hub) - self.assertIs(_hub._current_loop, hub) - self.assertIs(get_event_loop(), hub) + assert _hub._current_loop is hub + assert get_event_loop() is hub def test_dummy_context(self): with _dummy_context(): pass def test_raise_stop_error(self): - with self.assertRaises(Stop): + with pytest.raises(Stop): _raise_stop_error() -class test_Hub(Case): +class test_Hub: def setup(self): self.hub = Hub() @@ -179,7 +176,7 @@ class test_Hub(Case): poller = self.hub.poller = Mock(name='poller') self.hub._close_poller() poller.close.assert_called_with() - self.assertIsNone(self.hub.poller) + assert self.hub.poller is None def test_stop(self): self.hub.call_soon = Mock(name='call_soon') @@ -191,14 +188,14 @@ class test_Hub(Case): callback = Mock(name='callback') ret = self.hub.call_soon(callback, 1, 2, 3) promise.assert_called_with(callback, (1, 2, 3)) - self.assertIn(promise(), self.hub._ready) - self.assertIs(ret, promise()) + assert promise() in self.hub._ready + assert ret is promise() def test_call_soon__promise_argument(self): callback = promise(Mock(name='callback'), (1, 2, 3)) ret = self.hub.call_soon(callback) - self.assertIs(ret, callback) - self.assertIn(ret, self.hub._ready) + assert ret is callback + assert ret in self.hub._ready def test_call_later(self): callback = Mock(name='callback') @@ -213,24 +210,24 @@ class test_Hub(Case): self.hub.timer.call_at.assert_called_with(21231122, callback, (1, 2)) def test_repr(self): - self.assertTrue(repr(self.hub)) + assert repr(self.hub) def test_repr_flag(self): - self.assertEqual(repr_flag(READ), 'R') - self.assertEqual(repr_flag(WRITE), 'W') - self.assertEqual(repr_flag(ERR), '!') - self.assertEqual(repr_flag(READ | WRITE), 'RW') - self.assertEqual(repr_flag(READ | ERR), 'R!') - self.assertEqual(repr_flag(WRITE | ERR), 'W!') - self.assertEqual(repr_flag(READ | WRITE | ERR), 'RW!') + assert repr_flag(READ) == 'R' + assert repr_flag(WRITE) == 'W' + assert repr_flag(ERR) == '!' + assert repr_flag(READ | WRITE) == 'RW' + assert repr_flag(READ | ERR) == 'R!' + assert repr_flag(WRITE | ERR) == 'W!' + assert repr_flag(READ | WRITE | ERR) == 'RW!' def test_repr_callback_rcb(self): def f(): pass - self.assertEqual(_rcb(f), f.__name__) - self.assertEqual(_rcb('foo'), 'foo') + assert _rcb(f) == f.__name__ + assert _rcb('foo') == 'foo' @patch('kombu.async.hub.poll') def test_start_stop(self, poll): @@ -245,14 +242,12 @@ class test_Hub(Case): def test_fire_timers(self): self.hub.timer = Mock() self.hub.timer._queue = [] - self.assertEqual( - self.hub.fire_timers(min_delay=42.324, max_delay=32.321), - 32.321, - ) + assert self.hub.fire_timers( + min_delay=42.324, max_delay=32.321) == 32.321 self.hub.timer._queue = [1] self.hub.scheduler = iter([(3.743, None)]) - self.assertEqual(self.hub.fire_timers(), 3.743) + assert self.hub.fire_timers() == 3.743 e1, e2, e3 = Mock(), Mock(), Mock() entries = [e1, e2, e3] @@ -267,21 +262,19 @@ class test_Hub(Case): yield 3.982, None self.hub.scheduler = se() - self.assertEqual(self.hub.fire_timers(max_timers=10), 3.982) + assert self.hub.fire_timers(max_timers=10) == 3.982 for E in [e3, e2, e1]: E.assert_called_with() reset() entries[:] = [Mock() for _ in range(11)] keep = list(entries) - self.assertEqual( - self.hub.fire_timers(max_timers=10, min_delay=1.13), - 1.13, - ) + assert self.hub.fire_timers( + max_timers=10, min_delay=1.13) == 1.13 for E in reversed(keep[1:]): E.assert_called_with() reset() - self.assertEqual(self.hub.fire_timers(max_timers=10), 3.982) + assert self.hub.fire_timers(max_timers=10) == 3.982 keep[0].assert_called_with() def test_fire_timers_raises(self): @@ -289,32 +282,32 @@ class test_Hub(Case): eback.side_effect = KeyError('foo') self.hub.timer = Mock() self.hub.scheduler = iter([(0, eback)]) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.hub.fire_timers(propagate=(KeyError,)) eback.side_effect = ValueError('foo') self.hub.scheduler = iter([(0, eback)]) with patch('kombu.async.hub.logger') as logger: - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): self.hub.fire_timers() logger.error.assert_called() eback.side_effect = MemoryError('foo') self.hub.scheduler = iter([(0, eback)]) - with self.assertRaises(MemoryError): + with pytest.raises(MemoryError): self.hub.fire_timers() eback.side_effect = OSError() eback.side_effect.errno = errno.ENOMEM self.hub.scheduler = iter([(0, eback)]) - with self.assertRaises(OSError): + with pytest.raises(OSError): self.hub.fire_timers() eback.side_effect = OSError() eback.side_effect.errno = errno.ENOENT self.hub.scheduler = iter([(0, eback)]) with patch('kombu.async.hub.logger') as logger: - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): self.hub.fire_timers() logger.error.assert_called() @@ -322,7 +315,7 @@ class test_Hub(Case): self.hub.poller = Mock(name='hub.poller') self.hub.poller.register.side_effect = ValueError() self.hub._discard = Mock(name='hub.discard') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.hub.add(2, Mock(), READ) self.hub._discard.assert_called_with(2) @@ -331,34 +324,34 @@ class test_Hub(Case): self.hub.add(2, Mock(), READ) self.hub.add(2, Mock(), WRITE) self.hub.remove_reader(2) - self.assertNotIn(2, self.hub.readers) - self.assertIn(2, self.hub.writers) + assert 2 not in self.hub.readers + assert 2 in self.hub.writers def test_remove_reader__not_writeable(self): self.hub.poller = Mock(name='hub.poller') self.hub.add(2, Mock(), READ) self.hub.remove_reader(2) - self.assertNotIn(2, self.hub.readers) + assert 2 not in self.hub.readers def test_remove_writer(self): self.hub.poller = Mock(name='hub.poller') self.hub.add(2, Mock(), READ) self.hub.add(2, Mock(), WRITE) self.hub.remove_writer(2) - self.assertIn(2, self.hub.readers) - self.assertNotIn(2, self.hub.writers) + assert 2 in self.hub.readers + assert 2 not in self.hub.writers def test_remove_writer__not_readable(self): self.hub.poller = Mock(name='hub.poller') self.hub.add(2, Mock(), WRITE) self.hub.remove_writer(2) - self.assertNotIn(2, self.hub.writers) + assert 2 not in self.hub.writers def test_add__consolidate(self): self.hub.poller = Mock(name='hub.poller') self.hub.add(2, Mock(), WRITE, consolidate=True) - self.assertIn(2, self.hub.consolidate) - self.assertIsNone(self.hub.writers[2]) + assert 2 in self.hub.consolidate + assert self.hub.writers[2] is None @patch('kombu.async.hub.logger') def test_on_callback_error(self, logger): @@ -368,8 +361,8 @@ class test_Hub(Case): def test_loop_property(self): self.hub._loop = None self.hub.create_loop = Mock(name='hub.create_loop') - self.assertIs(self.hub.loop, self.hub.create_loop()) - self.assertIs(self.hub._loop, self.hub.create_loop()) + assert self.hub.loop is self.hub.create_loop() + assert self.hub._loop is self.hub.create_loop() def test_run_forever(self): self.hub.run_once = Mock(name='hub.run_once') @@ -380,7 +373,7 @@ class test_Hub(Case): self.hub._loop = iter([1]) self.hub.run_once() self.hub.run_once() - self.assertIsNone(self.hub._loop) + assert self.hub._loop is None def test_repr_active(self): self.hub.readers = {1: Mock(), 2: Mock()} @@ -388,7 +381,7 @@ class test_Hub(Case): for value in list( self.hub.readers.values()) + list(self.hub.writers.values()): value.__name__ = 'mock' - self.assertTrue(self.hub.repr_active()) + assert self.hub.repr_active() def test_repr_events(self): self.hub.readers = {6: Mock(), 7: Mock(), 8: Mock()} @@ -396,24 +389,24 @@ class test_Hub(Case): for value in list( self.hub.readers.values()) + list(self.hub.writers.values()): value.__name__ = 'mock' - self.assertTrue(self.hub.repr_events([ + assert self.hub.repr_events([ (6, READ), (7, ERR), (8, READ | ERR), (9, WRITE), (10, 13213), - ])) + ]) def test_callback_for(self): reader, writer = Mock(), Mock() self.hub.readers = {6: reader} self.hub.writers = {7: writer} - self.assertEqual(callback_for(self.hub, 6, READ), reader) - self.assertEqual(callback_for(self.hub, 7, WRITE), writer) - with self.assertRaises(KeyError): + assert callback_for(self.hub, 6, READ) == reader + assert callback_for(self.hub, 7, WRITE) == writer + with pytest.raises(KeyError): callback_for(self.hub, 6, WRITE) - self.assertEqual(callback_for(self.hub, 6, WRITE, 'foo'), 'foo') + assert callback_for(self.hub, 6, WRITE, 'foo') == 'foo' def test_add_remove_readers(self): P = self.hub.poller = Mock() @@ -428,13 +421,13 @@ class test_Hub(Case): call(11, self.hub.READ | self.hub.ERR), ], any_order=True) - self.assertEqual(self.hub.readers[10], (read_A, (10,))) - self.assertEqual(self.hub.readers[11], (read_B, (11,))) + assert self.hub.readers[10] == (read_A, (10,)) + assert self.hub.readers[11] == (read_B, (11,)) self.hub.remove(10) - self.assertNotIn(10, self.hub.readers) + assert 10 not in self.hub.readers self.hub.remove(File(11)) - self.assertNotIn(11, self.hub.readers) + assert 11 not in self.hub.readers P.unregister.assert_has_calls([ call(10), call(11), ]) @@ -463,13 +456,13 @@ class test_Hub(Case): call(21, self.hub.WRITE), ], any_order=True) - self.assertEqual(self.hub.writers[20], (write_A, ())) - self.assertEqual(self.hub.writers[21], (write_B, ())) + assert self.hub.writers[20], (write_A == ()) + assert self.hub.writers[21], (write_B == ()) self.hub.remove(20) - self.assertNotIn(20, self.hub.writers) + assert 20 not in self.hub.writers self.hub.remove(File(21)) - self.assertNotIn(21, self.hub.writers) + assert 21 not in self.hub.writers P.unregister.assert_has_calls([ call(20), call(21), ]) @@ -488,13 +481,13 @@ class test_Hub(Case): write_B = Mock() self.hub.add_writer(20, write_A) self.hub.add_writer(File(21), write_B) - self.assertTrue(self.hub.readers) - self.assertTrue(self.hub.writers) + assert self.hub.readers + assert self.hub.writers finally: assert self.hub.poller self.hub.close() - self.assertFalse(self.hub.readers) - self.assertFalse(self.hub.writers) + assert not self.hub.readers + assert not self.hub.writers P.unregister.assert_has_calls([ call(10), call(11), call(20), call(21), @@ -504,7 +497,7 @@ class test_Hub(Case): def test_scheduler_property(self): hub = Hub(timer=[1, 2, 3]) - self.assertEqual(list(hub.scheduler), [1, 2, 3]) + assert list(hub.scheduler), [1, 2 == 3] def test_loop__tick_callbacks(self): self.hub._ready = Mock(name='_ready') @@ -512,7 +505,7 @@ class test_Hub(Case): ticks = [Mock(name='cb1'), Mock(name='cb2')] self.hub.on_tick = list(ticks) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): next(self.hub.loop) ticks[0].assert_called_once_with() @@ -528,7 +521,7 @@ class test_Hub(Case): self.hub.call_soon(cb) self.hub._ready.add(None) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): next(self.hub.loop) callbacks[0].assert_called_once_with() diff --git a/kombu/tests/async/test_semaphore.py b/t/unit/async/test_semaphore.py index 703acf61..3db792fa 100644 --- a/kombu/tests/async/test_semaphore.py +++ b/t/unit/async/test_semaphore.py @@ -2,10 +2,8 @@ from __future__ import absolute_import, unicode_literals from kombu.async.semaphore import LaxBoundedSemaphore -from kombu.tests.case import Case - -class test_LaxBoundedSemaphore(Case): +class test_LaxBoundedSemaphore: def test_over_release(self): x = LaxBoundedSemaphore(2) @@ -17,29 +15,29 @@ class test_LaxBoundedSemaphore(Case): x.release() x.acquire(calls.append, 'y') - self.assertEqual(calls, [1, 2, 3, 4]) + assert calls, [1, 2, 3 == 4] for i in range(30): x.release() - self.assertEqual(calls, list(range(1, 21)) + ['x', 'y']) - self.assertEqual(x.value, x.initial_value) + assert calls, list(range(1, 21)) + ['x' == 'y'] + assert x.value == x.initial_value calls[:] = [] for i in range(1, 11): x.acquire(calls.append, i) for i in range(1, 11): x.release() - self.assertEqual(calls, list(range(1, 11))) + assert calls, list(range(1 == 11)) calls[:] = [] - self.assertEqual(x.value, x.initial_value) + assert x.value == x.initial_value x.acquire(calls.append, 'x') - self.assertEqual(x.value, 1) + assert x.value == 1 x.acquire(calls.append, 'y') - self.assertEqual(x.value, 0) + assert x.value == 0 x.release() - self.assertEqual(x.value, 1) + assert x.value == 1 x.release() - self.assertEqual(x.value, 2) + assert x.value == 2 x.release() - self.assertEqual(x.value, 2) + assert x.value == 2 diff --git a/kombu/tests/async/test_timer.py b/t/unit/async/test_timer.py index f52fc93e..589ac54b 100644 --- a/kombu/tests/async/test_timer.py +++ b/t/unit/async/test_timer.py @@ -1,50 +1,46 @@ from __future__ import absolute_import, unicode_literals +import pytest + from datetime import datetime -from kombu.five import bytes_if_py2 +from case import Mock, patch from kombu.async.timer import Entry, Timer, to_timestamp - -from kombu.tests.case import Case, Mock, mock, patch +from kombu.five import bytes_if_py2 -class test_to_timestamp(Case): +class test_to_timestamp: def test_timestamp(self): - self.assertIs(to_timestamp(3.13), 3.13) + assert to_timestamp(3.13) is 3.13 def test_datetime(self): - self.assertTrue(to_timestamp(datetime.utcnow())) + assert to_timestamp(datetime.utcnow()) -class test_Entry(Case): +class test_Entry: def test_call(self): - scratch = [None] - - def timed(x, y, moo='foo'): - scratch[0] = (x, y, moo) - - tref = Entry(timed, (4, 4), {'moo': 'baz'}) + fun = Mock(name='fun') + tref = Entry(fun, (4, 4), {'moo': 'baz'}) tref() - - self.assertTupleEqual(scratch[0], (4, 4, 'baz')) + fun.assert_called_with(4, 4, moo='baz') def test_cancel(self): tref = Entry(lambda x: x, (1,), {}) - self.assertFalse(tref.canceled) - self.assertFalse(tref.cancelled) + assert not tref.canceled + assert not tref.cancelled tref.cancel() - self.assertTrue(tref.canceled) - self.assertTrue(tref.cancelled) + assert tref.canceled + assert tref.cancelled def test_repr(self): tref = Entry(lambda x: x(1,), {}) - self.assertTrue(repr(tref)) + assert repr(tref) def test_hash(self): - self.assertTrue(hash(Entry(lambda: None))) + assert hash(Entry(lambda: None)) def test_ordering(self): # we don't care about results, just that it's possible @@ -56,11 +52,11 @@ class test_Entry(Case): def test_eq(self): x = Entry(lambda x: 1) y = Entry(lambda x: 1) - self.assertEqual(x, x) - self.assertNotEqual(x, y) + assert x == x + assert x != y -class test_Timer(Case): +class test_Timer: def test_enter_exit(self): x = Timer() @@ -77,14 +73,11 @@ class test_Timer(Case): x.cancel(tref) tref.cancel.assert_called_with() - self.assertIs(x.schedule, x) + assert x.schedule is x def test_handle_error(self): from datetime import datetime - scratch = [None] - - def on_error(exc_info): - scratch[0] = exc_info + on_error = Mock(name='on_error') s = Timer(on_error=on_error) @@ -94,11 +87,12 @@ class test_Timer(Case): eta=datetime.now()) s.enter_at(Entry(lambda: None, (), {}), eta=None) s.on_error = None - with self.assertRaises(OverflowError): + with pytest.raises(OverflowError): s.enter_at(Entry(lambda: None, (), {}), eta=datetime.now()) - exc = scratch[0] - self.assertIsInstance(exc, OverflowError) + on_error.assert_called_once() + exc = on_error.call_args[0][0] + assert isinstance(exc, OverflowError) def test_call_repeatedly(self): t = Timer() @@ -109,20 +103,20 @@ class test_Timer(Case): myfun.__name__ = bytes_if_py2('myfun') t.call_repeatedly(0.03, myfun) - self.assertEqual(t.schedule.enter_after.call_count, 1) + assert t.schedule.enter_after.call_count == 1 args1, _ = t.schedule.enter_after.call_args_list[0] sec1, tref1, _ = args1 - self.assertEqual(sec1, 0.03) + assert sec1 == 0.03 tref1() - self.assertEqual(t.schedule.enter_after.call_count, 2) + assert t.schedule.enter_after.call_count == 2 args2, _ = t.schedule.enter_after.call_args_list[1] sec2, tref2, _ = args2 - self.assertEqual(sec2, 0.03) + assert sec2 == 0.03 tref2.canceled = True tref2() - self.assertEqual(t.schedule.enter_after.call_count, 2) + assert t.schedule.enter_after.call_count == 2 finally: t.stop() @@ -137,8 +131,7 @@ class test_Timer(Case): t.schedule.apply_entry(fun) logger.error.assert_called() - @mock.stdouts - def test_apply_entry_error_not_handled(self, stdout, stderr): + def test_apply_entry_error_not_handled(self, stdouts): t = Timer() t.schedule.on_error = Mock() @@ -146,7 +139,7 @@ class test_Timer(Case): fun.side_effect = ValueError() t.schedule.apply_entry(fun) fun.assert_called_with() - self.assertFalse(stderr.getvalue()) + assert not stdouts.stderr.getvalue() def test_enter_after(self): t = Timer() diff --git a/kombu/tests/case.py b/t/unit/case.py index 057229e1..057229e1 100644 --- a/kombu/tests/case.py +++ b/t/unit/case.py diff --git a/t/unit/test_clocks.py b/t/unit/test_clocks.py new file mode 100644 index 00000000..5ed30bf1 --- /dev/null +++ b/t/unit/test_clocks.py @@ -0,0 +1,89 @@ +from __future__ import absolute_import, unicode_literals + +import pickle + +from heapq import heappush +from time import time + +from case import Mock + +from kombu.clocks import LamportClock, timetuple + + +class test_LamportClock: + + def test_clocks(self): + c1 = LamportClock() + c2 = LamportClock() + + c1.forward() + c2.forward() + c1.forward() + c1.forward() + c2.adjust(c1.value) + assert c2.value == c1.value + 1 + assert repr(c1) + + c2_val = c2.value + c2.forward() + c2.forward() + c2.adjust(c1.value) + assert c2.value == c2_val + 2 + 1 + + c1.adjust(c2.value) + assert c1.value == c2.value + 1 + + def test_sort(self): + c = LamportClock() + pid1 = 'a.example.com:312' + pid2 = 'b.example.com:311' + + events = [] + + m1 = (c.forward(), pid1) + heappush(events, m1) + m2 = (c.forward(), pid2) + heappush(events, m2) + m3 = (c.forward(), pid1) + heappush(events, m3) + m4 = (30, pid1) + heappush(events, m4) + m5 = (30, pid2) + heappush(events, m5) + + assert str(c) == str(c.value) + + assert c.sort_heap(events) == m1 + assert c.sort_heap([m4, m5]) == m4 + assert c.sort_heap([m4, m5, m1]) == m4 + + +class test_timetuple: + + def test_repr(self): + x = timetuple(133, time(), 'id', Mock()) + assert repr(x) + + def test_pickleable(self): + x = timetuple(133, time(), 'id', 'obj') + assert pickle.loads(pickle.dumps(x)) == tuple(x) + + def test_order(self): + t1 = time() + t2 = time() + 300 # windows clock not reliable + a = timetuple(133, t1, 'A', 'obj') + b = timetuple(140, t1, 'A', 'obj') + assert a.__getnewargs__() + assert a.clock == 133 + assert a.timestamp == t1 + assert a.id == 'A' + assert a.obj == 'obj' + assert a <= b + assert b >= a + + assert (timetuple(134, time(), 'A', 'obj').__lt__(tuple()) is + NotImplemented) + assert timetuple(134, t2, 'A', 'obj') > timetuple(133, t1, 'A', 'obj') + assert timetuple(134, t1, 'B', 'obj') > timetuple(134, t1, 'A', 'obj') + assert (timetuple(None, t2, 'B', 'obj') > + timetuple(None, t1, 'A', 'obj')) diff --git a/kombu/tests/test_common.py b/t/unit/test_common.py index acbba0eb..8aee3d9e 100644 --- a/kombu/tests/test_common.py +++ b/t/unit/test_common.py @@ -1,8 +1,10 @@ from __future__ import absolute_import, unicode_literals +import pytest import socket from amqp import RecoverableConnectionError +from case import ContextMock, Mock, patch from kombu import common from kombu.common import ( @@ -12,64 +14,61 @@ from kombu.common import ( QoS, PREFETCH_COUNT_MAX, ) -from .case import Case, ContextMock, Mock, MockPool, patch +from t.mocks import MockPool -class test_ignore_errors(Case): +def test_ignore_errors(): + connection = Mock() + connection.channel_errors = (KeyError,) + connection.connection_errors = (KeyError,) - def test_ignored(self): - connection = Mock() - connection.channel_errors = (KeyError,) - connection.connection_errors = (KeyError,) + with ignore_errors(connection): + raise KeyError() - with ignore_errors(connection): - raise KeyError() - - def raising(): - raise KeyError() + def raising(): + raise KeyError() - ignore_errors(connection, raising) + ignore_errors(connection, raising) - connection.channel_errors = connection.connection_errors = \ - () + connection.channel_errors = connection.connection_errors = () - with self.assertRaises(KeyError): - with ignore_errors(connection): - raise KeyError() + with pytest.raises(KeyError): + with ignore_errors(connection): + raise KeyError() -class test_declaration_cached(Case): +class test_declaration_cached: def test_when_cached(self): chan = Mock() chan.connection.client.declared_entities = ['foo'] - self.assertTrue(declaration_cached('foo', chan)) + assert declaration_cached('foo', chan) def test_when_not_cached(self): chan = Mock() chan.connection.client.declared_entities = ['bar'] - self.assertFalse(declaration_cached('foo', chan)) + assert not declaration_cached('foo', chan) -class test_Broadcast(Case): +class test_Broadcast: def test_arguments(self): q = Broadcast(name='test_Broadcast') - self.assertTrue(q.name.startswith('bcast.')) - self.assertEqual(q.alias, 'test_Broadcast') - self.assertTrue(q.auto_delete) - self.assertEqual(q.exchange.name, 'test_Broadcast') - self.assertEqual(q.exchange.type, 'fanout') + assert q.name.startswith('bcast.') + assert q.alias == 'test_Broadcast' + assert q.auto_delete + assert q.exchange.name == 'test_Broadcast' + assert q.exchange.type == 'fanout' q = Broadcast('test_Broadcast', 'explicit_queue_name') - self.assertEqual(q.name, 'explicit_queue_name') - self.assertEqual(q.exchange.name, 'test_Broadcast') + assert q.name == 'explicit_queue_name' + assert q.exchange.name == 'test_Broadcast' q2 = q(Mock()) - self.assertEqual(q2.name, q.name) + assert q2.name == q.name -class test_maybe_declare(Case): +class test_maybe_declare: def test_cacheable(self): channel = Mock() @@ -82,16 +81,14 @@ class test_maybe_declare(Case): entity.channel = channel maybe_declare(entity, channel) - self.assertEqual(entity.declare.call_count, 1) - self.assertIn( - hash(entity), channel.connection.client.declared_entities, - ) + assert entity.declare.call_count == 1 + assert hash(entity) in channel.connection.client.declared_entities maybe_declare(entity, channel) - self.assertEqual(entity.declare.call_count, 1) + assert entity.declare.call_count == 1 entity.channel.connection = None - with self.assertRaises(RecoverableConnectionError): + with pytest.raises(RecoverableConnectionError): maybe_declare(entity) def test_binds_entities(self): @@ -116,10 +113,10 @@ class test_maybe_declare(Case): entity.channel = channel maybe_declare(entity, channel, retry=True) - self.assertTrue(channel.connection.client.ensure.call_count) + assert channel.connection.client.ensure.call_count -class test_replies(Case): +class test_replies: def test_send_reply(self): req = Mock() @@ -136,16 +133,18 @@ class test_replies(Case): producer.channel.connection.client.declared_entities = set() send_reply(exchange, req, {'hello': 'world'}, producer) - self.assertTrue(producer.publish.call_count) + assert producer.publish.call_count args = producer.publish.call_args - self.assertDictEqual(args[0][0], {'hello': 'world'}) - self.assertDictEqual(args[1], {'exchange': exchange, - 'routing_key': 'hello', - 'correlation_id': 'world', - 'serializer': 'json', - 'retry': False, - 'retry_policy': None, - 'content_encoding': 'binary'}) + assert args[0][0] == {'hello': 'world'} + assert args[1] == { + 'exchange': exchange, + 'routing_key': 'hello', + 'correlation_id': 'world', + 'serializer': 'json', + 'retry': False, + 'retry_policy': None, + 'content_encoding': 'binary', + } @patch('kombu.common.itermessages') def test_collect_replies_with_ack(self, itermessages): @@ -154,11 +153,11 @@ class test_replies(Case): itermessages.return_value = [(body, message)] it = collect_replies(conn, channel, queue, no_ack=False) m = next(it) - self.assertIs(m, body) + assert m is body itermessages.assert_called_with(conn, channel, queue, no_ack=False) message.ack.assert_called_with() - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) channel.after_reply_message_received.assert_called_with(queue.name) @@ -170,7 +169,7 @@ class test_replies(Case): itermessages.return_value = [(body, message)] it = collect_replies(conn, channel, queue) m = next(it) - self.assertIs(m, body) + assert m is body itermessages.assert_called_with(conn, channel, queue, no_ack=True) message.ack.assert_not_called() @@ -179,13 +178,12 @@ class test_replies(Case): conn, channel, queue = Mock(), Mock(), Mock() itermessages.return_value = [] it = collect_replies(conn, channel, queue) - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) - channel.after_reply_message_received.assert_not_called() -class test_insured(Case): +class test_insured: @patch('kombu.common.logger') def test_ensure_errback(self, logger): @@ -212,22 +210,21 @@ class test_insured(Case): conn, pool, fun, insured = self.get_insured_mocks() ret = common.insured(pool, fun, (2, 2), {'foo': 'bar'}) - self.assertEqual(ret, 'works') + assert ret == 'works' conn.ensure_connection.assert_called_with( errback=common._ensure_errback, ) insured.assert_called() i_args, i_kwargs = insured.call_args - self.assertTupleEqual(i_args, (2, 2)) - self.assertDictEqual(i_kwargs, {'foo': 'bar', - 'connection': conn}) + assert i_args == (2, 2) + assert i_kwargs == {'foo': 'bar', 'connection': conn} conn.autoretry.assert_called() ar_args, ar_kwargs = conn.autoretry.call_args - self.assertTupleEqual(ar_args, (fun, conn.default_channel)) - self.assertTrue(ar_kwargs.get('on_revive')) - self.assertTrue(ar_kwargs.get('errback')) + assert ar_args == (fun, conn.default_channel) + assert ar_kwargs.get('on_revive') + assert ar_kwargs.get('errback') def test_insured_custom_errback(self): conn, pool, fun, insured = self.get_insured_mocks() @@ -254,7 +251,7 @@ class MockConsumer(object): self.consumers.discard(self) -class test_itermessages(Case): +class test_itermessages: class MockConnection(object): should_raise_timeout = False @@ -274,9 +271,9 @@ class test_itermessages(Case): it = common.itermessages(conn, channel, 'q', limit=1) ret = next(it) - self.assertTupleEqual(ret, ('body', 'message')) + assert ret == ('body', 'message') - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) def test_when_raises_socket_timeout(self): @@ -287,7 +284,7 @@ class test_itermessages(Case): conn.Consumer = MockConsumer it = common.itermessages(conn, channel, 'q', limit=1) - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) @patch('kombu.common.deque') @@ -299,11 +296,11 @@ class test_itermessages(Case): conn.Consumer = MockConsumer it = common.itermessages(conn, channel, 'q', limit=1) - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) -class test_QoS(Case): +class test_QoS: class _QoS(QoS): def __init__(self, value): @@ -326,20 +323,20 @@ class test_QoS(Case): def test_qos_increment_decrement(self): qos = self._QoS(10) - self.assertEqual(qos.increment_eventually(), 11) - self.assertEqual(qos.increment_eventually(3), 14) - self.assertEqual(qos.increment_eventually(-30), 14) - self.assertEqual(qos.decrement_eventually(7), 7) - self.assertEqual(qos.decrement_eventually(), 6) + assert qos.increment_eventually() == 11 + assert qos.increment_eventually(3) == 14 + assert qos.increment_eventually(-30) == 14 + assert qos.decrement_eventually(7) == 7 + assert qos.decrement_eventually() == 6 def test_qos_disabled_increment_decrement(self): qos = self._QoS(0) - self.assertEqual(qos.increment_eventually(), 0) - self.assertEqual(qos.increment_eventually(3), 0) - self.assertEqual(qos.increment_eventually(-30), 0) - self.assertEqual(qos.decrement_eventually(7), 0) - self.assertEqual(qos.decrement_eventually(), 0) - self.assertEqual(qos.decrement_eventually(10), 0) + assert qos.increment_eventually() == 0 + assert qos.increment_eventually(3) == 0 + assert qos.increment_eventually(-30) == 0 + assert qos.decrement_eventually(7) == 0 + assert qos.decrement_eventually() == 0 + assert qos.decrement_eventually(10) == 0 def test_qos_thread_safe(self): qos = self._QoS(10) @@ -361,59 +358,59 @@ class test_QoS(Case): thread.join() threaded([add, add]) - self.assertEqual(qos.value, 2010) + assert qos.value == 2010 qos.value = 1000 threaded([add, sub]) # n = 2 - self.assertEqual(qos.value, 1000) + assert qos.value == 1000 def test_exceeds_short(self): qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1) qos.update() - self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1) + assert qos.value == PREFETCH_COUNT_MAX - 1 qos.increment_eventually() - self.assertEqual(qos.value, PREFETCH_COUNT_MAX) + assert qos.value == PREFETCH_COUNT_MAX qos.increment_eventually() - self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1) + assert qos.value == PREFETCH_COUNT_MAX + 1 qos.decrement_eventually() - self.assertEqual(qos.value, PREFETCH_COUNT_MAX) + assert qos.value == PREFETCH_COUNT_MAX qos.decrement_eventually() - self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1) + assert qos.value == PREFETCH_COUNT_MAX - 1 def test_consumer_increment_decrement(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.update() - self.assertEqual(qos.value, 10) + assert qos.value == 10 mconsumer.qos.assert_called_with(prefetch_count=10) qos.decrement_eventually() qos.update() - self.assertEqual(qos.value, 9) + assert qos.value == 9 mconsumer.qos.assert_called_with(prefetch_count=9) qos.decrement_eventually() - self.assertEqual(qos.value, 8) + assert qos.value == 8 mconsumer.qos.assert_called_with(prefetch_count=9) - self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args) + assert {'prefetch_count': 9} in mconsumer.qos.call_args # Does not decrement 0 value qos.value = 0 qos.decrement_eventually() - self.assertEqual(qos.value, 0) + assert qos.value == 0 qos.increment_eventually() - self.assertEqual(qos.value, 0) + assert qos.value == 0 def test_consumer_decrement_eventually(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.decrement_eventually() - self.assertEqual(qos.value, 9) + assert qos.value == 9 qos.value = 0 qos.decrement_eventually() - self.assertEqual(qos.value, 0) + assert qos.value == 0 def test_set(self): mconsumer = Mock() qos = QoS(mconsumer.qos, 10) qos.set(12) - self.assertEqual(qos.prev, 12) + assert qos.prev == 12 qos.set(qos.prev) diff --git a/kombu/tests/test_compat.py b/t/unit/test_compat.py index c8abc85a..485625d0 100644 --- a/kombu/tests/test_compat.py +++ b/t/unit/test_compat.py @@ -1,13 +1,16 @@ from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock, patch + from kombu import Connection, Exchange, Queue from kombu import compat -from .case import Case, Mock, patch -from .mocks import Transport, Channel +from t.mocks import Transport, Channel -class test_misc(Case): +class test_misc: def test_iterconsume(self): @@ -27,11 +30,11 @@ class test_misc(Case): conn = MyConnection() consumer = Consumer() it = compat._iterconsume(conn, consumer) - self.assertEqual(next(it), 1) - self.assertTrue(consumer.active) + assert next(it) == 1 + assert consumer.active it2 = compat._iterconsume(conn, consumer, limit=10) - self.assertEqual(list(it2), [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert list(it2), [2, 3, 4, 5, 6, 7, 8, 9, 10 == 11] def test_Queue_from_dict(self): defs = {'binding_key': 'foo.#', @@ -41,40 +44,40 @@ class test_misc(Case): 'auto_delete': False} q1 = Queue.from_dict('foo', **dict(defs)) - self.assertEqual(q1.name, 'foo') - self.assertEqual(q1.routing_key, 'foo.#') - self.assertEqual(q1.exchange.name, 'fooex') - self.assertEqual(q1.exchange.type, 'topic') - self.assertTrue(q1.durable) - self.assertTrue(q1.exchange.durable) - self.assertFalse(q1.auto_delete) - self.assertFalse(q1.exchange.auto_delete) + assert q1.name == 'foo' + assert q1.routing_key == 'foo.#' + assert q1.exchange.name == 'fooex' + assert q1.exchange.type == 'topic' + assert q1.durable + assert q1.exchange.durable + assert not q1.auto_delete + assert not q1.exchange.auto_delete q2 = Queue.from_dict('foo', **dict(defs, exchange_durable=False)) - self.assertTrue(q2.durable) - self.assertFalse(q2.exchange.durable) + assert q2.durable + assert not q2.exchange.durable q3 = Queue.from_dict('foo', **dict(defs, exchange_auto_delete=True)) - self.assertFalse(q3.auto_delete) - self.assertTrue(q3.exchange.auto_delete) + assert not q3.auto_delete + assert q3.exchange.auto_delete q4 = Queue.from_dict('foo', **dict(defs, queue_durable=False)) - self.assertFalse(q4.durable) - self.assertTrue(q4.exchange.durable) + assert not q4.durable + assert q4.exchange.durable q5 = Queue.from_dict('foo', **dict(defs, queue_auto_delete=True)) - self.assertTrue(q5.auto_delete) - self.assertFalse(q5.exchange.auto_delete) + assert q5.auto_delete + assert not q5.exchange.auto_delete - self.assertEqual(Queue.from_dict('foo', **dict(defs)), - Queue.from_dict('foo', **dict(defs))) + assert (Queue.from_dict('foo', **dict(defs)) == + Queue.from_dict('foo', **dict(defs))) -class test_Publisher(Case): +class test_Publisher: def setup(self): self.connection = Connection(transport=Transport) @@ -83,25 +86,25 @@ class test_Publisher(Case): pub = compat.Publisher(self.connection, exchange='test_Publisher_constructor', routing_key='rkey') - self.assertIsInstance(pub.backend, Channel) - self.assertEqual(pub.exchange.name, 'test_Publisher_constructor') - self.assertTrue(pub.exchange.durable) - self.assertFalse(pub.exchange.auto_delete) - self.assertEqual(pub.exchange.type, 'direct') + assert isinstance(pub.backend, Channel) + assert pub.exchange.name == 'test_Publisher_constructor' + assert pub.exchange.durable + assert not pub.exchange.auto_delete + assert pub.exchange.type == 'direct' pub2 = compat.Publisher(self.connection, exchange='test_Publisher_constructor2', routing_key='rkey', auto_delete=True, durable=False) - self.assertTrue(pub2.exchange.auto_delete) - self.assertFalse(pub2.exchange.durable) + assert pub2.exchange.auto_delete + assert not pub2.exchange.durable explicit = Exchange('test_Publisher_constructor_explicit', type='topic') pub3 = compat.Publisher(self.connection, exchange=explicit) - self.assertEqual(pub3.exchange, explicit) + assert pub3.exchange == explicit compat.Publisher(self.connection, exchange='test_Publisher_constructor3', @@ -112,7 +115,7 @@ class test_Publisher(Case): exchange='test_Publisher_send', routing_key='rkey') pub.send({'foo': 'bar'}) - self.assertIn('basic_publish', pub.backend) + assert 'basic_publish' in pub.backend pub.close() def test__enter__exit__(self): @@ -120,12 +123,12 @@ class test_Publisher(Case): exchange='test_Publisher_send', routing_key='rkey') x = pub.__enter__() - self.assertIs(x, pub) + assert x is pub x.__exit__() - self.assertTrue(pub._closed) + assert pub._closed -class test_Consumer(Case): +class test_Consumer: def setup(self): self.connection = Connection(transport=Transport) @@ -139,39 +142,39 @@ class test_Consumer(Case): def test_constructor(self, n='test_Consumer_constructor'): c = compat.Consumer(self.connection, queue=n, exchange=n, routing_key='rkey') - self.assertIsInstance(c.backend, Channel) + assert isinstance(c.backend, Channel) q = c.queues[0] - self.assertTrue(q.durable) - self.assertTrue(q.exchange.durable) - self.assertFalse(q.auto_delete) - self.assertFalse(q.exchange.auto_delete) - self.assertEqual(q.name, n) - self.assertEqual(q.exchange.name, n) + assert q.durable + assert q.exchange.durable + assert not q.auto_delete + assert not q.exchange.auto_delete + assert q.name == n + assert q.exchange.name == n c2 = compat.Consumer(self.connection, queue=n + '2', exchange=n + '2', routing_key='rkey', durable=False, auto_delete=True, exclusive=True) q2 = c2.queues[0] - self.assertFalse(q2.durable) - self.assertFalse(q2.exchange.durable) - self.assertTrue(q2.auto_delete) - self.assertTrue(q2.exchange.auto_delete) + assert not q2.durable + assert not q2.exchange.durable + assert q2.auto_delete + assert q2.exchange.auto_delete def test__enter__exit__(self, n='test__enter__exit__'): c = compat.Consumer(self.connection, queue=n, exchange=n, routing_key='rkey') x = c.__enter__() - self.assertIs(x, c) + assert x is c x.__exit__() - self.assertTrue(c._closed) + assert c._closed def test_revive(self, n='test_revive'): c = compat.Consumer(self.connection, queue=n, exchange=n) with self.connection.channel() as c2: c.revive(c2) - self.assertIs(c.backend, c2) + assert c.backend is c2 def test__iter__(self, n='test__iter__'): c = compat.Consumer(self.connection, queue=n, exchange=n) @@ -188,7 +191,7 @@ class test_Consumer(Case): def test_process_next(self, n='test_process_next'): c = compat.Consumer(self.connection, queue=n, exchange=n, routing_key='rkey') - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): c.process_next() c.close() @@ -201,14 +204,14 @@ class test_Consumer(Case): c = compat.Consumer(self.connection, queue=n, exchange=n, routing_key='rkey') c.discard_all() - self.assertIn('queue_purge', c.backend) + assert 'queue_purge' in c.backend def test_fetch(self, n='test_fetch'): c = compat.Consumer(self.connection, queue=n, exchange=n, routing_key='rkey') - self.assertIsNone(c.fetch()) - self.assertIsNone(c.fetch(no_ack=True)) - self.assertIn('basic_get', c.backend) + assert c.fetch() is None + assert c.fetch(no_ack=True) is None + assert 'basic_get' in c.backend callback_called = [False] @@ -217,16 +220,16 @@ class test_Consumer(Case): c.backend.to_deliver.append('42') payload = c.fetch().payload - self.assertEqual(payload, '42') + assert payload == '42' c.backend.to_deliver.append('46') c.register_callback(receive) - self.assertEqual(c.fetch(enable_callbacks=True).payload, '46') - self.assertTrue(callback_called[0]) + assert c.fetch(enable_callbacks=True).payload == '46' + assert callback_called[0] def test_discard_all_filterfunc_not_supported(self, n='xjf21j21'): c = compat.Consumer(self.connection, queue=n, exchange=n, routing_key='rkey') - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): c.discard_all(filterfunc=lambda x: x) c.close() @@ -240,7 +243,7 @@ class test_Consumer(Case): c = C(self.connection, queue=n, exchange=n, routing_key='rkey') - self.assertEqual(c.wait(10), list(range(10))) + assert c.wait(10) == list(range(10)) c.close() def test_iterqueue(self, n='test_iterqueue'): @@ -255,11 +258,11 @@ class test_Consumer(Case): c = C(self.connection, queue=n, exchange=n, routing_key='rkey') - self.assertEqual(list(c.iterqueue(limit=10)), list(range(10))) + assert list(c.iterqueue(limit=10)) == list(range(10)) c.close() -class test_ConsumerSet(Case): +class test_ConsumerSet: def setup(self): self.connection = Connection(transport=Transport) @@ -267,8 +270,8 @@ class test_ConsumerSet(Case): def test_providing_channel(self): chan = Mock(name='channel') cs = compat.ConsumerSet(self.connection, channel=chan) - self.assertTrue(cs._provided_channel) - self.assertIs(cs.backend, chan) + assert cs._provided_channel + assert cs.backend is chan cs.cancel = Mock(name='cancel') cs.close() @@ -287,7 +290,7 @@ class test_ConsumerSet(Case): with self.connection.channel() as c2: cs.revive(c2) - self.assertIs(cs.backend, c2) + assert cs.backend is c2 def test_constructor(self, prefix='0daf8h21'): dcon = {'%s.xyx' % prefix: {'exchange': '%s.xyx' % prefix, @@ -300,31 +303,31 @@ class test_ConsumerSet(Case): c = compat.ConsumerSet(self.connection, consumers=consumers) c2 = compat.ConsumerSet(self.connection, from_dict=dcon) - self.assertEqual(len(c.queues), 3) - self.assertEqual(len(c2.queues), 2) + assert len(c.queues) == 3 + assert len(c2.queues) == 2 c.add_consumer(compat.Consumer(self.connection, queue=prefix + 'xaxxxa', exchange=prefix + 'xaxxxa')) - self.assertEqual(len(c.queues), 4) + assert len(c.queues) == 4 for cq in c.queues: - self.assertIs(cq.channel, c.channel) + assert cq.channel is c.channel c2.add_consumer_from_dict( '%s.xxx' % prefix, exchange='%s.xxx' % prefix, routing_key='xxx', ) - self.assertEqual(len(c2.queues), 3) + assert len(c2.queues) == 3 for c2q in c2.queues: - self.assertIs(c2q.channel, c2.channel) + assert c2q.channel is c2.channel c.discard_all() - self.assertEqual(c.channel.called.count('queue_purge'), 4) + assert c.channel.called.count('queue_purge') == 4 c.consume() c.close() c2.close() - self.assertIn('basic_cancel', c.channel) - self.assertIn('close', c.channel) - self.assertIn('close', c2.channel) + assert 'basic_cancel' in c.channel + assert 'close' in c.channel + assert 'close' in c2.channel diff --git a/kombu/tests/test_compression.py b/t/unit/test_compression.py index d98da690..f3fc625e 100644 --- a/kombu/tests/test_compression.py +++ b/t/unit/test_compression.py @@ -2,41 +2,41 @@ from __future__ import absolute_import, unicode_literals import sys -from kombu import compression +from case import mock, skip -from .case import Case, mock, skip +from kombu import compression -class test_compression(Case): +class test_compression: @mock.mask_modules('bz2') def test_no_bz2(self): c = sys.modules.pop('kombu.compression') try: import kombu.compression - self.assertFalse(hasattr(kombu.compression, 'bz2')) + assert not hasattr(kombu.compression, 'bz2') finally: if c is not None: sys.modules['kombu.compression'] = c def test_encoders__gzip(self): - self.assertIn('application/x-gzip', compression.encoders()) + assert 'application/x-gzip' in compression.encoders() @skip.unless_module('bz2') def test_encoders__bz2(self): - self.assertIn('application/x-bz2', compression.encoders()) + assert 'application/x-bz2' in compression.encoders() def test_compress__decompress__zlib(self): text = b'The Quick Brown Fox Jumps Over The Lazy Dog' c, ctype = compression.compress(text, 'zlib') - self.assertNotEqual(text, c) + assert text != c d = compression.decompress(c, ctype) - self.assertEqual(d, text) + assert d == text @skip.unless_module('bz2') def test_compress__decompress__bzip2(self): text = b'The Brown Quick Fox Over The Lazy Dog Jumps' c, ctype = compression.compress(text, 'bzip2') - self.assertNotEqual(text, c) + assert text != c d = compression.decompress(c, ctype) - self.assertEqual(d, text) + assert d == text diff --git a/kombu/tests/test_connection.py b/t/unit/test_connection.py index 548263d8..576181cf 100644 --- a/kombu/tests/test_connection.py +++ b/t/unit/test_connection.py @@ -1,21 +1,23 @@ from __future__ import absolute_import, unicode_literals import pickle +import pytest import socket from copy import copy, deepcopy +from case import Mock, patch, skip + from kombu import Connection, Consumer, Producer, parse_url from kombu.connection import Resource from kombu.exceptions import OperationalError from kombu.five import items, range from kombu.utils.functional import lazy -from .case import Case, Mock, patch, skip -from .mocks import Transport +from t.mocks import Transport -class test_connection_utils(Case): +class test_connection_utils: def setup(self): self.url = 'amqp://user:pass@localhost:5672/my/vhost' @@ -31,106 +33,74 @@ class test_connection_utils(Case): def test_parse_url(self): result = parse_url(self.url) - self.assertDictEqual(result, self.expected) + assert result == self.expected def test_parse_generated_as_uri(self): conn = Connection(self.url) info = conn.info() for k, v in self.expected.items(): - self.assertEqual(info[k], v) + assert info[k] == v # by default almost the same- no password - self.assertEqual(conn.as_uri(), self.nopass) - self.assertEqual(conn.as_uri(include_password=True), self.url) + assert conn.as_uri() == self.nopass + assert conn.as_uri(include_password=True) == self.url @skip.unless_module('redis') def test_as_uri_when_prefix(self): conn = Connection('redis+socket:///var/spool/x/y/z/redis.sock') - self.assertEqual( - conn.as_uri(), 'redis+socket:///var/spool/x/y/z/redis.sock', - ) + assert conn.as_uri() == 'redis+socket:///var/spool/x/y/z/redis.sock' @skip.unless_module('pymongo') def test_as_uri_when_mongodb(self): x = Connection('mongodb://localhost') - self.assertTrue(x.as_uri()) + assert x.as_uri() def test_bogus_scheme(self): - with self.assertRaises(KeyError): + with pytest.raises(KeyError): Connection('bogus://localhost:7421').transport def assert_info(self, conn, **fields): info = conn.info() for field, expected in items(fields): - self.assertEqual(info[field], expected) - - def test_rabbitmq_example_urls(self): + assert info[field] == expected + + @pytest.mark.parametrize('url,expected', [ + ('amqp://user:pass@host:10000/vhost', + dict(userid='user', password='pass', hostname='host', + port=10000, virtual_host='vhost')), + ('amqp://user%61:%61pass@ho%61st:10000/v%2fhost', + dict(userid='usera', password='apass', hostname='hoast', + port=10000, virtual_host='v/host')), + ('amqp://', + dict(userid='guest', password='guest', hostname='localhost', + port=5672, virtual_host='/')), + ('amqp://:@/', + dict(userid='guest', password='guest', hostname='localhost', + port=5672, virtual_host='/')), + ('amqp://user@/', + dict(userid='user', password='guest', hostname='localhost', + port=5672, virtual_host='/')), + ('amqp://user:pass@/', + dict(userid='user', password='pass', hostname='localhost', + port=5672, virtual_host='/')), + ('amqp://host', + dict(userid='guest', password='guest', hostname='host', + port=5672, virtual_host='/')), + ('amqp://:10000', + dict(userid='guest', password='guest', hostname='localhost', + port=10000, virtual_host='/')), + ('amqp:///vhost', + dict(userid='guest', password='guest', hostname='localhost', + port=5672, virtual_host='vhost')), + ('amqp://host/', + dict(userid='guest', password='guest', hostname='host', + port=5672, virtual_host='/')), + ('amqp://host/%2f', + dict(userid='guest', password='guest', hostname='host', + port=5672, virtual_host='/')), + ]) + def test_rabbitmq_example_urls(self, url, expected): # see Appendix A of http://www.rabbitmq.com/uri-spec.html - - self.assert_info( - Connection('amqp://user:pass@host:10000/vhost'), - userid='user', password='pass', hostname='host', - port=10000, virtual_host='vhost', - ) - - self.assert_info( - Connection('amqp://user%61:%61pass@ho%61st:10000/v%2fhost'), - userid='usera', password='apass', hostname='hoast', - port=10000, virtual_host='v/host', - ) - - self.assert_info( - Connection('amqp://'), - userid='guest', password='guest', hostname='localhost', - port=5672, virtual_host='/', - ) - - self.assert_info( - Connection('amqp://:@/'), - userid='guest', password='guest', hostname='localhost', - port=5672, virtual_host='/', - ) - - self.assert_info( - Connection('amqp://user@/'), - userid='user', password='guest', hostname='localhost', - port=5672, virtual_host='/', - ) - - self.assert_info( - Connection('amqp://user:pass@/'), - userid='user', password='pass', hostname='localhost', - port=5672, virtual_host='/', - ) - - self.assert_info( - Connection('amqp://host'), - userid='guest', password='guest', hostname='host', - port=5672, virtual_host='/', - ) - - self.assert_info( - Connection('amqp://:10000'), - userid='guest', password='guest', hostname='localhost', - port=10000, virtual_host='/', - ) - - self.assert_info( - Connection('amqp:///vhost'), - userid='guest', password='guest', hostname='localhost', - port=5672, virtual_host='vhost', - ) - - self.assert_info( - Connection('amqp://host/'), - userid='guest', password='guest', hostname='host', - port=5672, virtual_host='/', - ) - - self.assert_info( - Connection('amqp://host/%2f'), - userid='guest', password='guest', hostname='host', - port=5672, virtual_host='/', - ) + self.assert_info(Connection(url), **expected) @skip.todo('urllib cannot parse ipv6 urls') def test_url_IPV6(self): @@ -143,10 +113,10 @@ class test_connection_utils(Case): def test_connection_copy(self): conn = Connection(self.url, alternates=['amqp://host']) clone = deepcopy(conn) - self.assertEqual(clone.alt, ['amqp://host']) + assert clone.alt == ['amqp://host'] -class test_Connection(Case): +class test_Connection: def setup(self): self.conn = Connection(port=5672, transport=Transport) @@ -154,24 +124,24 @@ class test_Connection(Case): def test_establish_connection(self): conn = self.conn conn.connect() - self.assertTrue(conn.connection.connected) - self.assertEqual(conn.host, 'localhost:5672') + assert conn.connection.connected + assert conn.host == 'localhost:5672' channel = conn.channel() - self.assertTrue(channel.open) - self.assertEqual(conn.drain_events(), 'event') + assert channel.open + assert conn.drain_events() == 'event' _connection = conn.connection conn.close() - self.assertFalse(_connection.connected) - self.assertIsInstance(conn.transport, Transport) + assert not _connection.connected + assert isinstance(conn.transport, Transport) def test_multiple_urls(self): conn1 = Connection('amqp://foo;amqp://bar') - self.assertEqual(conn1.hostname, 'foo') - self.assertListEqual(conn1.alt, ['amqp://foo', 'amqp://bar']) + assert conn1.hostname == 'foo' + assert conn1.alt == ['amqp://foo', 'amqp://bar'] conn2 = Connection(['amqp://foo', 'amqp://bar']) - self.assertEqual(conn2.hostname, 'foo') - self.assertListEqual(conn2.alt, ['amqp://foo', 'amqp://bar']) + assert conn2.hostname == 'foo' + assert conn2.alt == ['amqp://foo', 'amqp://bar'] def test_collect(self): connection = Connection('memory://') @@ -185,9 +155,9 @@ class test_Connection(Case): _close.assert_not_called() _collect.assert_called_with(uconn) connection.declared_entities.clear.assert_called_with() - self.assertIsNone(trans.client) - self.assertIsNone(connection._transport) - self.assertIsNone(connection._connection) + assert trans.client is None + assert connection._transport is None + assert connection._connection is None def test_collect_no_transport(self): connection = Connection('memory://') @@ -213,7 +183,7 @@ class test_Connection(Case): connection.collect() collect.assert_called_with(uconn) - self.assertIsNone(connection._transport) + assert connection._transport is None def test_uri_passthrough(self): transport = Mock(name='transport') @@ -222,17 +192,17 @@ class test_Connection(Case): transport.can_parse_url = True with patch('kombu.connection.parse_url') as parse_url: c = Connection('foo+mysql://some_host') - self.assertEqual(c.transport_cls, 'foo') + assert c.transport_cls == 'foo' parse_url.assert_not_called() - self.assertEqual(c.hostname, 'mysql://some_host') - self.assertTrue(c.as_uri().startswith('foo+')) + assert c.hostname == 'mysql://some_host' + assert c.as_uri().startswith('foo+') with patch('kombu.connection.parse_url') as parse_url: c = Connection('mysql://some_host', transport='foo') - self.assertEqual(c.transport_cls, 'foo') + assert c.transport_cls == 'foo' parse_url.assert_not_called() - self.assertEqual(c.hostname, 'mysql://some_host') + assert c.hostname == 'mysql://some_host' c = Connection('pyamqp+sqlite://some_host') - self.assertTrue(c.as_uri().startswith('pyamqp+')) + assert c.as_uri().startswith('pyamqp+') def test_default_ensure_callback(self): with patch('kombu.connection.logger') as logger: @@ -249,31 +219,31 @@ class test_Connection(Case): args = rot.call_args[0] cb = args[4] intervals = iter([1, 2, 3, 4, 5]) - self.assertEqual(cb(KeyError(), intervals, 0), 0) - self.assertEqual(cb(KeyError(), intervals, 1), 1) - self.assertEqual(cb(KeyError(), intervals, 2), 0) - self.assertEqual(cb(KeyError(), intervals, 3), 2) - self.assertEqual(cb(KeyError(), intervals, 4), 0) - self.assertEqual(cb(KeyError(), intervals, 5), 3) - self.assertEqual(cb(KeyError(), intervals, 6), 0) - self.assertEqual(cb(KeyError(), intervals, 7), 4) + assert cb(KeyError(), intervals, 0) == 0 + assert cb(KeyError(), intervals, 1) == 1 + assert cb(KeyError(), intervals, 2) == 0 + assert cb(KeyError(), intervals, 3) == 2 + assert cb(KeyError(), intervals, 4) == 0 + assert cb(KeyError(), intervals, 5) == 3 + assert cb(KeyError(), intervals, 6) == 0 + assert cb(KeyError(), intervals, 7) == 4 errback = Mock() c.ensure_connection(errback=errback) args = rot.call_args[0] cb = args[4] - self.assertEqual(cb(KeyError(), intervals, 0), 0) + assert cb(KeyError(), intervals, 0) == 0 errback.assert_called() def test_supports_heartbeats(self): c = Connection(transport=Mock) c.transport.implements.heartbeats = False - self.assertFalse(c.supports_heartbeats) + assert not c.supports_heartbeats def test_is_evented(self): c = Connection(transport=Mock) c.transport.implements.async = False - self.assertFalse(c.is_evented) + assert not c.is_evented def test_register_with_event_loop(self): c = Connection(transport=Mock) @@ -285,41 +255,41 @@ class test_Connection(Case): def test_manager(self): c = Connection(transport=Mock) - self.assertIs(c.manager, c.transport.manager) + assert c.manager is c.transport.manager def test_copy(self): c = Connection('amqp://example.com') - self.assertEqual(copy(c).info(), c.info()) + assert copy(c).info() == c.info() def test_copy_multiples(self): c = Connection('amqp://A.example.com;amqp://B.example.com') - self.assertTrue(c.alt) + assert c.alt d = copy(c) - self.assertEqual(d.alt, c.alt) + assert d.alt == c.alt def test_switch(self): c = Connection('amqp://foo') c._closed = True c.switch('redis://example.com//3') - self.assertFalse(c._closed) - self.assertEqual(c.hostname, 'example.com') - self.assertEqual(c.transport_cls, 'redis') - self.assertEqual(c.virtual_host, '/3') + assert not c._closed + assert c.hostname == 'example.com' + assert c.transport_cls == 'redis' + assert c.virtual_host == '/3' def test_maybe_switch_next(self): c = Connection('amqp://foo;redis://example.com//3') c.maybe_switch_next() - self.assertFalse(c._closed) - self.assertEqual(c.hostname, 'example.com') - self.assertEqual(c.transport_cls, 'redis') - self.assertEqual(c.virtual_host, '/3') + assert not c._closed + assert c.hostname == 'example.com' + assert c.transport_cls == 'redis' + assert c.virtual_host == '/3' def test_maybe_switch_next_no_cycle(self): c = Connection('amqp://foo') c.maybe_switch_next() - self.assertFalse(c._closed) - self.assertEqual(c.hostname, 'foo') - self.assertIn(c.transport_cls, ('librabbitmq', 'pyamqp', 'amqp')) + assert not c._closed + assert c.hostname == 'foo' + assert c.transport_cls, ('librabbitmq', 'pyamqp' in 'amqp') def test_heartbeat_check(self): c = Connection(transport=Transport) @@ -329,45 +299,40 @@ class test_Connection(Case): def test_completes_cycle_no_cycle(self): c = Connection('amqp://') - self.assertTrue(c.completes_cycle(0)) - self.assertTrue(c.completes_cycle(1)) + assert c.completes_cycle(0) + assert c.completes_cycle(1) def test_completes_cycle(self): c = Connection('amqp://a;amqp://b;amqp://c') - self.assertFalse(c.completes_cycle(0)) - self.assertFalse(c.completes_cycle(1)) - self.assertTrue(c.completes_cycle(2)) + assert not c.completes_cycle(0) + assert not c.completes_cycle(1) + assert c.completes_cycle(2) def test_get_heartbeat_interval(self): self.conn.transport.get_heartbeat_interval = Mock(name='ghi') - self.assertIs( - self.conn.get_heartbeat_interval(), - self.conn.transport.get_heartbeat_interval.return_value, - ) + assert (self.conn.get_heartbeat_interval() is + self.conn.transport.get_heartbeat_interval.return_value) self.conn.transport.get_heartbeat_interval.assert_called_with( self.conn.connection) def test_supports_exchange_type(self): self.conn.transport.implements.exchange_type = {'topic'} - self.assertTrue(self.conn.supports_exchange_type('topic')) - self.assertFalse(self.conn.supports_exchange_type('fanout')) + assert self.conn.supports_exchange_type('topic') + assert not self.conn.supports_exchange_type('fanout') def test_qos_semantics_matches_spec(self): qsms = self.conn.transport.qos_semantics_matches_spec = Mock() - self.assertIs( - self.conn.qos_semantics_matches_spec, - qsms.return_value, - ) + assert self.conn.qos_semantics_matches_spec is qsms.return_value qsms.assert_called_with(self.conn.connection) def test__enter____exit__(self): conn = self.conn context = conn.__enter__() - self.assertIs(context, conn) + assert context is conn conn.connect() - self.assertTrue(conn.connection.connected) + assert conn.connection.connected conn.__exit__() - self.assertIsNone(conn.connection) + assert conn.connection is None conn.close() # again def test_close_survives_connerror(self): @@ -384,7 +349,7 @@ class test_Connection(Case): conn = Connection(transport=MyTransport) conn.connect() conn.close() - self.assertTrue(conn._closed) + assert conn._closed def test_close_when_default_channel(self): conn = self.conn @@ -413,17 +378,17 @@ class test_Connection(Case): conn.revive(Mock()) defchan.close.assert_called_with() - self.assertIsNone(conn._default_channel) + assert conn._default_channel is None def test_ensure_connection(self): - self.assertTrue(self.conn.ensure_connection()) + assert self.conn.ensure_connection() def test_ensure_success(self): def publish(): return 'foobar' ensured = self.conn.ensure(None, publish) - self.assertEqual(ensured(), 'foobar') + assert ensured() == 'foobar' def test_ensure_failure(self): class _CustomError(Exception): @@ -433,7 +398,7 @@ class test_Connection(Case): raise _CustomError('bar') ensured = self.conn.ensure(None, publish) - with self.assertRaises(_CustomError): + with pytest.raises(_CustomError): ensured() def test_ensure_connection_failure(self): @@ -445,7 +410,7 @@ class test_Connection(Case): self.conn.transport.connection_errors = (_ConnectionError,) ensured = self.conn.ensure(self.conn, publish) - with self.assertRaises(OperationalError): + with pytest.raises(OperationalError): ensured() def test_autoretry(self): @@ -466,36 +431,37 @@ class test_Connection(Case): def test_SimpleQueue(self): conn = self.conn q = conn.SimpleQueue('foo') - self.assertIs(q.channel, conn.default_channel) + assert q.channel is conn.default_channel chan = conn.channel() q2 = conn.SimpleQueue('foo', channel=chan) - self.assertIs(q2.channel, chan) + assert q2.channel is chan def test_SimpleBuffer(self): conn = self.conn q = conn.SimpleBuffer('foo') - self.assertIs(q.channel, conn.default_channel) + assert q.channel is conn.default_channel chan = conn.channel() q2 = conn.SimpleBuffer('foo', channel=chan) - self.assertIs(q2.channel, chan) + assert q2.channel is chan def test_Producer(self): conn = self.conn - self.assertIsInstance(conn.Producer(), Producer) - self.assertIsInstance(conn.Producer(conn.default_channel), Producer) + assert isinstance(conn.Producer(), Producer) + assert isinstance(conn.Producer(conn.default_channel), Producer) def test_Consumer(self): conn = self.conn - self.assertIsInstance(conn.Consumer(queues=[]), Consumer) - self.assertIsInstance(conn.Consumer(queues=[], - channel=conn.default_channel), Consumer) + assert isinstance(conn.Consumer(queues=[]), Consumer) + assert isinstance( + conn.Consumer(queues=[], channel=conn.default_channel), + Consumer) def test__repr__(self): - self.assertTrue(repr(self.conn)) + assert repr(self.conn) def test__reduce__(self): x = pickle.loads(pickle.dumps(self.conn)) - self.assertDictEqual(x.info(), self.conn.info()) + assert x.info() == self.conn.info() def test_channel_errors(self): @@ -503,7 +469,7 @@ class test_Connection(Case): channel_errors = (KeyError, ValueError) conn = Connection(transport=MyTransport) - self.assertTupleEqual(conn.channel_errors, (KeyError, ValueError)) + assert conn.channel_errors == (KeyError, ValueError) def test_connection_errors(self): @@ -511,10 +477,10 @@ class test_Connection(Case): connection_errors = (KeyError, ValueError) conn = Connection(transport=MyTransport) - self.assertTupleEqual(conn.connection_errors, (KeyError, ValueError)) + assert conn.connection_errors == (KeyError, ValueError) -class test_Connection_with_transport_options(Case): +class test_Connection_with_transport_options: transport_options = {'pool_recycler': 3600, 'echo': True} @@ -524,7 +490,7 @@ class test_Connection_with_transport_options(Case): def test_establish_connection(self): conn = self.conn - self.assertEqual(conn.transport_options, self.transport_options) + assert conn.transport_options == self.transport_options class xResource(Resource): @@ -533,56 +499,46 @@ class xResource(Resource): pass -class ResourceCase(Case): - abstract = True +class ResourceCase: def create_resource(self, limit): raise NotImplementedError('subclass responsibility') - def assertState(self, P, avail, dirty): - self.assertEqual(P._resource.qsize(), avail) - self.assertEqual(len(P._dirty), dirty) + def assert_state(self, P, avail, dirty): + assert P._resource.qsize() == avail + assert len(P._dirty) == dirty def test_setup(self): - if self.abstract: - with self.assertRaises(NotImplementedError): - Resource() + with pytest.raises(NotImplementedError): + Resource() def test_acquire__release(self): - if self.abstract: - return P = self.create_resource(10) - self.assertState(P, 10, 0) + self.assert_state(P, 10, 0) chans = [P.acquire() for _ in range(10)] - self.assertState(P, 0, 10) - with self.assertRaises(P.LimitExceeded): + self.assert_state(P, 0, 10) + with pytest.raises(P.LimitExceeded): P.acquire() chans.pop().release() - self.assertState(P, 1, 9) + self.assert_state(P, 1, 9) [chan.release() for chan in chans] - self.assertState(P, 10, 0) + self.assert_state(P, 10, 0) def test_acquire_prepare_raises(self): - if self.abstract: - return P = self.create_resource(10) - self.assertEqual(len(P._resource.queue), 10) + assert len(P._resource.queue) == 10 P.prepare = Mock() P.prepare.side_effect = IOError() - with self.assertRaises(IOError): + with pytest.raises(IOError): P.acquire(block=True) - self.assertEqual(len(P._resource.queue), 10) + assert len(P._resource.queue) == 10 def test_acquire_no_limit(self): - if self.abstract: - return P = self.create_resource(None) P.acquire().release() def test_replace_when_limit(self): - if self.abstract: - return P = self.create_resource(10) r = P.acquire() P._dirty = Mock() @@ -593,8 +549,6 @@ class ResourceCase(Case): P.close_resource.assert_called_with(r) def test_replace_no_limit(self): - if self.abstract: - return P = self.create_resource(None) r = P.acquire() P._dirty = Mock() @@ -605,26 +559,20 @@ class ResourceCase(Case): P.close_resource.assert_called_with(r) def test_interface_prepare(self): - if not self.abstract: - return x = xResource() - self.assertEqual(x.prepare(10), 10) + assert x.prepare(10) == 10 def test_force_close_all_handles_AttributeError(self): - if self.abstract: - return P = self.create_resource(10) cr = P.collect_resource = Mock() cr.side_effect = AttributeError('x') P.acquire() - self.assertTrue(P._dirty) + assert P._dirty P.force_close_all() def test_force_close_all_no_mutex(self): - if self.abstract: - return P = self.create_resource(10) P.close_resource = Mock() @@ -635,17 +583,14 @@ class ResourceCase(Case): P.force_close_all() def test_add_when_empty(self): - if self.abstract: - return P = self.create_resource(None) P._resource.queue.clear() - self.assertFalse(P._resource.queue) + assert not P._resource.queue P._add_when_empty() - self.assertTrue(P._resource.queue) + assert P._resource.queue class test_ConnectionPool(ResourceCase): - abstract = False def create_resource(self, limit): return Connection(port=5672, transport=Transport).Pool(limit) @@ -666,9 +611,9 @@ class test_ConnectionPool(ResourceCase): def test_setup(self): P = self.create_resource(10) q = P._resource.queue - self.assertIsNone(q[0]()._connection) - self.assertIsNone(q[1]()._connection) - self.assertIsNone(q[2]()._connection) + assert q[0]()._connection is None + assert q[1]()._connection is None + assert q[2]()._connection is None def test_acquire_raises_evaluated(self): P = self.create_resource(1) @@ -678,7 +623,7 @@ class test_ConnectionPool(ResourceCase): P.prepare = Mock() P.prepare.side_effect = MemoryError() P.release = Mock() - with self.assertRaises(MemoryError): + with pytest.raises(MemoryError): with P.acquire(): assert False P.release.assert_called_with(r) @@ -691,22 +636,21 @@ class test_ConnectionPool(ResourceCase): def test_setup_no_limit(self): P = self.create_resource(None) - self.assertFalse(P._resource.queue) - self.assertIsNone(P.limit) + assert not P._resource.queue + assert P.limit is None def test_prepare_not_callable(self): P = self.create_resource(None) conn = Connection('memory://') - self.assertIs(P.prepare(conn), conn) + assert P.prepare(conn) is conn def test_acquire_channel(self): P = self.create_resource(10) with P.acquire_channel() as (conn, channel): - self.assertIs(channel, conn.default_channel) + assert channel is conn.default_channel class test_ChannelPool(ResourceCase): - abstract = False def create_resource(self, limit): return Connection(port=5672, transport=Transport).ChannelPool(limit) @@ -714,16 +658,16 @@ class test_ChannelPool(ResourceCase): def test_setup(self): P = self.create_resource(10) q = P._resource.queue - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): q[0].basic_consume def test_setup_no_limit(self): P = self.create_resource(None) - self.assertFalse(P._resource.queue) - self.assertIsNone(P.limit) + assert not P._resource.queue + assert P.limit is None def test_prepare_not_callable(self): P = self.create_resource(10) conn = Connection('memory://') chan = conn.default_channel - self.assertIs(P.prepare(chan), chan) + assert P.prepare(chan) is chan diff --git a/kombu/tests/test_entity.py b/t/unit/test_entity.py index 8a15f5b8..e2312c8c 100644 --- a/kombu/tests/test_entity.py +++ b/t/unit/test_entity.py @@ -1,21 +1,23 @@ from __future__ import absolute_import, unicode_literals import pickle +import pytest + +from case import Mock, call from kombu import Connection, Exchange, Producer, Queue, binding from kombu.abstract import MaybeChannelBound from kombu.exceptions import NotBoundError from kombu.serialization import registry -from .case import Case, Mock, call -from .mocks import Transport +from t.mocks import Transport def get_conn(): return Connection(transport=Transport) -class test_binding(Case): +class test_binding: def test_constructor(self): x = binding( @@ -23,83 +25,80 @@ class test_binding(Case): arguments={'barg': 'bval'}, unbind_arguments={'uarg': 'uval'}, ) - self.assertEqual(x.exchange, Exchange('foo')) - self.assertEqual(x.routing_key, 'rkey') - self.assertDictEqual(x.arguments, {'barg': 'bval'}) - self.assertDictEqual(x.unbind_arguments, {'uarg': 'uval'}) + assert x.exchange == Exchange('foo') + assert x.routing_key == 'rkey' + assert x.arguments == {'barg': 'bval'} + assert x.unbind_arguments == {'uarg': 'uval'} def test_declare(self): chan = get_conn().channel() x = binding(Exchange('foo'), 'rkey') x.declare(chan) - self.assertIn('exchange_declare', chan) + assert 'exchange_declare' in chan def test_declare_no_exchange(self): chan = get_conn().channel() x = binding() x.declare(chan) - self.assertNotIn('exchange_declare', chan) + assert 'exchange_declare' not in chan def test_bind(self): chan = get_conn().channel() x = binding(Exchange('foo')) x.bind(Exchange('bar')(chan)) - self.assertIn('exchange_bind', chan) + assert 'exchange_bind' in chan def test_unbind(self): chan = get_conn().channel() x = binding(Exchange('foo')) x.unbind(Exchange('bar')(chan)) - self.assertIn('exchange_unbind', chan) + assert 'exchange_unbind' in chan def test_repr(self): b = binding(Exchange('foo'), 'rkey') - self.assertIn('foo', repr(b)) - self.assertIn('rkey', repr(b)) + assert 'foo' in repr(b) + assert 'rkey' in repr(b) -class test_Exchange(Case): +class test_Exchange: def test_bound(self): exchange = Exchange('foo', 'direct') - self.assertFalse(exchange.is_bound) - self.assertIn('<unbound', repr(exchange)) + assert not exchange.is_bound + assert '<unbound' in repr(exchange) chan = get_conn().channel() bound = exchange.bind(chan) - self.assertTrue(bound.is_bound) - self.assertIs(bound.channel, chan) - self.assertIn('bound to chan:%r' % (chan.channel_id,), - repr(bound)) + assert bound.is_bound + assert bound.channel is chan + assert 'bound to chan:%r' % (chan.channel_id,) in repr(bound) def test_hash(self): - self.assertEqual(hash(Exchange('a')), hash(Exchange('a'))) - self.assertNotEqual(hash(Exchange('a')), hash(Exchange('b'))) + assert hash(Exchange('a')) == hash(Exchange('a')) + assert hash(Exchange('a')) != hash(Exchange('b')) def test_can_cache_declaration(self): - self.assertTrue(Exchange('a', durable=True).can_cache_declaration) - self.assertTrue(Exchange('a', durable=False).can_cache_declaration) - self.assertFalse(Exchange('a', auto_delete=True).can_cache_declaration) - self.assertFalse( - Exchange( - 'a', durable=True, auto_delete=True - ).can_cache_declaration, - ) + assert Exchange('a', durable=True).can_cache_declaration + assert Exchange('a', durable=False).can_cache_declaration + assert not Exchange('a', auto_delete=True).can_cache_declaration + assert not Exchange( + 'a', durable=True, auto_delete=True, + ).can_cache_declaration def test_pickle(self): e1 = Exchange('foo', 'direct') e2 = pickle.loads(pickle.dumps(e1)) - self.assertEqual(e1, e2) + assert e1 == e2 def test_eq(self): e1 = Exchange('foo', 'direct') e2 = Exchange('foo', 'direct') - self.assertEqual(e1, e2) + assert e1 == e2 e3 = Exchange('foo', 'topic') - self.assertNotEqual(e1, e3) + assert e1 != e3 - self.assertEqual(e1.__eq__(True), NotImplemented) + assert e1.__eq__(True) == NotImplemented def test_revive(self): exchange = Exchange('foo', 'direct') @@ -108,121 +107,121 @@ class test_Exchange(Case): # reviving unbound channel is a noop. exchange.revive(chan) - self.assertFalse(exchange.is_bound) - self.assertIsNone(exchange._channel) + assert not exchange.is_bound + assert exchange._channel is None bound = exchange.bind(chan) - self.assertTrue(bound.is_bound) - self.assertIs(bound.channel, chan) + assert bound.is_bound + assert bound.channel is chan chan2 = conn.channel() bound.revive(chan2) - self.assertTrue(bound.is_bound) - self.assertIs(bound._channel, chan2) + assert bound.is_bound + assert bound._channel is chan2 def test_assert_is_bound(self): exchange = Exchange('foo', 'direct') - with self.assertRaises(NotBoundError): + with pytest.raises(NotBoundError): exchange.declare() conn = get_conn() chan = conn.channel() exchange.bind(chan).declare() - self.assertIn('exchange_declare', chan) + assert 'exchange_declare' in chan def test_set_transient_delivery_mode(self): exc = Exchange('foo', 'direct', delivery_mode='transient') - self.assertEqual(exc.delivery_mode, Exchange.TRANSIENT_DELIVERY_MODE) + assert exc.delivery_mode == Exchange.TRANSIENT_DELIVERY_MODE def test_set_passive_mode(self): exc = Exchange('foo', 'direct', passive=True) - self.assertTrue(exc.passive) + assert exc.passive def test_set_persistent_delivery_mode(self): exc = Exchange('foo', 'direct', delivery_mode='persistent') - self.assertEqual(exc.delivery_mode, Exchange.PERSISTENT_DELIVERY_MODE) + assert exc.delivery_mode == Exchange.PERSISTENT_DELIVERY_MODE def test_bind_at_instantiation(self): - self.assertTrue(Exchange('foo', channel=get_conn().channel()).is_bound) + assert Exchange('foo', channel=get_conn().channel()).is_bound def test_create_message(self): chan = get_conn().channel() Exchange('foo', channel=chan).Message({'foo': 'bar'}) - self.assertIn('prepare_message', chan) + assert 'prepare_message' in chan def test_publish(self): chan = get_conn().channel() Exchange('foo', channel=chan).publish('the quick brown fox') - self.assertIn('basic_publish', chan) + assert 'basic_publish' in chan def test_delete(self): chan = get_conn().channel() Exchange('foo', channel=chan).delete() - self.assertIn('exchange_delete', chan) + assert 'exchange_delete' in chan def test__repr__(self): b = Exchange('foo', 'topic') - self.assertIn('foo(topic)', repr(b)) - self.assertIn('Exchange', repr(b)) + assert 'foo(topic)' in repr(b) + assert 'Exchange' in repr(b) def test_bind_to(self): chan = get_conn().channel() foo = Exchange('foo', 'topic') bar = Exchange('bar', 'topic') foo(chan).bind_to(bar) - self.assertIn('exchange_bind', chan) + assert 'exchange_bind' in chan def test_bind_to_by_name(self): chan = get_conn().channel() foo = Exchange('foo', 'topic') foo(chan).bind_to('bar') - self.assertIn('exchange_bind', chan) + assert 'exchange_bind' in chan def test_unbind_from(self): chan = get_conn().channel() foo = Exchange('foo', 'topic') bar = Exchange('bar', 'topic') foo(chan).unbind_from(bar) - self.assertIn('exchange_unbind', chan) + assert 'exchange_unbind' in chan def test_unbind_from_by_name(self): chan = get_conn().channel() foo = Exchange('foo', 'topic') foo(chan).unbind_from('bar') - self.assertIn('exchange_unbind', chan) + assert 'exchange_unbind' in chan def test_declare__no_declare(self): chan = get_conn().channel() foo = Exchange('foo', 'topic', no_declare=True) foo(chan).declare() - self.assertNotIn('exchange_declare', chan) + assert 'exchange_declare' not in chan def test_declare__internal_exchange(self): chan = get_conn().channel() foo = Exchange('amq.rabbitmq.trace', 'topic') foo(chan).declare() - self.assertNotIn('exchange_declare', chan) + assert 'exchange_declare' not in chan def test_declare(self): chan = get_conn().channel() foo = Exchange('foo', 'topic', no_declare=False) foo(chan).declare() - self.assertIn('exchange_declare', chan) + assert 'exchange_declare' in chan -class test_Queue(Case): +class test_Queue: def setup(self): self.exchange = Exchange('foo', 'direct') def test_hash(self): - self.assertEqual(hash(Queue('a')), hash(Queue('a'))) - self.assertNotEqual(hash(Queue('a')), hash(Queue('b'))) + assert hash(Queue('a')) == hash(Queue('a')) + assert hash(Queue('a')) != hash(Queue('b')) def test_repr_with_bindings(self): ex = Exchange('foo') x = Queue('foo', bindings=[ex.binding('A'), ex.binding('B')]) - self.assertTrue(repr(x)) + assert repr(x) def test_anonymous(self): chan = Mock() @@ -230,7 +229,7 @@ class test_Queue(Case): chan.queue_declare.return_value = 'generated', 0, 0 xx = x(chan) xx.declare() - self.assertEqual(xx.name, 'generated') + assert xx.name == 'generated' def test_basic_get__accept_disallowed(self): conn = Connection('memory://') @@ -242,9 +241,9 @@ class test_Queue(Case): ) message = q(conn).get(no_ack=True) - self.assertIsNotNone(message) + assert message is not None - with self.assertRaises(q.ContentDisallowed): + with pytest.raises(q.ContentDisallowed): message.decode() def test_basic_get__accept_allowed(self): @@ -257,15 +256,15 @@ class test_Queue(Case): ) message = q(conn).get(accept=['pickle'], no_ack=True) - self.assertIsNotNone(message) + assert message is not None payload = message.decode() - self.assertTrue(payload['complex']) + assert payload['complex'] def test_when_bound_but_no_exchange(self): q = Queue('a') q.exchange = None - self.assertIsNone(q.when_bound()) + assert q.when_bound() is None def test_declare_but_no_exchange(self): q = Queue('a') @@ -296,7 +295,7 @@ class test_Queue(Case): chan = Mock() q = Queue('a')(chan) chan.message_to_python = None - self.assertTrue(q.get()) + assert q.get() def test_multiple_bindings(self): chan = Mock() @@ -306,110 +305,105 @@ class test_Queue(Case): binding(Exchange('mul3'), 'rkey3'), ]) q(chan).declare() - self.assertIn( - call( - nowait=False, - exchange='mul1', - auto_delete=False, - passive=False, - arguments=None, - type='direct', - durable=True, - ), - chan.exchange_declare.call_args_list, - ) + assert call( + nowait=False, + exchange='mul1', + auto_delete=False, + passive=False, + arguments=None, + type='direct', + durable=True, + ) in chan.exchange_declare.call_args_list def test_can_cache_declaration(self): - self.assertTrue(Queue('a', durable=True).can_cache_declaration) - self.assertTrue(Queue('a', durable=False).can_cache_declaration) + assert Queue('a', durable=True).can_cache_declaration + assert Queue('a', durable=False).can_cache_declaration def test_eq(self): q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') q2 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx') - self.assertEqual(q1, q2) - self.assertEqual(q1.__eq__(True), NotImplemented) + assert q1 == q2 + assert q1.__eq__(True) == NotImplemented q3 = Queue('yyy', Exchange('xxx', 'direct'), 'xxx') - self.assertNotEqual(q1, q3) + assert q1 != q3 def test_exclusive_implies_auto_delete(self): - self.assertTrue( - Queue('foo', self.exchange, exclusive=True).auto_delete, - ) + assert Queue('foo', self.exchange, exclusive=True).auto_delete def test_binds_at_instantiation(self): - self.assertTrue(Queue('foo', self.exchange, - channel=get_conn().channel()).is_bound) + assert Queue('foo', self.exchange, + channel=get_conn().channel()).is_bound def test_also_binds_exchange(self): chan = get_conn().channel() b = Queue('foo', self.exchange) - self.assertFalse(b.is_bound) - self.assertFalse(b.exchange.is_bound) + assert not b.is_bound + assert not b.exchange.is_bound b = b.bind(chan) - self.assertTrue(b.is_bound) - self.assertTrue(b.exchange.is_bound) - self.assertIs(b.channel, b.exchange.channel) - self.assertIsNot(b.exchange, self.exchange) + assert b.is_bound + assert b.exchange.is_bound + assert b.channel is b.exchange.channel + assert b.exchange is not self.exchange def test_declare(self): chan = get_conn().channel() b = Queue('foo', self.exchange, 'foo', channel=chan) - self.assertTrue(b.is_bound) + assert b.is_bound b.declare() - self.assertIn('exchange_declare', chan) - self.assertIn('queue_declare', chan) - self.assertIn('queue_bind', chan) + assert 'exchange_declare' in chan + assert 'queue_declare' in chan + assert 'queue_bind' in chan def test_get(self): b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.get() - self.assertIn('basic_get', b.channel) + assert 'basic_get' in b.channel def test_purge(self): b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.purge() - self.assertIn('queue_purge', b.channel) + assert 'queue_purge' in b.channel def test_consume(self): b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.consume('fifafo', None) - self.assertIn('basic_consume', b.channel) + assert 'basic_consume' in b.channel def test_cancel(self): b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.cancel('fifafo') - self.assertIn('basic_cancel', b.channel) + assert 'basic_cancel' in b.channel def test_delete(self): b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.delete() - self.assertIn('queue_delete', b.channel) + assert 'queue_delete' in b.channel def test_queue_unbind(self): b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel()) b.queue_unbind() - self.assertIn('queue_unbind', b.channel) + assert 'queue_unbind' in b.channel def test_as_dict(self): q = Queue('foo', self.exchange, 'rk') d = q.as_dict(recurse=True) - self.assertEqual(d['exchange']['name'], self.exchange.name) + assert d['exchange']['name'] == self.exchange.name def test_queue_dump(self): b = binding(self.exchange, 'rk') q = Queue('foo', self.exchange, 'rk', bindings=[b]) d = q.as_dict(recurse=True) - self.assertEqual(d['bindings'][0]['routing_key'], 'rk') + assert d['bindings'][0]['routing_key'] == 'rk' registry.dumps(d) def test__repr__(self): b = Queue('foo', self.exchange, 'foo') - self.assertIn('foo', repr(b)) - self.assertIn('Queue', repr(b)) + assert 'foo' in repr(b) + assert 'Queue' in repr(b) -class test_MaybeChannelBound(Case): +class test_MaybeChannelBound: def test_repr(self): - self.assertTrue(repr(MaybeChannelBound())) + assert repr(MaybeChannelBound()) diff --git a/t/unit/test_exceptions.py b/t/unit/test_exceptions.py new file mode 100644 index 00000000..f72f3d6d --- /dev/null +++ b/t/unit/test_exceptions.py @@ -0,0 +1,11 @@ +from __future__ import absolute_import, unicode_literals + +from case import Mock + +from kombu.exceptions import HttpError + + +class test_HttpError: + + def test_str(self): + assert str(HttpError(200, 'msg', Mock(name='response'))) diff --git a/kombu/tests/test_log.py b/t/unit/test_log.py index 08dc299e..5946755b 100644 --- a/kombu/tests/test_log.py +++ b/t/unit/test_log.py @@ -3,6 +3,8 @@ from __future__ import absolute_import, unicode_literals import logging import sys +from case import ANY, Mock, patch + from kombu.log import ( get_logger, get_loglevel, @@ -12,22 +14,20 @@ from kombu.log import ( setup_logging, ) -from .case import ANY, Case, Mock, patch - -class test_get_logger(Case): +class test_get_logger: def test_when_string(self): l = get_logger('foo') - self.assertIs(l, logging.getLogger('foo')) + assert l is logging.getLogger('foo') h1 = l.handlers[0] - self.assertIsInstance(h1, logging.NullHandler) + assert isinstance(h1, logging.NullHandler) def test_when_logger(self): l = get_logger(logging.getLogger('foo')) h1 = l.handlers[0] - self.assertIsInstance(h1, logging.NullHandler) + assert isinstance(h1, logging.NullHandler) def test_with_custom_handler(self): l = logging.getLogger('bar') @@ -35,28 +35,23 @@ class test_get_logger(Case): l.addHandler(handler) l = get_logger('bar') - self.assertIs(l.handlers[0], handler) + assert l.handlers[0] is handler def test_get_loglevel(self): - self.assertEqual(get_loglevel('DEBUG'), logging.DEBUG) - self.assertEqual(get_loglevel('ERROR'), logging.ERROR) - self.assertEqual(get_loglevel(logging.INFO), logging.INFO) + assert get_loglevel('DEBUG') == logging.DEBUG + assert get_loglevel('ERROR') == logging.ERROR + assert get_loglevel(logging.INFO) == logging.INFO -class test_safe_format(Case): +def test_safe_format(): + fmt = 'The %r jumped %x over the %s' + args = ['frog', 'foo', 'elephant'] - def test_formatting(self): - fmt = 'The %r jumped %x over the %s' - args = ['frog', 'foo', 'elephant'] + res = list(safeify_format(fmt, args)) + assert [x.strip('u') for x in res] == ["'frog'", 'foo', 'elephant'] - res = list(safeify_format(fmt, args)) - self.assertListEqual( - [x.strip('u') for x in res], - ["'frog'", 'foo', 'elephant'], - ) - -class test_LogMixin(Case): +class test_LogMixin: def setup(self): self.log = Log('Log', Mock()) @@ -96,22 +91,20 @@ class test_LogMixin(Case): log.DISABLE_TRACEBACKS = False def test_get_loglevel(self): - self.assertEqual(self.log.get_loglevel('DEBUG'), logging.DEBUG) - self.assertEqual(self.log.get_loglevel('ERROR'), logging.ERROR) - self.assertEqual(self.log.get_loglevel(logging.INFO), logging.INFO) + assert self.log.get_loglevel('DEBUG') == logging.DEBUG + assert self.log.get_loglevel('ERROR') == logging.ERROR + assert self.log.get_loglevel(logging.INFO) == logging.INFO def test_is_enabled_for(self): self.logger.isEnabledFor.return_value = True - self.assertTrue(self.log.is_enabled_for('DEBUG')) + assert self.log.is_enabled_for('DEBUG') self.logger.isEnabledFor.assert_called_with(logging.DEBUG) def test_LogMixin_get_logger(self): - self.assertIs(LogMixin().get_logger(), - logging.getLogger('LogMixin')) + assert LogMixin().get_logger() is logging.getLogger('LogMixin') def test_Log_get_logger(self): - self.assertIs(Log('test_Log').get_logger(), - logging.getLogger('test_Log')) + assert Log('test_Log').get_logger() is logging.getLogger('test_Log') def test_log_when_not_enabled(self): self.logger.isEnabledFor.return_value = False @@ -123,13 +116,10 @@ class test_LogMixin(Case): self.logger.log.assert_called_with( logging.DEBUG, 'Log - Host %s removed', ANY, ) - self.assertEqual( - self.logger.log.call_args[0][2].strip('u'), - "'example.com'", - ) + assert self.logger.log.call_args[0][2].strip('u') == "'example.com'" -class test_setup_logging(Case): +class test_setup_logging: @patch('logging.getLogger') def test_set_up_default_values(self, getLogger): @@ -141,8 +131,8 @@ class test_setup_logging(Case): logger.addHandler.assert_called() ah_args, _ = logger.addHandler.call_args handler = ah_args[0] - self.assertIsInstance(handler, logging.StreamHandler) - self.assertIs(handler.stream, sys.__stderr__) + assert isinstance(handler, logging.StreamHandler) + assert handler.stream is sys.__stderr__ @patch('logging.getLogger') @patch('kombu.log.WatchedFileHandler') diff --git a/kombu/tests/test_message.py b/t/unit/test_message.py index 320e2421..70d35b00 100644 --- a/kombu/tests/test_message.py +++ b/t/unit/test_message.py @@ -1,26 +1,27 @@ from __future__ import absolute_import, unicode_literals +import pytest import sys -from kombu.message import Message +from case import Mock, patch -from .case import Case, Mock, patch +from kombu.message import Message -class test_Message(Case): +class test_Message: def test_repr(self): - self.assertTrue(repr(Message(Mock(), 'b'))) + assert repr(Message(Mock(), 'b')) def test_decode(self): m = Message(Mock(), 'body') decode = m._decode = Mock() - self.assertIsNone(m._decoded_cache) - self.assertIs(m.decode(), m._decode.return_value) - self.assertIs(m._decoded_cache, m._decode.return_value) + assert m._decoded_cache is None + assert m.decode() is m._decode.return_value + assert m._decoded_cache is m._decode.return_value m._decode.assert_called_with() m._decode = Mock() - self.assertIs(m.decode(), decode.return_value) + assert m.decode() is decode.return_value def test_reraise_error(self): m = Message(Mock(), 'body') @@ -32,12 +33,12 @@ class test_Message(Case): m._reraise_error(callback) callback.assert_called() - with self.assertRaises(KeyError): + with pytest.raises(KeyError): m._reraise_error(None) @patch('kombu.message.decompress') def test_decompression_stores_error(self, decompress): decompress.side_effect = RuntimeError() m = Message(Mock(), 'body', headers={'compression': 'zlib'}) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): m._reraise_error(None) diff --git a/kombu/tests/test_messaging.py b/t/unit/test_messaging.py index 8f48f7cd..fbc8d72b 100644 --- a/kombu/tests/test_messaging.py +++ b/t/unit/test_messaging.py @@ -1,41 +1,43 @@ from __future__ import absolute_import, unicode_literals import pickle +import pytest import sys from collections import defaultdict +from case import Mock, patch + from kombu import Connection, Consumer, Producer, Exchange, Queue from kombu.exceptions import MessageStateError from kombu.utils import json from kombu.utils.functional import ChannelPromise -from .case import Case, Mock, patch -from .mocks import Transport +from t.mocks import Transport -class test_Producer(Case): +class test_Producer: def setup(self): self.exchange = Exchange('foo', 'direct') self.connection = Connection(transport=Transport) self.connection.connect() - self.assertTrue(self.connection.connection.connected) - self.assertFalse(self.exchange.is_bound) + assert self.connection.connection.connected + assert not self.exchange.is_bound def test_repr(self): p = Producer(self.connection) - self.assertTrue(repr(p)) + assert repr(p) def test_pickle(self): chan = Mock() producer = Producer(chan, serializer='pickle') p2 = pickle.loads(pickle.dumps(producer)) - self.assertEqual(p2.serializer, producer.serializer) + assert p2.serializer == producer.serializer def test_no_channel(self): p = Producer(None) - self.assertFalse(p._channel) + assert not p._channel @patch('kombu.messaging.maybe_declare') def test_maybe_declare(self, maybe_declare): @@ -53,30 +55,30 @@ class test_Producer(Case): def test_auto_declare(self): channel = self.connection.channel() p = Producer(channel, self.exchange, auto_declare=True) - self.assertIsNot(p.exchange, self.exchange, - 'creates Exchange clone at bind') - self.assertTrue(p.exchange.is_bound) - self.assertIn('exchange_declare', channel, - 'auto_declare declares exchange') + # creates Exchange clone at bind + assert p.exchange is not self.exchange + assert p.exchange.is_bound + # auto_declare declares exchange' + assert 'exchange_declare' in channel def test_manual_declare(self): channel = self.connection.channel() p = Producer(channel, self.exchange, auto_declare=False) - self.assertTrue(p.exchange.is_bound) - self.assertNotIn('exchange_declare', channel, - 'auto_declare=False does not declare exchange') + assert p.exchange.is_bound + # auto_declare=False does not declare exchange + assert 'exchange_declare' not in channel + # p.declare() declares exchange') p.declare() - self.assertIn('exchange_declare', channel, - 'p.declare() declares exchange') + assert 'exchange_declare' in channel def test_prepare(self): message = {'the quick brown fox': 'jumps over the lazy dog'} channel = self.connection.channel() p = Producer(channel, self.exchange, serializer='json') m, ctype, cencoding = p._prepare(message, headers={}) - self.assertDictEqual(message, json.loads(m)) - self.assertEqual(ctype, 'application/json') - self.assertEqual(cencoding, 'utf-8') + assert json.loads(m) == message + assert ctype == 'application/json' + assert cencoding == 'utf-8' def test_prepare_compression(self): message = {'the quick brown fox': 'jumps over the lazy dog'} @@ -85,64 +87,59 @@ class test_Producer(Case): headers = {} m, ctype, cencoding = p._prepare(message, compression='zlib', headers=headers) - self.assertEqual(ctype, 'application/json') - self.assertEqual(cencoding, 'utf-8') - self.assertEqual(headers['compression'], 'application/x-gzip') + assert ctype == 'application/json' + assert cencoding == 'utf-8' + assert headers['compression'] == 'application/x-gzip' import zlib - self.assertEqual( - json.loads(zlib.decompress(m).decode('utf-8')), - message, - ) + assert json.loads(zlib.decompress(m).decode('utf-8')) == message def test_prepare_custom_content_type(self): message = 'the quick brown fox'.encode('utf-8') channel = self.connection.channel() p = Producer(channel, self.exchange, serializer='json') m, ctype, cencoding = p._prepare(message, content_type='custom') - self.assertEqual(m, message) - self.assertEqual(ctype, 'custom') - self.assertEqual(cencoding, 'binary') + assert m == message + assert ctype == 'custom' + assert cencoding == 'binary' m, ctype, cencoding = p._prepare(message, content_type='custom', content_encoding='alien') - self.assertEqual(m, message) - self.assertEqual(ctype, 'custom') - self.assertEqual(cencoding, 'alien') + assert m == message + assert ctype == 'custom' + assert cencoding == 'alien' def test_prepare_is_already_unicode(self): message = 'the quick brown fox' channel = self.connection.channel() p = Producer(channel, self.exchange, serializer='json') m, ctype, cencoding = p._prepare(message, content_type='text/plain') - self.assertEqual(m, message.encode('utf-8')) - self.assertEqual(ctype, 'text/plain') - self.assertEqual(cencoding, 'utf-8') + assert m == message.encode('utf-8') + assert ctype == 'text/plain' + assert cencoding == 'utf-8' m, ctype, cencoding = p._prepare(message, content_type='text/plain', content_encoding='utf-8') - self.assertEqual(m, message.encode('utf-8')) - self.assertEqual(ctype, 'text/plain') - self.assertEqual(cencoding, 'utf-8') + assert m == message.encode('utf-8') + assert ctype == 'text/plain' + assert cencoding == 'utf-8' def test_publish_with_Exchange_instance(self): p = self.connection.Producer() p.channel = Mock() p.publish('hello', exchange=Exchange('foo'), delivery_mode='transient') - self.assertEqual( - p._channel.basic_publish.call_args[1]['exchange'], 'foo', - ) + assert p._channel.basic_publish.call_args[1]['exchange'] == 'foo' def test_publish_with_expiration(self): p = self.connection.Producer() p.channel = Mock() p.publish('hello', exchange=Exchange('foo'), expiration=10) properties = p._channel.prepare_message.call_args[0][5] - self.assertEqual(properties['expiration'], '10000') + assert properties['expiration'] == '10000' def test_publish_with_reply_to(self): p = self.connection.Producer() p.channel = Mock() p.publish('hello', exchange=Exchange('foo'), reply_to=Queue('foo')) properties = p._channel.prepare_message.call_args[0][5] - self.assertEqual(properties['reply_to'], 'foo') + assert properties['reply_to'] == 'foo' def test_set_on_return(self): chan = Mock() @@ -173,14 +170,14 @@ class test_Producer(Case): defchan = new_conn.default_channel p.revive(new_conn) - self.assertIs(p.channel, defchan) + assert p.channel is defchan p.exchange.revive.assert_called_with(defchan) def test_enter_exit(self): p = self.connection.Producer() p.release = Mock() - self.assertIs(p.__enter__(), p) + assert p.__enter__() is p p.__exit__() p.release.assert_called_with() @@ -188,37 +185,37 @@ class test_Producer(Case): p = self.connection.Producer() p.channel = object() p.__connection__ = None - self.assertIsNone(p.connection) + assert p.connection is None def test_publish(self): channel = self.connection.channel() p = Producer(channel, self.exchange, serializer='json') message = {'the quick brown fox': 'jumps over the lazy dog'} ret = p.publish(message, routing_key='process') - self.assertIn('prepare_message', channel) - self.assertIn('basic_publish', channel) + assert 'prepare_message' in channel + assert 'basic_publish' in channel m, exc, rkey = ret - self.assertDictEqual(message, json.loads(m['body'])) - self.assertDictContainsSubset({'content_type': 'application/json', - 'content_encoding': 'utf-8', - 'priority': 0}, m) - self.assertDictContainsSubset({'delivery_mode': 2}, m['properties']) - self.assertEqual(exc, p.exchange.name) - self.assertEqual(rkey, 'process') + assert json.loads(m['body']) == message + assert m['content_type'] == 'application/json' + assert m['content_encoding'] == 'utf-8' + assert m['priority'] == 0 + assert m['properties']['delivery_mode'] == 2 + assert exc == p.exchange.name + assert rkey == 'process' def test_no_exchange(self): chan = self.connection.channel() p = Producer(chan) - self.assertFalse(p.exchange.name) + assert not p.exchange.name def test_revive(self): chan = self.connection.channel() p = Producer(chan) chan2 = self.connection.channel() p.revive(chan2) - self.assertIs(p.channel, chan2) - self.assertIs(p.exchange.channel, chan2) + assert p.channel is chan2 + assert p.exchange.channel is chan2 def test_on_return(self): chan = self.connection.channel() @@ -227,28 +224,27 @@ class test_Producer(Case): pass p = Producer(chan, on_return=on_return) - self.assertIn(on_return, chan.events['basic_return']) - self.assertTrue(p.on_return) + assert on_return in chan.events['basic_return'] + assert p.on_return -class test_Consumer(Case): +class test_Consumer: def setup(self): self.connection = Connection(transport=Transport) self.connection.connect() - self.assertTrue(self.connection.connection.connected) + assert self.connection.connection.connected self.exchange = Exchange('foo', 'direct') def test_accept(self): a = Consumer(self.connection) - self.assertIsNone(a.accept) + assert a.accept is None b = Consumer(self.connection, accept=['json', 'pickle']) - self.assertSetEqual( - b.accept, - {'application/json', 'application/x-python-serialize'}, - ) + assert b.accept == { + 'application/json', 'application/x-python-serialize', + } c = Consumer(self.connection, accept=b.accept) - self.assertSetEqual(b.accept, c.accept) + assert b.accept == c.accept def test_enter_exit_cancel_raises(self): c = Consumer(self.connection) @@ -269,7 +265,7 @@ class test_Consumer(Case): c._receive_callback(message) callback.assert_called_with(message) - self.assertSetEqual(message.accept, c.accept) + assert message.accept == c.accept def test_accept__content_disallowed(self): conn = Connection('memory://') @@ -282,7 +278,7 @@ class test_Consumer(Case): callback = Mock(name='callback') with conn.Consumer(queues=[q], callbacks=[callback]) as consumer: - with self.assertRaises(consumer.ContentDisallowed): + with pytest.raises(consumer.ContentDisallowed): conn.drain_events(timeout=1) callback.assert_not_called() @@ -301,26 +297,26 @@ class test_Consumer(Case): conn.drain_events(timeout=1) callback.assert_called() body, message = callback.call_args[0] - self.assertTrue(body['complex']) + assert body['complex'] def test_set_no_channel(self): c = Consumer(None) - self.assertIsNone(c.channel) + assert c.channel is None c.revive(Mock()) - self.assertTrue(c.channel) + assert c.channel def test_set_no_ack(self): channel = self.connection.channel() queue = Queue('qname', self.exchange, 'rkey') consumer = Consumer(channel, queue, auto_declare=True, no_ack=True) - self.assertTrue(consumer.no_ack) + assert consumer.no_ack def test_add_queue_when_auto_declare(self): consumer = self.connection.Consumer(auto_declare=True) q = Mock() q.return_value = q consumer.add_queue(q) - self.assertIn(q, consumer.queues) + assert q in consumer.queues q.declare.assert_called_with() def test_add_queue_when_not_auto_declare(self): @@ -328,26 +324,26 @@ class test_Consumer(Case): q = Mock() q.return_value = q consumer.add_queue(q) - self.assertIn(q, consumer.queues) - self.assertFalse(q.declare.call_count) + assert q in consumer.queues + assert not q.declare.call_count def test_consume_without_queues_returns(self): consumer = self.connection.Consumer() consumer.queues[:] = [] - self.assertIsNone(consumer.consume()) + assert consumer.consume() is None def test_consuming_from(self): consumer = self.connection.Consumer() consumer.queues[:] = [Queue('a'), Queue('b'), Queue('d')] consumer._active_tags = {'a': 1, 'b': 2} - self.assertFalse(consumer.consuming_from(Queue('c'))) - self.assertFalse(consumer.consuming_from('c')) - self.assertFalse(consumer.consuming_from(Queue('d'))) - self.assertFalse(consumer.consuming_from('d')) - self.assertTrue(consumer.consuming_from(Queue('a'))) - self.assertTrue(consumer.consuming_from(Queue('b'))) - self.assertTrue(consumer.consuming_from('b')) + assert not consumer.consuming_from(Queue('c')) + assert not consumer.consuming_from('c') + assert not consumer.consuming_from(Queue('d')) + assert not consumer.consuming_from('d') + assert consumer.consuming_from(Queue('a')) + assert consumer.consuming_from(Queue('b')) + assert consumer.consuming_from('b') def test_receive_callback_without_m2p(self): channel = self.connection.channel() @@ -374,7 +370,7 @@ class test_Consumer(Case): except KeyError: message.errors = [sys.exc_info()] message._reraise_error.side_effect = KeyError() - with self.assertRaises(KeyError): + with pytest.raises(KeyError): c._receive_callback(message) def test_set_callbacks(self): @@ -384,7 +380,7 @@ class test_Consumer(Case): lambda x, y: x] consumer = Consumer(channel, queue, auto_declare=True, callbacks=callbacks) - self.assertEqual(consumer.callbacks, callbacks) + assert consumer.callbacks == callbacks def test_auto_declare(self): channel = self.connection.channel() @@ -392,22 +388,22 @@ class test_Consumer(Case): consumer = Consumer(channel, queue, auto_declare=True) consumer.consume() consumer.consume() # twice is a noop - self.assertIsNot(consumer.queues[0], queue) - self.assertTrue(consumer.queues[0].is_bound) - self.assertTrue(consumer.queues[0].exchange.is_bound) - self.assertIsNot(consumer.queues[0].exchange, self.exchange) + assert consumer.queues[0] is not queue + assert consumer.queues[0].is_bound + assert consumer.queues[0].exchange.is_bound + assert consumer.queues[0].exchange is not self.exchange for meth in ('exchange_declare', 'queue_declare', 'queue_bind', 'basic_consume'): - self.assertIn(meth, channel) - self.assertEqual(channel.called.count('basic_consume'), 1) - self.assertTrue(consumer._active_tags) + assert meth in channel + assert channel.called.count('basic_consume') == 1 + assert consumer._active_tags consumer.cancel_by_queue(queue.name) consumer.cancel_by_queue(queue.name) - self.assertFalse(consumer._active_tags) + assert not consumer._active_tags def test_consumer_tag_prefix(self): channel = self.connection.channel() @@ -415,33 +411,31 @@ class test_Consumer(Case): consumer = Consumer(channel, queue, tag_prefix='consumer_') consumer.consume() - self.assertTrue( - consumer._active_tags[queue.name].startswith('consumer_'), - ) + assert consumer._active_tags[queue.name].startswith('consumer_') def test_manual_declare(self): channel = self.connection.channel() queue = Queue('qname', self.exchange, 'rkey') consumer = Consumer(channel, queue, auto_declare=False) - self.assertIsNot(consumer.queues[0], queue) - self.assertTrue(consumer.queues[0].is_bound) - self.assertTrue(consumer.queues[0].exchange.is_bound) - self.assertIsNot(consumer.queues[0].exchange, self.exchange) + assert consumer.queues[0] is not queue + assert consumer.queues[0].is_bound + assert consumer.queues[0].exchange.is_bound + assert consumer.queues[0].exchange is not self.exchange for meth in ('exchange_declare', 'queue_declare', 'basic_consume'): - self.assertNotIn(meth, channel) + assert meth not in channel consumer.declare() for meth in ('exchange_declare', 'queue_declare', 'queue_bind'): - self.assertIn(meth, channel) - self.assertNotIn('basic_consume', channel) + assert meth in channel + assert 'basic_consume' not in channel consumer.consume() - self.assertIn('basic_consume', channel) + assert 'basic_consume' in channel def test_consume__cancel(self): channel = self.connection.channel() @@ -449,34 +443,34 @@ class test_Consumer(Case): consumer = Consumer(channel, queue, auto_declare=True) consumer.consume() consumer.cancel() - self.assertIn('basic_cancel', channel) - self.assertFalse(consumer._active_tags) + assert 'basic_cancel' in channel + assert not consumer._active_tags def test___enter____exit__(self): channel = self.connection.channel() queue = Queue('qname', self.exchange, 'rkey') consumer = Consumer(channel, queue, auto_declare=True) context = consumer.__enter__() - self.assertIs(context, consumer) - self.assertTrue(consumer._active_tags) + assert context is consumer + assert consumer._active_tags res = consumer.__exit__(None, None, None) - self.assertFalse(res) - self.assertIn('basic_cancel', channel) - self.assertFalse(consumer._active_tags) + assert not res + assert 'basic_cancel' in channel + assert not consumer._active_tags def test_flow(self): channel = self.connection.channel() queue = Queue('qname', self.exchange, 'rkey') consumer = Consumer(channel, queue, auto_declare=True) consumer.flow(False) - self.assertIn('flow', channel) + assert 'flow' in channel def test_qos(self): channel = self.connection.channel() queue = Queue('qname', self.exchange, 'rkey') consumer = Consumer(channel, queue, auto_declare=True) consumer.qos(30, 10, False) - self.assertIn('basic_qos', channel) + assert 'basic_qos' in channel def test_purge(self): channel = self.connection.channel() @@ -486,7 +480,7 @@ class test_Consumer(Case): b4 = Queue('qname4', self.exchange, 'rkey') consumer = Consumer(channel, [b1, b2, b3, b4], auto_declare=True) consumer.purge() - self.assertEqual(channel.called.count('queue_purge'), 4) + assert channel.called.count('queue_purge') == 4 def test_multiple_queues(self): channel = self.connection.channel() @@ -496,14 +490,14 @@ class test_Consumer(Case): b4 = Queue('qname4', self.exchange, 'rkey') consumer = Consumer(channel, [b1, b2, b3, b4]) consumer.consume() - self.assertEqual(channel.called.count('exchange_declare'), 4) - self.assertEqual(channel.called.count('queue_declare'), 4) - self.assertEqual(channel.called.count('queue_bind'), 4) - self.assertEqual(channel.called.count('basic_consume'), 4) - self.assertEqual(len(consumer._active_tags), 4) + assert channel.called.count('exchange_declare') == 4 + assert channel.called.count('queue_declare') == 4 + assert channel.called.count('queue_bind') == 4 + assert channel.called.count('basic_consume') == 4 + assert len(consumer._active_tags) == 4 consumer.cancel() - self.assertEqual(channel.called.count('basic_cancel'), 4) - self.assertFalse(len(consumer._active_tags)) + assert channel.called.count('basic_cancel') == 4 + assert not len(consumer._active_tags) def test_receive_callback(self): channel = self.connection.channel() @@ -519,9 +513,9 @@ class test_Consumer(Case): consumer.register_callback(callback) consumer._receive_callback({'foo': 'bar'}) - self.assertIn('basic_ack', channel) - self.assertIn('message_to_python', channel) - self.assertEqual(received[0], {'foo': 'bar'}) + assert 'basic_ack' in channel + assert 'message_to_python' in channel + assert received[0] == {'foo': 'bar'} def test_basic_ack_twice(self): channel = self.connection.channel() @@ -533,7 +527,7 @@ class test_Consumer(Case): message.ack() consumer.register_callback(callback) - with self.assertRaises(MessageStateError): + with pytest.raises(MessageStateError): consumer._receive_callback({'foo': 'bar'}) def test_basic_reject(self): @@ -546,7 +540,7 @@ class test_Consumer(Case): consumer.register_callback(callback) consumer._receive_callback({'foo': 'bar'}) - self.assertIn('basic_reject', channel) + assert 'basic_reject' in channel def test_basic_reject_twice(self): channel = self.connection.channel() @@ -558,9 +552,9 @@ class test_Consumer(Case): message.reject() consumer.register_callback(callback) - with self.assertRaises(MessageStateError): + with pytest.raises(MessageStateError): consumer._receive_callback({'foo': 'bar'}) - self.assertIn('basic_reject', channel) + assert 'basic_reject' in channel def test_basic_reject__requeue(self): channel = self.connection.channel() @@ -572,7 +566,7 @@ class test_Consumer(Case): consumer.register_callback(callback) consumer._receive_callback({'foo': 'bar'}) - self.assertIn('basic_reject:requeue', channel) + assert 'basic_reject:requeue' in channel def test_basic_reject__requeue_twice(self): channel = self.connection.channel() @@ -584,15 +578,15 @@ class test_Consumer(Case): message.requeue() consumer.register_callback(callback) - with self.assertRaises(MessageStateError): + with pytest.raises(MessageStateError): consumer._receive_callback({'foo': 'bar'}) - self.assertIn('basic_reject:requeue', channel) + assert 'basic_reject:requeue' in channel def test_receive_without_callbacks_raises(self): channel = self.connection.channel() b1 = Queue('qname1', self.exchange, 'rkey') consumer = Consumer(channel, [b1]) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): consumer.receive(1, 2) def test_decode_error(self): @@ -601,7 +595,7 @@ class test_Consumer(Case): consumer = Consumer(channel, [b1]) consumer.channel.throw_decode_error = True - with self.assertRaises(ValueError): + with pytest.raises(ValueError): consumer._receive_callback({'foo': 'bar'}) def test_on_decode_error_callback(self): @@ -616,17 +610,17 @@ class test_Consumer(Case): consumer.channel.throw_decode_error = True consumer._receive_callback({'foo': 'bar'}) - self.assertTrue(thrown) + assert thrown m, exc = thrown[0] - self.assertEqual(json.loads(m), {'foo': 'bar'}) - self.assertIsInstance(exc, ValueError) + assert json.loads(m) == {'foo': 'bar'} + assert isinstance(exc, ValueError) def test_recover(self): channel = self.connection.channel() b1 = Queue('qname1', self.exchange, 'rkey') consumer = Consumer(channel, [b1]) consumer.recover() - self.assertIn('basic_recover', channel) + assert 'basic_recover' in channel def test_revive(self): channel = self.connection.channel() @@ -634,9 +628,9 @@ class test_Consumer(Case): consumer = Consumer(channel, [b1]) channel2 = self.connection.channel() consumer.revive(channel2) - self.assertIs(consumer.channel, channel2) - self.assertIs(consumer.queues[0].channel, channel2) - self.assertIs(consumer.queues[0].exchange.channel, channel2) + assert consumer.channel is channel2 + assert consumer.queues[0].channel is channel2 + assert consumer.queues[0].exchange.channel is channel2 def test_revive__with_prefetch_count(self): channel = Mock(name='channel') @@ -647,9 +641,9 @@ class test_Consumer(Case): def test__repr__(self): channel = self.connection.channel() b1 = Queue('qname1', self.exchange, 'rkey') - self.assertTrue(repr(Consumer(channel, [b1]))) + assert repr(Consumer(channel, [b1])) def test_connection_property_handles_AttributeError(self): p = self.connection.Consumer() p.channel = object() - self.assertIsNone(p.connection) + assert p.connection is None diff --git a/kombu/tests/test_mixins.py b/t/unit/test_mixins.py index e3690806..372399e4 100644 --- a/kombu/tests/test_mixins.py +++ b/t/unit/test_mixins.py @@ -1,10 +1,11 @@ from __future__ import absolute_import, unicode_literals +import pytest import socket -from kombu.mixins import ConsumerMixin +from case import ContextMock, Mock, patch -from .case import Case, Mock, ContextMock, patch +from kombu.mixins import ConsumerMixin def Message(body, content_type='text/plain', content_encoding='utf-8'): @@ -31,7 +32,7 @@ class Cons(ConsumerMixin): self.extra_context.return_value = self.extra_context -class test_ConsumerMixin(Case): +class test_ConsumerMixin: def _context(self): Acons = ContextMock(name='consumerA') @@ -57,7 +58,7 @@ class test_ConsumerMixin(Case): next(it) next(it) c.should_stop = True - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) def test_consume_drain_raises_socket_error(self): @@ -65,7 +66,7 @@ class test_ConsumerMixin(Case): c.should_stop = False it = c.consume(no_ack=True) c.connection.drain_events.side_effect = socket.error - with self.assertRaises(socket.error): + with pytest.raises(socket.error): next(it) def se2(*args, **kwargs): @@ -73,7 +74,7 @@ class test_ConsumerMixin(Case): raise socket.error() c.connection.drain_events.side_effect = se2 it = c.consume(no_ack=True) - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): next(it) def test_consume_drain_raises_socket_timeout(self): @@ -85,49 +86,47 @@ class test_ConsumerMixin(Case): c.should_stop = True raise socket.timeout() c.connection.drain_events.side_effect = se - with self.assertRaises(socket.error): + with pytest.raises(socket.error): next(it) def test_Consumer_context(self): c, Acons, Bcons = self._context() with c.Consumer() as (conn, channel, consumer): - self.assertIs(conn, c.connection) - self.assertIs(channel, conn.default_channel) + assert conn is c.connection + assert channel is conn.default_channel c.on_connection_revived.assert_called_with() c.get_consumers.assert_called() cls = c.get_consumers.call_args[0][0] subcons = cls() - self.assertIs(subcons.on_decode_error, c.on_decode_error) - self.assertIs(subcons.channel, conn.default_channel) + assert subcons.on_decode_error is c.on_decode_error + assert subcons.channel is conn.default_channel Acons.__enter__.assert_called_with() Bcons.__enter__.assert_called_with() c.on_consume_end.assert_called_with(conn, channel) -class test_ConsumerMixin_interface(Case): +class test_ConsumerMixin_interface: def setup(self): self.c = ConsumerMixin() def test_get_consumers(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.c.get_consumers(Mock(), Mock()) def test_on_connection_revived(self): - self.assertIsNone(self.c.on_connection_revived()) + assert self.c.on_connection_revived() is None def test_on_consume_ready(self): - self.assertIsNone(self.c.on_consume_ready( - Mock(), Mock(), [], - )) + assert self.c.on_consume_ready(Mock(), Mock(), []) is None def test_on_consume_end(self): - self.assertIsNone(self.c.on_consume_end(Mock(), Mock())) + assert self.c.on_consume_end(Mock(), Mock()) is None def test_on_iteration(self): - self.assertIsNone(self.c.on_iteration()) + assert self.c.on_iteration() is None def test_on_decode_error(self): message = Message('foo') @@ -146,15 +145,15 @@ class test_ConsumerMixin_interface(Case): pass def test_restart_limit(self): - self.assertTrue(self.c.restart_limit) + assert self.c.restart_limit def test_connection_errors(self): conn = Mock(name='connection') self.c.connection = conn conn.connection_errors = (KeyError,) - self.assertTupleEqual(self.c.connection_errors, conn.connection_errors) + assert self.c.connection_errors == conn.connection_errors conn.channel_errors = (ValueError,) - self.assertTupleEqual(self.c.channel_errors, conn.channel_errors) + assert self.c.channel_errors == conn.channel_errors def test__consume_from(self): a = ContextMock(name='A') @@ -173,7 +172,7 @@ class test_ConsumerMixin_interface(Case): self.c.connect_max_retries = 3 with self.c.establish_connection() as conn: - self.assertTrue(conn) + assert conn conn.ensure_connection.assert_called_with( self.c.on_connection_error, 3, ) diff --git a/kombu/tests/test_pidbox.py b/t/unit/test_pidbox.py index 852dbbf5..c2dfccb0 100644 --- a/kombu/tests/test_pidbox.py +++ b/t/unit/test_pidbox.py @@ -1,29 +1,34 @@ from __future__ import absolute_import, unicode_literals +import pytest import socket import warnings +from case import Mock, patch + from kombu import Connection from kombu import pidbox from kombu.exceptions import ContentDisallowed, InconsistencyError from kombu.utils.uuid import uuid -from .case import Case, Mock, patch +def is_cast(message): + return message['method'] -class test_Mailbox(Case): - def _handler(self, state): - return self.stats['var'] +def is_call(message): + return message['method'] and message['reply_to'] - def setup(self): - class Mailbox(pidbox.Mailbox): +class test_Mailbox: + + class Mailbox(pidbox.Mailbox): - def _collect(self, *args, **kwargs): - return 'COLLECTED' + def _collect(self, *args, **kwargs): + return 'COLLECTED' - self.mailbox = Mailbox('test_pidbox') + def setup(self): + self.mailbox = self.Mailbox('test_pidbox') self.connection = Connection(transport='memory') self.state = {'var': 1} self.handlers = {'mymethod': self._handler} @@ -35,6 +40,9 @@ class test_Mailbox(Case): channel=self.default_chan, ) + def _handler(self, state): + return self.stats['var'] + def test_publish_reply_ignores_InconsistencyError(self): mailbox = pidbox.Mailbox('test_reply__collect')(self.connection) with patch('kombu.pidbox.Producer') as Producer: @@ -60,17 +68,17 @@ class test_Mailbox(Case): reply = mailbox._collect(ticket, limit=1, callback=callback, channel=channel) - self.assertEqual(reply, [{'foo': 'bar'}]) - self.assertTrue(_callback_called[0]) + assert reply == [{'foo': 'bar'}] + assert _callback_called[0] ticket = uuid() mailbox._publish_reply({'biz': 'boz'}, exchange, mailbox.oid, ticket) reply = mailbox._collect(ticket, limit=1, channel=channel) - self.assertEqual(reply, [{'biz': 'boz'}]) + assert reply == [{'biz': 'boz'}] mailbox._publish_reply({'foo': 'BAM'}, exchange, mailbox.oid, 'doom', serializer='pickle') - with self.assertRaises(ContentDisallowed): + with pytest.raises(ContentDisallowed): reply = mailbox._collect('doom', limit=1, channel=channel) mailbox._publish_reply( {'foo': 'BAMBAM'}, exchange, mailbox.oid, 'doom', @@ -78,40 +86,40 @@ class test_Mailbox(Case): ) reply = mailbox._collect('doom', limit=1, channel=channel, accept=['pickle']) - self.assertEqual(reply[0]['foo'], 'BAMBAM') + assert reply[0]['foo'] == 'BAMBAM' de = mailbox.connection.drain_events = Mock() de.side_effect = socket.timeout mailbox._collect(ticket, limit=1, channel=channel) def test_constructor(self): - self.assertIsNone(self.mailbox.connection) - self.assertTrue(self.mailbox.exchange.name) - self.assertTrue(self.mailbox.reply_exchange.name) + assert self.mailbox.connection is None + assert self.mailbox.exchange.name + assert self.mailbox.reply_exchange.name def test_bound(self): bound = self.mailbox(self.connection) - self.assertIs(bound.connection, self.connection) + assert bound.connection is self.connection def test_Node(self): - self.assertTrue(self.node.hostname) - self.assertTrue(self.node.state) - self.assertIs(self.node.mailbox, self.bound) - self.assertTrue(self.handlers) + assert self.node.hostname + assert self.node.state + assert self.node.mailbox is self.bound + assert self.handlers # No initial handlers node2 = self.bound.Node('test_pidbox2', state=self.state) - self.assertDictEqual(node2.handlers, {}) + assert node2.handlers == {} def test_Node_consumer(self): consumer1 = self.node.Consumer() - self.assertIs(consumer1.channel, self.default_chan) - self.assertTrue(consumer1.no_ack) + assert consumer1.channel is self.default_chan + assert consumer1.no_ack chan2 = self.connection.channel() consumer2 = self.node.Consumer(channel=chan2, no_ack=False) - self.assertIs(consumer2.channel, chan2) - self.assertFalse(consumer2.no_ack) + assert consumer2.channel is chan2 + assert not consumer2.no_ack def test_Node_consumer_multiple_listeners(self): warnings.resetwarnings() @@ -119,12 +127,12 @@ class test_Mailbox(Case): q = consumer.queues[0] with warnings.catch_warnings(record=True) as log: q.on_declared('foo', 1, 1) - self.assertTrue(log) - self.assertIn('already using this', log[0].message.args[0]) + assert log + assert 'already using this' in log[0].message.args[0] with warnings.catch_warnings(record=True) as log: q.on_declared('foo', 1, 0) - self.assertFalse(log) + assert not log def test_handler(self): node = self.bound.Node('test_handler', state=self.state) @@ -133,7 +141,7 @@ class test_Mailbox(Case): def my_handler_name(state): return 42 - self.assertIn('my_handler_name', node.handlers) + assert 'my_handler_name' in node.handlers def test_dispatch(self): node = self.bound.Node('test_dispatch', state=self.state) @@ -142,8 +150,8 @@ class test_Mailbox(Case): def my_handler_name(state, x=None, y=None): return x + y - self.assertEqual(node.dispatch('my_handler_name', - arguments={'x': 10, 'y': 10}), 20) + assert node.dispatch('my_handler_name', + arguments={'x': 10, 'y': 10}) == 20 def test_dispatch_raising_SystemExit(self): node = self.bound.Node('test_dispatch_raising_SystemExit', @@ -153,7 +161,7 @@ class test_Mailbox(Case): def my_handler_name(state): raise SystemExit - with self.assertRaises(SystemExit): + with pytest.raises(SystemExit): node.dispatch('my_handler_name') def test_dispatch_raising(self): @@ -164,8 +172,8 @@ class test_Mailbox(Case): raise KeyError('foo') res = node.dispatch('my_handler_name') - self.assertIn('error', res) - self.assertIn('KeyError', res['error']) + assert 'error' in res + assert 'KeyError' in res['error'] def test_dispatch_replies(self): _replied = [False] @@ -183,7 +191,7 @@ class test_Mailbox(Case): node.dispatch('my_handler_name', arguments={'x': 10, 'y': 10}, reply_to={'exchange': 'foo', 'routing_key': 'bar'}) - self.assertTrue(_replied[0]) + assert _replied[0] def test_reply(self): _replied = [(None, None, None)] @@ -204,10 +212,10 @@ class test_Mailbox(Case): 'routing_key': 'rkey'}, ticket='TICKET') data, exchange, routing_key, ticket = _replied[0] - self.assertEqual(data, {'test_reply': 42}) - self.assertEqual(exchange, 'exchange') - self.assertEqual(routing_key, 'rkey') - self.assertEqual(ticket, 'TICKET') + assert data == {'test_reply': 42} + assert exchange == 'exchange' + assert routing_key == 'rkey' + assert ticket == 'TICKET' def test_handle_message(self): node = self.bound.Node('test_dispatch_from_message') @@ -219,11 +227,11 @@ class test_Mailbox(Case): body = {'method': 'my_handler_name', 'arguments': {'x': 64, 'y': 64}} - self.assertEqual(node.handle_message(body, None), 64 * 64) + assert node.handle_message(body, None) == 64 * 64 # message not for me should not be processed. body['destination'] = ['some_other_node'] - self.assertIsNone(node.handle_message(body, None)) + assert node.handle_message(body, None) is None def test_handle_message_adjusts_clock(self): node = self.bound.Node('test_adjusts_clock') @@ -239,49 +247,38 @@ class test_Mailbox(Case): node.adjust_clock = Mock(name='adjust_clock') res = node.handle_message(body, message) node.adjust_clock.assert_called_with(313) - self.assertEqual(res, 10) + assert res == 10 def test_listen(self): consumer = self.node.listen() - self.assertEqual(consumer.callbacks[0], - self.node.handle_message) - self.assertEqual(consumer.channel, self.default_chan) + assert consumer.callbacks[0] == self.node.handle_message + assert consumer.channel == self.default_chan def test_cast(self): self.bound.cast(['somenode'], 'mymethod') consumer = self.node.Consumer() - self.assertIsCast(self.get_next(consumer)) + assert is_cast(self.get_next(consumer)) def test_abcast(self): self.bound.abcast('mymethod') consumer = self.node.Consumer() - self.assertIsCast(self.get_next(consumer)) + assert is_cast(self.get_next(consumer)) def test_call_destination_must_be_sequence(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.bound.call('some_node', 'mymethod') def test_call(self): - self.assertEqual( - self.bound.call(['some_node'], 'mymethod'), - 'COLLECTED', - ) + assert self.bound.call(['some_node'], 'mymethod') == 'COLLECTED' consumer = self.node.Consumer() - self.assertIsCall(self.get_next(consumer)) + assert is_call(self.get_next(consumer)) def test_multi_call(self): - self.assertEqual(self.bound.multi_call('mymethod'), 'COLLECTED') + assert self.bound.multi_call('mymethod') == 'COLLECTED' consumer = self.node.Consumer() - self.assertIsCall(self.get_next(consumer)) + assert is_call(self.get_next(consumer)) def get_next(self, consumer): m = consumer.queues[0].get() if m: return m.payload - - def assertIsCast(self, message): - self.assertTrue(message['method']) - - def assertIsCall(self, message): - self.assertTrue(message['method']) - self.assertTrue(message['reply_to']) diff --git a/kombu/tests/test_pools.py b/t/unit/test_pools.py index 7a632eee..c70cc89d 100644 --- a/kombu/tests/test_pools.py +++ b/t/unit/test_pools.py @@ -1,14 +1,16 @@ from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock + from kombu import Connection, Producer from kombu import pools from kombu.connection import ConnectionPool from kombu.utils.collections import eqhash -from .case import Case, Mock - -class test_ProducerPool(Case): +class test_ProducerPool: Pool = pools.ProducerPool class MyPool(pools.ProducerPool): @@ -32,7 +34,7 @@ class test_ProducerPool(Case): self.pool.Producer.side_effect = IOError() acq = self.pool._acquire_connection = Mock() conn = acq.return_value = Mock() - with self.assertRaises(IOError): + with pytest.raises(IOError): self.pool.create_producer() conn.release.assert_called_with() @@ -43,7 +45,7 @@ class test_ProducerPool(Case): acq = self.pool._acquire_connection = Mock() conn = acq.return_value = Mock() p._channel = None - with self.assertRaises(IOError): + with pytest.raises(IOError): self.pool.prepare(pp) conn.release.assert_called_with() @@ -56,10 +58,10 @@ class test_ProducerPool(Case): self.pool.release(p) def test_init(self): - self.assertIs(self.pool.connections, self.connections) + assert self.pool.connections is self.connections def test_Producer(self): - self.assertIsInstance(self.pool.Producer(Mock()), Producer) + assert isinstance(self.pool.Producer(Mock()), Producer) def test_acquire_connection(self): self.pool._acquire_connection() @@ -68,20 +70,20 @@ class test_ProducerPool(Case): def test_new(self): promise = self.pool.new() producer = promise() - self.assertIsInstance(producer, Producer) + assert isinstance(producer, Producer) self.connections.acquire.assert_called_with(block=True) def test_setup_unlimited(self): pool = self.Pool(self.connections, limit=None) pool.setup() - self.assertFalse(pool._resource.queue) + assert not pool._resource.queue def test_setup(self): - self.assertEqual(len(self.pool._resource.queue), self.pool.limit) + assert len(self.pool._resource.queue) == self.pool.limit first = self.pool._resource.get_nowait() producer = first() - self.assertIsInstance(producer, Producer) + assert isinstance(producer, Producer) def test_prepare(self): connection = self.connections.acquire.return_value = Mock() @@ -111,10 +113,10 @@ class test_ProducerPool(Case): p.__connection__ = Mock() self.pool.release(p) p.__connection__.release.assert_called_with() - self.assertIsNone(p.channel) + assert p.channel is None -class test_PoolGroup(Case): +class test_PoolGroup: Group = pools.PoolGroup class MyGroup(pools.PoolGroup): @@ -124,47 +126,47 @@ class test_PoolGroup(Case): def test_interface_create(self): g = self.Group() - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): g.create(Mock(), 10) def test_getitem_using_global_limit(self): g = self.MyGroup(limit=pools.use_global_limit) res = g['foo'] - self.assertTupleEqual(res, ('foo', pools.get_limit())) + assert res == ('foo', pools.get_limit()) def test_getitem_using_custom_limit(self): g = self.MyGroup(limit=102456) res = g['foo'] - self.assertTupleEqual(res, ('foo', 102456)) + assert res == ('foo', 102456) def test_delitem(self): g = self.MyGroup() g['foo'] del(g['foo']) - self.assertNotIn('foo', g) + assert 'foo' not in g def test_Connections(self): conn = Connection('memory://') p = pools.connections[conn] - self.assertTrue(p) - self.assertIsInstance(p, ConnectionPool) - self.assertIs(p.connection, conn) - self.assertEqual(p.limit, pools.get_limit()) + assert p + assert isinstance(p, ConnectionPool) + assert p.connection is conn + assert p.limit == pools.get_limit() def test_Producers(self): conn = Connection('memory://') p = pools.producers[conn] - self.assertTrue(p) - self.assertIsInstance(p, pools.ProducerPool) - self.assertIs(p.connections, pools.connections[conn]) - self.assertEqual(p.limit, p.connections.limit) - self.assertEqual(p.limit, pools.get_limit()) + assert p + assert isinstance(p, pools.ProducerPool) + assert p.connections is pools.connections[conn] + assert p.limit == p.connections.limit + assert p.limit == pools.get_limit() def test_all_groups(self): conn = Connection('memory://') pools.connections[conn] - self.assertTrue(list(pools._all_pools())) + assert list(pools._all_pools()) def test_reset(self): pools.reset() @@ -181,7 +183,7 @@ class test_PoolGroup(Case): pools.reset() p1.force_close_all.assert_called_with() - self.assertTrue(g1.clear_called) + assert g1.clear_called p1 = pools.connections['foo'] = Mock() p1.force_close_all.side_effect = KeyError() @@ -191,23 +193,23 @@ class test_PoolGroup(Case): pools.reset() pools.set_limit(34576) limit = pools.get_limit() - self.assertEqual(limit, 34576) + assert limit == 34576 conn = Connection('memory://') pool = pools.connections[conn] with pool.acquire(): pools.set_limit(limit + 1) - self.assertEqual(pools.get_limit(), limit + 1) + assert pools.get_limit() == limit + 1 limit = pools.get_limit() - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): pools.set_limit(limit - 1) pools.set_limit(limit - 1, force=True) - self.assertEqual(pools.get_limit(), limit - 1) + assert pools.get_limit() == limit - 1 pools.set_limit(pools.get_limit()) -class test_fun_PoolGroup(Case): +class test_fun_PoolGroup: def test_connections_behavior(self): c1u = 'memory://localhost:123' @@ -220,19 +222,19 @@ class test_fun_PoolGroup(Case): assert eqhash(c1) == eqhash(c3) c4 = Connection(c1u, transport_options={'confirm_publish': True}) - self.assertNotEqual(eqhash(c3), eqhash(c4)) + assert eqhash(c3) != eqhash(c4) p1 = pools.connections[c1] p2 = pools.connections[c2] p3 = pools.connections[c3] - self.assertIsNot(p1, p2) - self.assertIs(p1, p3) + assert p1 is not p2 + assert p1 is p3 r1 = p1.acquire() - self.assertTrue(p1._dirty) - self.assertTrue(p3._dirty) - self.assertFalse(p2._dirty) + assert p1._dirty + assert p3._dirty + assert not p2._dirty r1.release() - self.assertFalse(p1._dirty) - self.assertFalse(p3._dirty) + assert not p1._dirty + assert not p3._dirty diff --git a/kombu/tests/test_serialization.py b/t/unit/test_serialization.py index d98b007e..88af0860 100644 --- a/kombu/tests/test_serialization.py +++ b/t/unit/test_serialization.py @@ -2,10 +2,13 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals +import pytest import sys from base64 import b64decode +from case import call, mock, patch, skip + from kombu.exceptions import ContentDisallowed, EncodeError, DecodeError from kombu.five import text_t, bytes_t from kombu.serialization import ( @@ -17,8 +20,6 @@ from kombu.serialization import ( ) from kombu.utils.encoding import str_to_bytes -from .case import Case, call, mock, patch, skip - # For content_encoding tests unicode_string = 'abcdé\u8463' unicode_string_as_utf8 = unicode_string.encode('utf-8') @@ -72,38 +73,38 @@ registry.register('testS', lambda s: s, lambda s: 'decoded', 'application/testS', 'utf-8') -class test_Serialization(Case): +class test_Serialization: def test_disable(self): disabled = registry._disabled_content_types try: registry.disable('testS') - self.assertIn('application/testS', disabled) + assert 'application/testS' in disabled disabled.clear() registry.disable('application/testS') - self.assertIn('application/testS', disabled) + assert 'application/testS' in disabled finally: disabled.clear() def test_enable(self): registry._disabled_content_types.add('application/json') registry.enable('json') - self.assertNotIn('application/json', registry._disabled_content_types) + assert 'application/json' not in registry._disabled_content_types registry._disabled_content_types.add('application/json') registry.enable('application/json') - self.assertNotIn('application/json', registry._disabled_content_types) + assert 'application/json' not in registry._disabled_content_types def test_loads_when_disabled(self): disabled = registry._disabled_content_types try: registry.disable('testS') - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): loads('xxd', 'application/testS', 'utf-8', force=False) ret = loads('xxd', 'application/testS', 'utf-8', force=True) - self.assertEqual(ret, 'decoded') + assert ret == 'decoded' finally: disabled.clear() @@ -111,50 +112,36 @@ class test_Serialization(Case): loads(None, 'application/testS', 'utf-8') def test_content_type_decoding(self): - self.assertEqual( - unicode_string, - loads(unicode_string_as_utf8, - content_type='plain/text', content_encoding='utf-8'), - ) - self.assertEqual( - latin_string, - loads(latin_string_as_latin1, - content_type='application/data', content_encoding='latin-1'), - ) + assert loads( + unicode_string_as_utf8, + content_type='plain/text', + content_encoding='utf-8') == unicode_string + assert loads( + latin_string_as_latin1, + content_type='application/data', + content_encoding='latin-1') == latin_string def test_content_type_binary(self): - self.assertIsInstance( + assert isinstance( loads(unicode_string_as_utf8, content_type='application/data', content_encoding='binary'), - bytes_t, - ) + bytes_t) - self.assertEqual( + assert loads( unicode_string_as_utf8, - loads(unicode_string_as_utf8, - content_type='application/data', content_encoding='binary'), - ) + content_type='application/data', + content_encoding='binary') == unicode_string_as_utf8 def test_content_type_encoding(self): # Using the 'raw' serializer - self.assertEqual( - unicode_string_as_utf8, - dumps(unicode_string, serializer='raw')[-1], - ) - self.assertEqual( - latin_string_as_utf8, - dumps(latin_string, serializer='raw')[-1], - ) + assert (dumps(unicode_string, serializer='raw')[-1] == + unicode_string_as_utf8) + assert (dumps(latin_string, serializer='raw')[-1] == + latin_string_as_utf8) # And again w/o a specific serializer to check the # code where we force unicode objects into a string. - self.assertEqual( - unicode_string_as_utf8, - dumps(unicode_string)[-1], - ) - self.assertEqual( - latin_string_as_utf8, - dumps(latin_string)[-1], - ) + assert dumps(unicode_string)[-1] == unicode_string_as_utf8 + assert dumps(latin_string)[-1] == latin_string_as_utf8 def test_enable_insecure_serializers(self): with patch('kombu.serialization.registry') as registry: @@ -182,34 +169,31 @@ class test_Serialization(Case): ]) def test_reraises_EncodeError(self): - with self.assertRaises(EncodeError): + with pytest.raises(EncodeError): dumps([object()], serializer='json') def test_reraises_DecodeError(self): - with self.assertRaises(DecodeError): + with pytest.raises(DecodeError): loads(object(), content_type='application/json', content_encoding='utf-8') def test_json_loads(self): - self.assertEqual( - py_data, - loads(json_data, - content_type='application/json', content_encoding='utf-8'), - ) + assert loads(json_data, + content_type='application/json', + content_encoding='utf-8') == py_data def test_json_dumps(self): - self.assertEqual( - loads( - dumps(py_data, serializer='json')[-1], - content_type='application/json', - content_encoding='utf-8', - ), - loads( - json_data, - content_type='application/json', - content_encoding='utf-8', - ), + a = loads( + dumps(py_data, serializer='json')[-1], + content_type='application/json', + content_encoding='utf-8', + ) + b = loads( + json_data, + content_type='application/json', + content_encoding='utf-8', ) + assert a == b @skip.if_pypy() @skip.unless_module('msgpack', (ImportError, ValueError)) @@ -224,122 +208,109 @@ class test_Serialization(Case): res[k] = v.encode() if isinstance(v, (list, tuple)): res[k] = [i.encode() for i in v] - self.assertEqual( - msgpack_py_data, - res, - ) + assert res == msgpack_py_data @skip.if_pypy() @skip.unless_module('msgpack', (ImportError, ValueError)) def test_msgpack_dumps(self): register_msgpack() - self.assertEqual( - loads( - dumps(msgpack_py_data, serializer='msgpack')[-1], - content_type='application/x-msgpack', - content_encoding='binary', - ), - loads( - msgpack_data, - content_type='application/x-msgpack', - content_encoding='binary', - ), + a = loads( + dumps(msgpack_py_data, serializer='msgpack')[-1], + content_type='application/x-msgpack', + content_encoding='binary', ) + b = loads( + msgpack_data, + content_type='application/x-msgpack', + content_encoding='binary', + ) + assert a == b @skip.unless_module('yaml') def test_yaml_loads(self): register_yaml() - self.assertEqual( - py_data, - loads(yaml_data, - content_type='application/x-yaml', - content_encoding='utf-8'), - ) + assert loads( + yaml_data, + content_type='application/x-yaml', + content_encoding='utf-8') == py_data @skip.unless_module('yaml') def test_yaml_dumps(self): register_yaml() - self.assertEqual( - loads( - dumps(py_data, serializer='yaml')[-1], - content_type='application/x-yaml', - content_encoding='utf-8', - ), - loads( - yaml_data, - content_type='application/x-yaml', - content_encoding='utf-8', - ), + a = loads( + dumps(py_data, serializer='yaml')[-1], + content_type='application/x-yaml', + content_encoding='utf-8', + ) + b = loads( + yaml_data, + content_type='application/x-yaml', + content_encoding='utf-8', ) + assert a == b def test_pickle_loads(self): - self.assertEqual( - py_data, - loads(pickle_data, - content_type='application/x-python-serialize', - content_encoding='binary'), - ) + assert loads( + pickle_data, + content_type='application/x-python-serialize', + content_encoding='binary') == py_data def test_pickle_dumps(self): - self.assertEqual( - pickle.loads(pickle_data), - pickle.loads(dumps(py_data, serializer='pickle')[-1]), - ) + a = pickle.loads(pickle_data), + b = pickle.loads(dumps(py_data, serializer='pickle')[-1]), + assert a == b def test_register(self): register(None, None, None, None) def test_unregister(self): - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): unregister('nonexisting') dumps('foo', serializer='pickle') unregister('pickle') - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): dumps('foo', serializer='pickle') register_pickle() def test_set_default_serializer_missing(self): - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): registry._set_default_serializer('nonexisting') def test_dumps_missing(self): - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): dumps('foo', serializer='nonexisting') def test_dumps__no_serializer(self): ctyp, cenc, data = dumps(str_to_bytes('foo')) - self.assertEqual(ctyp, 'application/data') - self.assertEqual(cenc, 'binary') + assert ctyp == 'application/data' + assert cenc == 'binary' def test_loads__trusted_content(self): loads('tainted', 'application/data', 'binary', accept=[]) loads('tainted', 'application/text', 'utf-8', accept=[]) def test_loads__not_accepted(self): - with self.assertRaises(ContentDisallowed): + with pytest.raises(ContentDisallowed): loads('tainted', 'application/x-evil', 'binary', accept=[]) - with self.assertRaises(ContentDisallowed): + with pytest.raises(ContentDisallowed): loads('tainted', 'application/x-evil', 'binary', accept=['application/x-json']) - self.assertTrue( - loads('tainted', 'application/x-doomsday', 'binary', - accept=['application/x-doomsday']) - ) + assert loads('tainted', 'application/x-doomsday', 'binary', + accept=['application/x-doomsday']) def test_raw_encode(self): - self.assertTupleEqual( - raw_encode('foo'.encode('utf-8')), - ('application/data', 'binary', 'foo'.encode('utf-8')), + assert raw_encode('foo'.encode('utf-8')) == ( + 'application/data', 'binary', 'foo'.encode('utf-8'), ) @mock.mask_modules('yaml') def test_register_yaml__no_yaml(self): register_yaml() - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): loads('foo', 'application/x-yaml', 'utf-8') @mock.mask_modules('msgpack') def test_register_msgpack__no_msgpack(self): register_msgpack() - with self.assertRaises(SerializerNotInstalled): + with pytest.raises(SerializerNotInstalled): loads('foo', 'application/x-msgpack', 'utf-8') diff --git a/kombu/tests/test_simple.py b/t/unit/test_simple.py index a0d68eff..e1f7cb22 100644 --- a/kombu/tests/test_simple.py +++ b/t/unit/test_simple.py @@ -1,12 +1,13 @@ from __future__ import absolute_import, unicode_literals -from kombu import Connection, Exchange, Queue +import pytest + +from case import Mock -from .case import Case, Mock +from kombu import Connection, Exchange, Queue -class SimpleBase(Case): - abstract = True +class SimpleBase: def Queue(self, name, *args, **kwargs): q = name @@ -20,83 +21,67 @@ class SimpleBase(Case): raise NotImplementedError() def setup(self): - if not self.abstract: - self.connection = Connection(transport='memory') - with self.connection.channel() as channel: - channel.exchange_declare('amq.direct') - self.q = self.Queue(None, no_ack=True) + self.connection = Connection(transport='memory') + with self.connection.channel() as channel: + channel.exchange_declare('amq.direct') + self.q = self.Queue(None, no_ack=True) def teardown(self): - if not self.abstract: - self.q.close() - self.connection.close() + self.q.close() + self.connection.close() def test_produce__consume(self): - if self.abstract: - return q = self.Queue('test_produce__consume', no_ack=True) q.put({'hello': 'Simple'}) - self.assertEqual(q.get(timeout=1).payload, {'hello': 'Simple'}) - with self.assertRaises(q.Empty): + assert q.get(timeout=1).payload == {'hello': 'Simple'} + with pytest.raises(q.Empty): q.get(timeout=0.1) def test_produce__basic_get(self): - if self.abstract: - return q = self.Queue('test_produce__basic_get', no_ack=True) q.put({'hello': 'SimpleSync'}) - self.assertEqual(q.get_nowait().payload, {'hello': 'SimpleSync'}) - with self.assertRaises(q.Empty): + assert q.get_nowait().payload == {'hello': 'SimpleSync'} + with pytest.raises(q.Empty): q.get_nowait() q.put({'hello': 'SimpleSync'}) - self.assertEqual(q.get(block=False).payload, {'hello': 'SimpleSync'}) - with self.assertRaises(q.Empty): + assert q.get(block=False).payload == {'hello': 'SimpleSync'} + with pytest.raises(q.Empty): q.get(block=False) def test_clear(self): - if self.abstract: - return q = self.Queue('test_clear', no_ack=True) for i in range(10): q.put({'hello': 'SimplePurge%d' % (i,)}) - self.assertEqual(q.clear(), 10) + assert q.clear() == 10 def test_enter_exit(self): - if self.abstract: - return q = self.Queue('test_enter_exit') q.close = Mock() - self.assertIs(q.__enter__(), q) + assert q.__enter__() is q q.__exit__() q.close.assert_called_with() def test_qsize(self): - if self.abstract: - return q = self.Queue('test_clear', no_ack=True) for i in range(10): q.put({'hello': 'SimplePurge%d' % (i,)}) - self.assertEqual(q.qsize(), 10) - self.assertEqual(len(q), 10) + assert q.qsize() == 10 + assert len(q) == 10 def test_autoclose(self): - if self.abstract: - return channel = self.connection.channel() q = self.Queue('test_autoclose', no_ack=True, channel=channel) q.close() def test_custom_Queue(self): - if self.abstract: - return n = self.__class__.__name__ exchange = Exchange('%s-test.custom.Queue' % (n,)) queue = Queue('%s-test.custom.Queue' % (n,), @@ -104,33 +89,29 @@ class SimpleBase(Case): 'my.routing.key') q = self.Queue(queue) - self.assertEqual(q.consumer.queues[0], queue) + assert q.consumer.queues[0] == queue q.close() def test_bool(self): - if self.abstract: - return q = self.Queue('test_nonzero') - self.assertTrue(q) + assert q class test_SimpleQueue(SimpleBase): - abstract = False def _Queue(self, *args, **kwargs): return self.connection.SimpleQueue(*args, **kwargs) def test_is_ack(self): q = self.Queue('test_is_no_ack') - self.assertFalse(q.no_ack) + assert not q.no_ack class test_SimpleBuffer(SimpleBase): - abstract = False def Queue(self, *args, **kwargs): return self.connection.SimpleBuffer(*args, **kwargs) def test_is_no_ack(self): q = self.Queue('test_is_no_ack') - self.assertTrue(q.no_ack) + assert q.no_ack diff --git a/kombu/tests/test_syn.py b/t/unit/test_syn.py index bf5fd972..eaa80f90 100644 --- a/kombu/tests/test_syn.py +++ b/t/unit/test_syn.py @@ -4,45 +4,45 @@ import socket import sys import types +from case import mock, patch + from kombu import syn from kombu.five import bytes_if_py2 -from kombu.tests.case import Case, mock, patch - -class test_syn(Case): +class test_syn: def test_compat(self): - self.assertEqual(syn.blocking(lambda: 10), 10) + assert syn.blocking(lambda: 10) == 10 syn.select_blocking_method('foo') def test_detect_environment(self): try: syn._environment = None X = syn.detect_environment() - self.assertEqual(syn._environment, X) + assert syn._environment == X Y = syn.detect_environment() - self.assertEqual(Y, X) + assert Y == X finally: syn._environment = None @mock.module_exists('eventlet', 'eventlet.patcher') def test_detect_environment_eventlet(self): with patch('eventlet.patcher.is_monkey_patched', create=True) as m: - self.assertTrue(sys.modules['eventlet']) + assert sys.modules['eventlet'] m.return_value = True env = syn._detect_environment() m.assert_called_with(socket) - self.assertEqual(env, 'eventlet') + assert env == 'eventlet' @mock.module_exists('gevent') def test_detect_environment_gevent(self): with patch('gevent.socket', create=True) as m: prev, socket.socket = socket.socket, m.socket try: - self.assertTrue(sys.modules['gevent']) + assert sys.modules['gevent'] env = syn._detect_environment() - self.assertEqual(env, 'gevent') + assert env == 'gevent' finally: socket.socket = prev @@ -52,14 +52,14 @@ class test_syn(Case): bytes_if_py2('eventlet')) sys.modules['eventlet.patcher'] = types.ModuleType( bytes_if_py2('patcher')) - self.assertEqual(syn._detect_environment(), 'default') + assert syn._detect_environment() == 'default' finally: sys.modules.pop('eventlet.patcher', None) sys.modules.pop('eventlet', None) syn._detect_environment() try: sys.modules['gevent'] = types.ModuleType(bytes_if_py2('gevent')) - self.assertEqual(syn._detect_environment(), 'default') + assert syn._detect_environment() == 'default' finally: sys.modules.pop('gevent', None) syn._detect_environment() diff --git a/kombu/tests/transport/virtual/__init__.py b/t/unit/transport/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/transport/virtual/__init__.py +++ b/t/unit/transport/__init__.py diff --git a/kombu/tests/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 98ea3401..5032d43c 100644 --- a/kombu/tests/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -7,6 +7,10 @@ slightly. from __future__ import absolute_import, unicode_literals +import pytest + +from case import skip + from kombu import five from kombu import messaging from kombu import Connection, Exchange, Queue @@ -14,8 +18,6 @@ from kombu import Connection, Exchange, Queue from kombu.transport import SQS from kombu.async.aws.ext import exception -from kombu.tests.case import Case, skip - class SQSQueueMock(object): @@ -95,7 +97,7 @@ class SQSConnectionMock(object): @skip.unless_module('boto') -class test_Channel(Case): +class test_Channel: def handleMessageCallback(self, message): self.callback_message = message @@ -141,9 +143,20 @@ class test_Channel(Case): callback=self.handleMessageCallback, consumer_tag='unittest') + def teardown(self): + # Removes QoS reserved messages so we don't restore msgs on shutdown. + try: + qos = self.channel._qos + except AttributeError: + pass + else: + if qos: + qos._dirty.clear() + qos._delivered.clear() + def test_init(self): """kombu.SQS.Channel instantiates correctly with mocked queues""" - self.assertIn(self.queue_name, self.channel._queue_cache) + assert self.queue_name in self.channel._queue_cache def test_auth_fail(self): normal_func = SQS.Channel.sqs.get_all_queues @@ -159,11 +172,11 @@ class test_Channel(Case): try: SQS.Channel.sqs.access_key = '1234' SQS.Channel.sqs.get_all_queues = get_all_queues_fail_403 - with self.assertRaises(RuntimeError) as context: + with pytest.raises(RuntimeError) as excinfo: self.channel = self.connection.channel() - self.assertIn('access_key=1234', str(context.exception)) + assert 'access_key=1234' in str(excinfo.value) SQS.Channel.sqs.get_all_queues = get_all_queues_fail_not_403 - with self.assertRaises(exception.SQSError) as context: + with pytest.raises(exception.SQSError): self.channel = self.connection.channel() finally: SQS.Channel.sqs.get_all_queues = normal_func @@ -171,7 +184,7 @@ class test_Channel(Case): def test_new_queue(self): queue_name = 'new_unittest_queue' self.channel._new_queue(queue_name) - self.assertIn(queue_name, self.sqs_conn_mock.queues) + assert queue_name in self.sqs_conn_mock.queues # For cleanup purposes, delete the queue and the queue file self.channel._delete(queue_name) @@ -181,16 +194,16 @@ class test_Channel(Case): # first 1000 queues sorted by name. queue_name = 'unittest_queue' self.channel._new_queue(queue_name) - self.assertIn(queue_name, self.sqs_conn_mock.queues) + assert queue_name in self.sqs_conn_mock.queues q = self.sqs_conn_mock.get_queue(queue_name) - self.assertEqual(1, q.count()) - self.assertEqual('hello', q.read()) + assert 1 == q.count() + assert 'hello' == q.read() def test_delete(self): queue_name = 'new_unittest_queue' self.channel._new_queue(queue_name) self.channel._delete(queue_name) - self.assertNotIn(queue_name, self.channel._queue_cache) + assert queue_name not in self.channel._queue_cache def test_get_from_sqs(self): # Test getting a single message @@ -198,7 +211,7 @@ class test_Channel(Case): self.producer.publish(message) q = self.channel._new_queue(self.queue_name) results = q.get_messages() - self.assertEqual(len(results), 1) + assert len(results) == 1 # Now test getting many messages for i in range(3): @@ -206,14 +219,14 @@ class test_Channel(Case): self.producer.publish(message) results = q.get_messages(num_messages=3) - self.assertEqual(len(results), 3) + assert len(results) == 3 def test_get_with_empty_list(self): - with self.assertRaises(five.Empty): + with pytest.raises(five.Empty): self.channel._get(self.queue_name) def test_get_bulk_raises_empty(self): - with self.assertRaises(five.Empty): + with pytest.raises(five.Empty): self.channel._get_bulk(self.queue_name) def test_messages_to_python(self): @@ -243,27 +256,27 @@ class test_Channel(Case): ) # We got the same number of payloads back, right? - self.assertEqual(len(kombu_payloads), kombu_message_count) - self.assertEqual(len(json_payloads), json_message_count) + assert len(kombu_payloads) == kombu_message_count + assert len(json_payloads) == json_message_count # Make sure they're payload-style objects for p in kombu_payloads: - self.assertIn('properties', p) + assert 'properties' in p for p in json_payloads: - self.assertIn('properties', p) + assert 'properties' in p def test_put_and_get(self): message = 'my test message' self.producer.publish(message) results = self.queue(self.channel).get().payload - self.assertEqual(message, results) + assert message == results def test_put_and_get_bulk(self): # With QoS.prefetch_count = 0 message = 'my test message' self.producer.publish(message) results = self.channel._get_bulk(self.queue_name) - self.assertEqual(1, len(results)) + assert 1 == len(results) def test_puts_and_get_bulk(self): # Generate 8 messages @@ -280,19 +293,19 @@ class test_Channel(Case): # Count how many messages are retrieved the first time. Should # be 5 (message_count). results = self.channel._get_bulk(self.queue_name) - self.assertEqual(5, len(results)) + assert 5 == len(results) for i, r in enumerate(results): self.channel.qos.append(r, i) # Now, do the get again, the number of messages returned should be 1. results = self.channel._get_bulk(self.queue_name) - self.assertEqual(len(results), 1) + assert len(results) == 1 def test_drain_events_with_empty_list(self): def mock_can_consume(): return False self.channel.qos.can_consume = mock_can_consume - with self.assertRaises(five.Empty): + with pytest.raises(five.Empty): self.channel.drain_events() def test_drain_events_with_prefetch_5(self): @@ -312,9 +325,8 @@ class test_Channel(Case): self.channel.drain_events() # How many times was the SQSConnectionMock get_message method called? - self.assertEqual( - expected_get_message_count, - self.channel._queue_cache[self.queue_name]._get_message_calls) + assert (expected_get_message_count == + self.channel._queue_cache[self.queue_name]._get_message_calls) def test_drain_events_with_prefetch_none(self): # Generate 20 messages @@ -333,6 +345,5 @@ class test_Channel(Case): self.channel.drain_events() # How many times was the SQSConnectionMock get_message method called? - self.assertEqual( - expected_get_message_count, - self.channel._queue_cache[self.queue_name]._get_message_calls) + assert (expected_get_message_count == + self.channel._queue_cache[self.queue_name]._get_message_calls) diff --git a/kombu/tests/transport/test_base.py b/t/unit/transport/test_base.py index 77d2b144..bf1d551c 100644 --- a/kombu/tests/transport/test_base.py +++ b/t/unit/transport/test_base.py @@ -1,14 +1,16 @@ from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock + from kombu import Connection, Consumer, Exchange, Producer, Queue from kombu.five import text_t from kombu.message import Message from kombu.transport.base import StdChannel, Transport, Management -from kombu.tests.case import Case, Mock - -class test_StdChannel(Case): +class test_StdChannel: def setup(self): self.conn = Connection('memory://') @@ -20,25 +22,23 @@ class test_StdChannel(Case): q = Queue('foo', Exchange('foo')) print(self.channel.queues) cons = self.channel.Consumer(q) - self.assertIsInstance(cons, Consumer) - self.assertIs(cons.channel, self.channel) + assert isinstance(cons, Consumer) + assert cons.channel is self.channel def test_Producer(self): prod = self.channel.Producer() - self.assertIsInstance(prod, Producer) - self.assertIs(prod.channel, self.channel) + assert isinstance(prod, Producer) + assert prod.channel is self.channel def test_interface_get_bindings(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): StdChannel().get_bindings() def test_interface_after_reply_message_received(self): - self.assertIsNone( - StdChannel().after_reply_message_received(Queue('foo')), - ) + assert StdChannel().after_reply_message_received(Queue('foo')) is None -class test_Message(Case): +class test_Message: def setup(self): self.conn = Connection('memory://') @@ -47,7 +47,7 @@ class test_Message(Case): def test_postencode(self): m = Message(self.channel, text_t('FOO'), postencode='ccyzz') - with self.assertRaises(LookupError): + with pytest.raises(LookupError): m._reraise_error() m.ack() @@ -57,7 +57,7 @@ class test_Message(Case): ack = self.channel.basic_ack = Mock() self.message.ack() - self.assertNotEqual(self.message._state, 'ACK') + assert self.message._state != 'ACK' ack.assert_not_called() def test_ack_missing_consumer_tag(self): @@ -88,7 +88,7 @@ class test_Message(Case): self.message.ack_log_error(logger, KeyError) ack.assert_called_with(multiple=False) logger.critical.assert_called() - self.assertIn("Couldn't ack", logger.critical.call_args[0][0]) + assert "Couldn't ack" in logger.critical.call_args[0][0] def test_reject_log_error_when_no_error(self): reject = self.message.reject = Mock() @@ -102,36 +102,36 @@ class test_Message(Case): self.message.reject_log_error(logger, KeyError) reject.assert_called_with(requeue=False) logger.critical.assert_called() - self.assertIn("Couldn't reject", logger.critical.call_args[0][0]) + assert "Couldn't reject" in logger.critical.call_args[0][0] -class test_interface(Case): +class test_interface: def test_establish_connection(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): Transport(None).establish_connection() def test_close_connection(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): Transport(None).close_connection(None) def test_create_channel(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): Transport(None).create_channel(None) def test_close_channel(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): Transport(None).close_channel(None) def test_drain_events(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): Transport(None).drain_events(None) def test_heartbeat_check(self): Transport(None).heartbeat_check(Mock(name='connection')) def test_driver_version(self): - self.assertTrue(Transport(None).driver_version()) + assert Transport(None).driver_version() def test_register_with_event_loop(self): Transport(None).register_with_event_loop( @@ -144,12 +144,12 @@ class test_interface(Case): ) def test_manager(self): - self.assertTrue(Transport(None).manager) + assert Transport(None).manager -class test_Management(Case): +class test_Management: def test_get_bindings(self): m = Management(Mock(name='transport')) - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): m.get_bindings() diff --git a/kombu/tests/transport/test_consul.py b/t/unit/transport/test_consul.py index ca2bb4bc..43b2b165 100644 --- a/kombu/tests/transport/test_consul.py +++ b/t/unit/transport/test_consul.py @@ -1,62 +1,60 @@ from __future__ import absolute_import, unicode_literals -from kombu.five import Empty +import pytest -from kombu.transport.consul import Channel, Transport +from case import Mock, skip -from kombu.tests.case import Case, Mock, skip +from kombu.five import Empty +from kombu.transport.consul import Channel, Transport @skip.unless_module('consul') -class test_Consul(Case): +class test_Consul: def setup(self): self.connection = Mock() self.connection.client.transport_options = {} self.connection.client.port = 303 - self.consul = self.patch('consul.Consul').return_value + self.consul = self.patching('consul.Consul').return_value self.channel = Channel(connection=self.connection) def test_driver_version(self): - self.assertTrue(Transport(self.connection.client).driver_version()) + assert Transport(self.connection.client).driver_version() def test_failed_get(self): self.channel._acquire_lock = Mock(return_value=False) self.channel.client.kv.get.return_value = (1, None) - with self.assertRaises(Empty): + with pytest.raises(Empty): self.channel._get('empty')() def test_test_purge(self): self.channel._destroy_session = Mock(return_value=True) self.consul.kv.delete = Mock(return_value=True) - self.assertTrue(self.channel._purge('foo')) + assert self.channel._purge('foo') def test_variables(self): - self.assertEqual(self.channel.session_ttl, 30) - self.assertEqual(self.channel.timeout, '10s') + assert self.channel.session_ttl == 30 + assert self.channel.timeout == '10s' def test_lock_key(self): key = self.channel._lock_key('myqueue') - self.assertEqual(key, 'kombu/myqueue.lock') + assert key == 'kombu/myqueue.lock' def test_key_prefix(self): key = self.channel._key_prefix('myqueue') - self.assertEqual(key, 'kombu/myqueue') + assert key == 'kombu/myqueue' def test_get_or_create_session(self): queue = 'myqueue' session_id = '123456' self.consul.session.create.return_value = session_id - self.assertEqual( - self.channel._get_or_create_session(queue), - session_id, - ) + assert self.channel._get_or_create_session(queue) == session_id def test_create_delete_queue(self): queue = 'mynewqueue' self.consul.kv.put.return_value = True - self.assertTrue(self.channel._new_queue(queue)) + assert self.channel._new_queue(queue) self.consul.kv.delete.return_value = True self.channel._destroy_session = Mock() @@ -64,7 +62,7 @@ class test_Consul(Case): def test_size(self): self.consul.kv.get.return_value = [(1, {}), (2, {})] - self.assertEqual(self.channel._size('q'), 2) + assert self.channel._size('q') == 2 def test_get(self): self.channel._obtain_lock = Mock(return_value=True) @@ -76,8 +74,8 @@ class test_Consul(Case): self.consul.kv.delete.return_value = True - self.assertIsNotNone(self.channel._get('myqueue')) + assert self.channel._get('myqueue') is not None def test_put(self): self.consul.kv.put.return_value = True - self.assertIsNone(self.channel._put('myqueue', 'mydata')) + assert self.channel._put('myqueue', 'mydata') is None diff --git a/kombu/tests/transport/test_filesystem.py b/t/unit/transport/test_filesystem.py index b2ab368d..52a925f0 100644 --- a/kombu/tests/transport/test_filesystem.py +++ b/t/unit/transport/test_filesystem.py @@ -2,16 +2,17 @@ from __future__ import absolute_import, unicode_literals import tempfile -from kombu import Connection, Exchange, Queue, Consumer, Producer - +from case import skip from case.skip import SkipTest -from kombu.tests.case import Case, skip + +from kombu import Connection, Exchange, Queue, Consumer, Producer @skip.if_win32() -class test_FilesystemTransport(Case): +class test_FilesystemTransport: def setup(self): + self.channels = set() try: data_folder_in = tempfile.mkdtemp() data_folder_out = tempfile.mkdtemp() @@ -22,11 +23,13 @@ class test_FilesystemTransport(Case): 'data_folder_in': data_folder_in, 'data_folder_out': data_folder_out, }) + self.channels.add(self.c.default_channel) self.p = Connection(transport='filesystem', transport_options={ 'data_folder_in': data_folder_out, 'data_folder_out': data_folder_in, }) + self.channels.add(self.p.default_channel) self.e = Exchange('test_transport_filesystem') self.q = Queue('test_transport_filesystem', exchange=self.e, @@ -35,9 +38,26 @@ class test_FilesystemTransport(Case): exchange=self.e, routing_key='test_transport_filesystem2') + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in self.channels: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def _add_channel(self, channel): + self.channels.add(channel) + return channel + def test_produce_consume_noack(self): - producer = Producer(self.p.channel(), self.e) - consumer = Consumer(self.c.channel(), self.q, no_ack=True) + producer = Producer(self._add_channel(self.p.channel()), self.e) + consumer = Consumer(self._add_channel(self.c.channel()), self.q, + no_ack=True) for i in range(10): producer.publish({'foo': i}, @@ -56,11 +76,11 @@ class test_FilesystemTransport(Case): break self.c.drain_events() - self.assertEqual(len(_received), 10) + assert len(_received) == 10 def test_produce_consume(self): - producer_channel = self.p.channel() - consumer_channel = self.c.channel() + producer_channel = self._add_channel(self.p.channel()) + consumer_channel = self._add_channel(self.c.channel()) producer = Producer(producer_channel, self.e) consumer1 = Consumer(consumer_channel, self.q) consumer2 = Consumer(consumer_channel, self.q2) @@ -95,28 +115,28 @@ class test_FilesystemTransport(Case): break self.c.drain_events() - self.assertEqual(len(_received1) + len(_received2), 20) + assert len(_received1) + len(_received2) == 20 # compression producer.publish({'compressed': True}, routing_key='test_transport_filesystem', compression='zlib') m = self.q(consumer_channel).get() - self.assertDictEqual(m.payload, {'compressed': True}) + assert m.payload == {'compressed': True} # queue.delete for i in range(10): producer.publish({'foo': i}, routing_key='test_transport_filesystem') - self.assertTrue(self.q(consumer_channel).get()) + assert self.q(consumer_channel).get() self.q(consumer_channel).delete() self.q(consumer_channel).declare() - self.assertIsNone(self.q(consumer_channel).get()) + assert self.q(consumer_channel).get() is None # queue.purge for i in range(10): producer.publish({'foo': i}, routing_key='test_transport_filesystem2') - self.assertTrue(self.q2(consumer_channel).get()) + assert self.q2(consumer_channel).get() self.q2(consumer_channel).purge() - self.assertIsNone(self.q2(consumer_channel).get()) + assert self.q2(consumer_channel).get() is None diff --git a/kombu/tests/transport/test_librabbitmq.py b/t/unit/transport/test_librabbitmq.py index bbfbe961..26dbfbcd 100644 --- a/kombu/tests/transport/test_librabbitmq.py +++ b/t/unit/transport/test_librabbitmq.py @@ -1,5 +1,9 @@ from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock, patch, skip + try: import librabbitmq except ImportError: @@ -7,11 +11,9 @@ except ImportError: else: from kombu.transport import librabbitmq # noqa -from kombu.tests.case import Case, Mock, patch, skip - @skip.unless_module('librabbitmq') -class lrmqCase(Case): +class lrmqCase: pass @@ -22,9 +24,9 @@ class test_Message(lrmqCase): message = librabbitmq.Message( chan, {'prop': 42}, {'delivery_tag': 337}, 'body', ) - self.assertEqual(message.body, 'body') - self.assertEqual(message.delivery_tag, 337) - self.assertEqual(message.properties['prop'], 42) + assert message.body == 'body' + assert message.delivery_tag == 337 + assert message.properties['prop'] == 42 class test_Channel(lrmqCase): @@ -32,7 +34,7 @@ class test_Channel(lrmqCase): def test_prepare_message(self): conn = Mock(name='connection') chan = librabbitmq.Channel(conn, 1) - self.assertTrue(chan) + assert chan body = 'the quick brown fox...' properties = {'name': 'Elaine M.'} @@ -45,32 +47,31 @@ class test_Channel(lrmqCase): headers={'H': 2}, ) - self.assertEqual(props2['name'], 'Elaine M.') - self.assertEqual(props2['priority'], 999) - self.assertEqual(props2['content_type'], 'ctype') - self.assertEqual(props2['content_encoding'], 'cenc') - self.assertEqual(props2['headers'], {'H': 2}) - self.assertEqual(body2, body) + assert props2['name'] == 'Elaine M.' + assert props2['priority'] == 999 + assert props2['content_type'] == 'ctype' + assert props2['content_encoding'] == 'cenc' + assert props2['headers'] == {'H': 2} + assert body2 == body body3, props3 = chan.prepare_message(body, priority=777) - self.assertEqual(props3['priority'], 777) - self.assertEqual(body3, body) + assert props3['priority'] == 777 + assert body3 == body class test_Transport(lrmqCase): def setup(self): - super(test_Transport, self).setup() self.client = Mock(name='client') self.T = librabbitmq.Transport(self.client) def test_driver_version(self): - self.assertTrue(self.T.driver_version()) + assert self.T.driver_version() def test_create_channel(self): conn = Mock(name='connection') chan = self.T.create_channel(conn) - self.assertTrue(chan) + assert chan conn.channel.assert_called_with() def test_drain_events(self): @@ -80,7 +81,7 @@ class test_Transport(lrmqCase): def test_establish_connection_SSL_not_supported(self): self.client.ssl = True - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.T.establish_connection() def test_establish_connection(self): @@ -90,18 +91,15 @@ class test_Transport(lrmqCase): self.T.client.transport_options = {} conn = self.T.establish_connection() - self.assertEqual( - self.T.client.port, - self.T.default_connection_params['port'], - ) - self.assertEqual(conn.client, self.T.client) - self.assertEqual(self.T.client.drain_events, conn.drain_events) + assert self.T.client.port == self.T.default_connection_params['port'] + assert conn.client == self.T.client + assert self.T.client.drain_events == conn.drain_events def test_collect__no_conn(self): self.T.client.drain_events = 1234 self.T._collect(None) - self.assertIsNone(self.client.drain_events) - self.assertIsNone(self.T.client) + assert self.client.drain_events is None + assert self.T.client is None def test_collect__with_conn(self): self.T.client.drain_events = 1234 @@ -114,12 +112,12 @@ class test_Transport(lrmqCase): with patch('os.close') as close: self.T._collect(conn) close.assert_called_with(conn.fileno()) - self.assertFalse(conn.channels) - self.assertFalse(conn.callbacks) + assert not conn.channels + assert not conn.callbacks for chan in chans.values(): - self.assertIsNone(chan.connection) - self.assertIsNone(self.client.drain_events) - self.assertIsNone(self.T.client) + assert chan.connection is None + assert self.client.drain_events is None + assert self.T.client is None with patch('os.close') as close: self.T.client = self.client @@ -138,11 +136,11 @@ class test_Transport(lrmqCase): def test_verify_connection(self): conn = Mock(name='connection') conn.connected = True - self.assertTrue(self.T.verify_connection(conn)) + assert self.T.verify_connection(conn) def test_close_connection(self): conn = Mock(name='connection') self.client.drain_events = 1234 self.T.close_connection(conn) - self.assertIsNone(self.client.drain_events) + assert self.client.drain_events is None conn.close.assert_called_with() diff --git a/kombu/tests/transport/test_memory.py b/t/unit/transport/test_memory.py index 1ca83b9c..d4dc3390 100644 --- a/kombu/tests/transport/test_memory.py +++ b/t/unit/transport/test_memory.py @@ -1,13 +1,12 @@ from __future__ import absolute_import, unicode_literals +import pytest import socket from kombu import Connection, Exchange, Queue, Consumer, Producer -from kombu.tests.case import Case - -class test_MemoryTransport(Case): +class test_MemoryTransport: def setup(self): self.c = Connection(transport='memory') @@ -25,7 +24,7 @@ class test_MemoryTransport(Case): exchange=self.fanout) def test_driver_version(self): - self.assertTrue(self.c.transport.driver_version()) + assert self.c.transport.driver_version() def test_produce_consume_noack(self): channel = self.c.channel() @@ -48,7 +47,7 @@ class test_MemoryTransport(Case): break self.c.drain_events() - self.assertEqual(len(_received), 10) + assert len(_received) == 10 def test_produce_consume_fanout(self): producer = self.c.Producer() @@ -60,10 +59,10 @@ class test_MemoryTransport(Case): exchange=self.fanout, ) - self.assertEqual(self.q3(self.c).get().payload, {'hello': 'world'}) - self.assertEqual(self.q4(self.c).get().payload, {'hello': 'world'}) - self.assertIsNone(self.q3(self.c).get()) - self.assertIsNone(self.q4(self.c).get()) + assert self.q3(self.c).get().payload == {'hello': 'world'} + assert self.q4(self.c).get().payload == {'hello': 'world'} + assert self.q3(self.c).get() is None + assert self.q4(self.c).get() is None def test_produce_consume(self): channel = self.c.channel() @@ -99,38 +98,38 @@ class test_MemoryTransport(Case): break self.c.drain_events() - self.assertEqual(len(_received1) + len(_received2), 20) + assert len(_received1) + len(_received2) == 20 # compression producer.publish({'compressed': True}, routing_key='test_transport_memory', compression='zlib') m = self.q(channel).get() - self.assertDictEqual(m.payload, {'compressed': True}) + assert m.payload == {'compressed': True} # queue.delete for i in range(10): producer.publish({'foo': i}, routing_key='test_transport_memory') - self.assertTrue(self.q(channel).get()) + assert self.q(channel).get() self.q(channel).delete() self.q(channel).declare() - self.assertIsNone(self.q(channel).get()) + assert self.q(channel).get() is None # queue.purge for i in range(10): producer.publish({'foo': i}, routing_key='test_transport_memory2') - self.assertTrue(self.q2(channel).get()) + assert self.q2(channel).get() self.q2(channel).purge() - self.assertIsNone(self.q2(channel).get()) + assert self.q2(channel).get() is None def test_drain_events(self): - with self.assertRaises(socket.timeout): + with pytest.raises(socket.timeout): self.c.drain_events(timeout=0.1) c1 = self.c.channel() c2 = self.c.channel() - with self.assertRaises(socket.timeout): + with pytest.raises(socket.timeout): self.c.drain_events(timeout=0.1) del(c1) # so pyflakes doesn't complain. @@ -162,5 +161,5 @@ class test_MemoryTransport(Case): chan.queues.clear() x = chan._queue_for('foo') - self.assertTrue(x) - self.assertIs(chan._queue_for('foo'), x) + assert x + assert chan._queue_for('foo') is x diff --git a/kombu/tests/transport/test_mongodb.py b/t/unit/transport/test_mongodb.py index cbe8301a..b21be7f3 100644 --- a/kombu/tests/transport/test_mongodb.py +++ b/t/unit/transport/test_mongodb.py @@ -1,10 +1,12 @@ from __future__ import absolute_import, unicode_literals import datetime +import pytest + +from case import MagicMock, call, patch, skip from kombu import Connection from kombu.five import Empty -from kombu.tests.case import Case, MagicMock, call, patch, skip def _create_mock_connection(url='', **kwargs): @@ -44,7 +46,7 @@ def _create_mock_connection(url='', **kwargs): @skip.unless_module('pymongo') -class test_mongodb_uri_parsing(Case): +class test_mongodb_uri_parsing: def test_defaults(self): url = 'mongodb://' @@ -53,22 +55,22 @@ class test_mongodb_uri_parsing(Case): hostname, dbname, options = channel._parse_uri() - self.assertEqual(dbname, 'kombu_default') - self.assertEqual(hostname, 'mongodb://127.0.0.1') + assert dbname == 'kombu_default' + assert hostname == 'mongodb://127.0.0.1' def test_custom_host(self): url = 'mongodb://localhost' channel = _create_mock_connection(url).default_channel hostname, dbname, options = channel._parse_uri() - self.assertEqual(dbname, 'kombu_default') + assert dbname == 'kombu_default' def test_custom_database(self): url = 'mongodb://localhost/dbname' channel = _create_mock_connection(url).default_channel hostname, dbname, options = channel._parse_uri() - self.assertEqual(dbname, 'dbname') + assert dbname == 'dbname' def test_custom_credentials(self): url = 'mongodb://localhost/dbname' @@ -76,11 +78,11 @@ class test_mongodb_uri_parsing(Case): url, userid='foo', password='bar').default_channel hostname, dbname, options = channel._parse_uri() - self.assertEqual(hostname, 'mongodb://foo:bar@localhost/dbname') - self.assertEqual(dbname, 'dbname') + assert hostname == 'mongodb://foo:bar@localhost/dbname' + assert dbname == 'dbname' -class BaseMongoDBChannelCase(Case): +class BaseMongoDBChannelCase: def _get_method(self, cname, mname): collection = getattr(self.channel, cname) @@ -104,7 +106,7 @@ class BaseMongoDBChannelCase(Case): self.channel._queue_bind('fanout_exchange', 'foo', '*', queue) - self.assertIn(queue, self.channel._broadcast_cursors) + assert queue in self.channel._broadcast_cursors def get_broadcast(self, queue): return self.channel._broadcast_cursors[queue] @@ -162,10 +164,10 @@ class test_mongodb_channel(BaseMongoDBChannelCase): ], ) - self.assertDictEqual(event, {'some': 'data'}) + assert event == {'some': 'data'} self.set_operation_return_value('messages', 'find_and_modify', None) - with self.assertRaises(Empty): + with pytest.raises(Empty): self.channel._get('foobar') def test_get_fanout(self): @@ -175,9 +177,9 @@ class test_mongodb_channel(BaseMongoDBChannelCase): event = self.channel._get('foobar') self.assert_collection_accessed('messages.broadcast') - self.assertDictEqual(event, {'some': 'data'}) + assert event == {'some': 'data'} - with self.assertRaises(Empty): + with pytest.raises(Empty): self.channel._get('foobar') def test_put(self): @@ -209,7 +211,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): 'messages', 'find', {'queue': 'foobar'}, ) - self.assertEqual(result, 77) + assert result == 77 def test_size_fanout(self): self.declare_droadcast_queue('foobar') @@ -220,7 +222,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): result = self.channel._size('foobar') - self.assertEqual(result, 77) + assert result == 77 def test_purge(self): self.set_operation_return_value('messages', 'find.count', 77) @@ -231,7 +233,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): 'messages', 'remove', {'queue': 'foobar'}, ) - self.assertEqual(result, 77) + assert result == 77 def test_purge_fanout(self): self.declare_droadcast_queue('foobar') @@ -244,7 +246,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): cursor.purge.assert_any_call() - self.assertEqual(result, 77) + assert result == 77 def test_get_table(self): state_table = [('foo', '*', 'foo')] @@ -266,9 +268,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase): 'routing', 'find', {'exchange': 'test_exchange'}, ) - self.assertSetEqual( - set(result), frozenset(state_table) | frozenset(stored_table), - ) + assert set(result) == frozenset(state_table) | frozenset(stored_table) def test_queue_bind(self): self.channel._queue_bind('test_exchange', 'foo', '*', 'foo') @@ -299,8 +299,8 @@ class test_mongodb_channel(BaseMongoDBChannelCase): cursor.close.assert_any_call() - self.assertNotIn('foobar', self.channel._broadcast_cursors) - self.assertNotIn('foobar', self.channel._fanout_queues) + assert 'foobar' not in self.channel._broadcast_cursors + assert 'foobar' not in self.channel._fanout_queues # Tests for channel internals @@ -468,14 +468,14 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase): self.channel.client.assert_not_called() - self.assertEqual(result, self.expire_at) + assert result == self.expire_at self.set_operation_return_value('queues', 'find_one', { '_id': 'docId', 'options': {'arguments': {'x-expires': 777}}, }) result = self.channel._get_expire('foobar', 'x-expires') - self.assertEqual(result, self.expire_at) + assert result == self.expire_at def test_update_queues_expire(self): self.set_operation_return_value('queues', 'find_one', { @@ -518,4 +518,4 @@ class test_mongodb_channel_calc_queue_size(BaseMongoDBChannelCase): self.assert_operation_has_calls('messages', 'find', []) - self.assertEqual(result, 0) + assert result == 0 diff --git a/kombu/tests/transport/test_pyamqp.py b/t/unit/transport/test_pyamqp.py index 5bb8975f..1c8812b8 100644 --- a/kombu/tests/transport/test_pyamqp.py +++ b/t/unit/transport/test_pyamqp.py @@ -4,16 +4,11 @@ import sys from itertools import count -try: - import amqp # noqa -except ImportError: - pyamqp = None # noqa -else: - from kombu.transport import pyamqp +from case import Mock, mock, patch + from kombu import Connection from kombu.five import nextfun - -from kombu.tests.case import Case, Mock, mock, patch +from kombu.transport import pyamqp class MockConnection(dict): @@ -25,7 +20,7 @@ class MockConnection(dict): pass -class test_Channel(Case): +class test_Channel: def setup(self): @@ -47,39 +42,39 @@ class test_Channel(Case): self.channel = Channel(self.conn, 0) def test_init(self): - self.assertFalse(self.channel.no_ack_consumers) + assert not self.channel.no_ack_consumers def test_prepare_message(self): - self.assertTrue(self.channel.prepare_message( + assert self.channel.prepare_message( 'foobar', 10, 'application/data', 'utf-8', properties={}, - )) + ) def test_message_to_python(self): message = Mock() message.headers = {} message.properties = {} - self.assertTrue(self.channel.message_to_python(message)) + assert self.channel.message_to_python(message) def test_close_resolves_connection_cycle(self): - self.assertIsNotNone(self.channel.connection) + assert self.channel.connection is not None self.channel.close() - self.assertIsNone(self.channel.connection) + assert self.channel.connection is None def test_basic_consume_registers_ack_status(self): self.channel.wait_returns = 'my-consumer-tag' self.channel.basic_consume('foo', no_ack=True) - self.assertIn('my-consumer-tag', self.channel.no_ack_consumers) + assert 'my-consumer-tag' in self.channel.no_ack_consumers self.channel.wait_returns = 'other-consumer-tag' self.channel.basic_consume('bar', no_ack=False) - self.assertNotIn('other-consumer-tag', self.channel.no_ack_consumers) + assert 'other-consumer-tag' not in self.channel.no_ack_consumers self.channel.basic_cancel('my-consumer-tag') - self.assertNotIn('my-consumer-tag', self.channel.no_ack_consumers) + assert 'my-consumer-tag' not in self.channel.no_ack_consumers -class test_Transport(Case): +class test_Transport: def setup(self): self.connection = Connection('pyamqp://') @@ -91,7 +86,7 @@ class test_Transport(Case): connection.channel.assert_called_with() def test_driver_version(self): - self.assertTrue(self.transport.driver_version()) + assert self.transport.driver_version() def test_drain_events(self): connection = Mock() @@ -111,18 +106,18 @@ class test_Transport(Case): self.transport.Connection = Conn self.transport.client.hostname = 'localhost' conn1 = self.transport.establish_connection() - self.assertEqual(conn1.host, '127.0.0.1:5672') + assert conn1.host == '127.0.0.1:5672' self.transport.client.hostname = 'example.com' conn2 = self.transport.establish_connection() - self.assertEqual(conn2.host, 'example.com:5672') + assert conn2.host == 'example.com:5672' def test_close_connection(self): connection = Mock() connection.client = Mock() self.transport.close_connection(connection) - self.assertIsNone(connection.client) + assert connection.client is None connection.close.assert_called_with() @mock.mask_modules('ssl') @@ -130,13 +125,13 @@ class test_Transport(Case): pm = sys.modules.pop('amqp.connection') try: from amqp.connection import SSLError - self.assertEqual(SSLError.__module__, 'amqp.connection') + assert SSLError.__module__ == 'amqp.connection' finally: if pm is not None: sys.modules['amqp.connection'] = pm -class test_pyamqp(Case): +class test_pyamqp: def test_default_port(self): @@ -144,8 +139,7 @@ class test_pyamqp(Case): Connection = MockConnection c = Connection(port=None, transport=Transport).connect() - self.assertEqual(c['host'], - '127.0.0.1:%s' % (Transport.default_port,)) + assert c['host'] == '127.0.0.1:%s' % (Transport.default_port,) def test_custom_port(self): @@ -153,7 +147,7 @@ class test_pyamqp(Case): Connection = MockConnection c = Connection(port=1337, transport=Transport).connect() - self.assertEqual(c['host'], '127.0.0.1:1337') + assert c['host'] == '127.0.0.1:1337' def test_register_with_event_loop(self): t = pyamqp.Transport(Mock()) diff --git a/kombu/tests/transport/test_redis.py b/t/unit/transport/test_redis.py index c660badf..fd47317f 100644 --- a/kombu/tests/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -1,11 +1,14 @@ from __future__ import absolute_import, unicode_literals +import pytest import socket import types from collections import defaultdict from itertools import count +from case import ANY, ContextMock, Mock, call, mock, skip, patch + from kombu import Connection, Exchange, Queue, Consumer, Producer from kombu.exceptions import InconsistencyError, VersionMismatch from kombu.five import Empty, Queue as _Queue, bytes_if_py2 @@ -13,10 +16,6 @@ from kombu.transport import virtual from kombu.utils import eventio # patch poll from kombu.utils.json import dumps -from kombu.tests.case import ( - Case, ContextMock, Mock, call, mock, skip, patch, ANY, -) - class _poll(eventio._select): @@ -234,7 +233,7 @@ class Transport(redis.Transport): @skip.unless_module('redis') -class test_Channel(Case): +class test_Channel: def setup(self): self.connection = self.create_connection() @@ -253,7 +252,7 @@ class test_Channel(Case): msg = chan.prepare_message('quick brown fox') chan.basic_publish(msg, n, n) payload = chan._get(n) - self.assertTrue(payload) + assert payload pymsg = chan.message_to_python(payload) return pymsg.delivery_tag @@ -261,11 +260,11 @@ class test_Channel(Case): seen = set() for i in range(100): tag = self._get_one_delivery_tag() - self.assertNotIn(tag, seen) + assert tag not in seen seen.add(tag) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): int(tag) - self.assertEqual(len(tag), 36) + assert len(tag) == 36 def test_disable_ack_emulation(self): conn = Connection(transport=Transport, transport_options={ @@ -273,8 +272,8 @@ class test_Channel(Case): }) chan = conn.channel() - self.assertFalse(chan.ack_emulation) - self.assertEqual(chan.QoS, virtual.QoS) + assert not chan.ack_emulation + assert chan.QoS == virtual.QoS def test_redis_ping_raises(self): pool = Mock(name='pool') @@ -295,13 +294,13 @@ class test_Channel(Case): conn = Connection(transport=XTransport) client.ping.side_effect = RuntimeError() - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): conn.channel() pool.disconnect.assert_called_with() pool.disconnect.reset_mock() pool_at_init = [None] - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): conn.channel() pool.disconnect.assert_not_called() @@ -314,10 +313,8 @@ class test_Channel(Case): pool.disconnect.assert_called_with() def test_next_delivery_tag(self): - self.assertNotEqual( - self.channel._next_delivery_tag(), - self.channel._next_delivery_tag(), - ) + assert (self.channel._next_delivery_tag() != + self.channel._next_delivery_tag()) def test_do_restore_message(self): client = Mock(name='client') @@ -399,21 +396,21 @@ class test_Channel(Case): qos._vrestore_count = 1 qos.restore_visible() client.zrevrangebyscore.assert_not_called() - self.assertEqual(qos._vrestore_count, 2) + assert qos._vrestore_count == 2 qos._vrestore_count = 0 qos.restore_visible() restore.assert_has_calls([ call(1, client), call(2, client), call(3, client), ]) - self.assertEqual(qos._vrestore_count, 1) + assert qos._vrestore_count == 1 qos._vrestore_count = 0 restore.reset_mock() client.zrevrangebyscore.return_value = [] qos.restore_visible() restore.assert_not_called() - self.assertEqual(qos._vrestore_count, 1) + assert qos._vrestore_count == 1 qos._vrestore_count = 0 client.setnx.side_effect = redis.MutexHeld() @@ -424,14 +421,13 @@ class test_Channel(Case): self.channel.queue_declare(queue='txconfanq') self.channel.queue_bind(queue='txconfanq', exchange='txconfan') - self.assertIn('txconfanq', self.channel._fanout_queues) + assert 'txconfanq' in self.channel._fanout_queues self.channel.basic_consume('txconfanq', False, None, 1) - self.assertIn('txconfanq', self.channel.active_fanout_queues) - self.assertEqual(self.channel._fanout_to_queue.get('txconfan'), - 'txconfanq') + assert 'txconfanq' in self.channel.active_fanout_queues + assert self.channel._fanout_to_queue.get('txconfan') == 'txconfanq' def test_basic_cancel_unknown_delivery_tag(self): - self.assertIsNone(self.channel.basic_cancel('txaseqwewq')) + assert self.channel.basic_cancel('txaseqwewq') is None def test_subscribe_no_queues(self): self.channel.subclient = Mock() @@ -448,7 +444,7 @@ class test_Channel(Case): self.channel._subscribe() self.channel.subclient.psubscribe.assert_called() s_args, _ = self.channel.subclient.psubscribe.call_args - self.assertItemsEqual(s_args[0], ['/{db}.a', '/{db}.b']) + assert sorted(s_args[0]) == ['/{db}.a', '/{db}.b'] self.channel.subclient.connection._sock = None self.channel._subscribe() @@ -458,38 +454,34 @@ class test_Channel(Case): s = self.channel.subclient s.subscribed = True self.channel._handle_message(s, ['unsubscribe', 'a', 0]) - self.assertFalse(s.subscribed) + assert not s.subscribed def test_handle_pmessage_message(self): - self.assertDictEqual( - self.channel._handle_message( - self.channel.subclient, - ['pmessage', 'pattern', 'channel', 'data'], - ), - { - 'type': 'pmessage', - 'pattern': 'pattern', - 'channel': 'channel', - 'data': 'data', - }, + res = self.channel._handle_message( + self.channel.subclient, + ['pmessage', 'pattern', 'channel', 'data'], ) + assert res == { + 'type': 'pmessage', + 'pattern': 'pattern', + 'channel': 'channel', + 'data': 'data', + } def test_handle_message(self): - self.assertDictEqual( - self.channel._handle_message( - self.channel.subclient, - ['type', 'channel', 'data'], - ), - { - 'type': 'type', - 'pattern': None, - 'channel': 'channel', - 'data': 'data', - }, + res = self.channel._handle_message( + self.channel.subclient, + ['type', 'channel', 'data'], ) + assert res == { + 'type': 'type', + 'pattern': None, + 'channel': 'channel', + 'data': 'data', + } def test_brpop_start_but_no_queues(self): - self.assertIsNone(self.channel._brpop_start()) + assert self.channel._brpop_start() is None def test_receive(self): s = self.channel.subclient = Mock() @@ -497,37 +489,37 @@ class test_Channel(Case): s.parse_response.return_value = ['message', 'a', dumps({'hello': 'world'})] payload, queue = self.channel._receive() - self.assertDictEqual(payload, {'hello': 'world'}) - self.assertEqual(queue, 'b') + assert payload == {'hello': 'world'} + assert queue == 'b' def test_receive_raises_for_connection_error(self): self.channel._in_listen = True s = self.channel.subclient = Mock() s.parse_response.side_effect = KeyError('foo') - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.channel._receive() - self.assertFalse(self.channel._in_listen) + assert not self.channel._in_listen def test_receive_empty(self): s = self.channel.subclient = Mock() s.parse_response.return_value = None - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): self.channel._receive() def test_receive_different_message_Type(self): s = self.channel.subclient = Mock() s.parse_response.return_value = ['message', '/foo/', 0, 'data'] - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): self.channel._receive() def test_brpop_read_raises(self): c = self.channel.client = Mock() c.parse_response.side_effect = KeyError('foo') - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.channel._brpop_read() c.connection.disconnect.assert_called_with() @@ -536,7 +528,7 @@ class test_Channel(Case): c = self.channel.client = Mock() c.parse_response.return_value = None - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): self.channel._brpop_read() def test_poll_error(self): @@ -547,7 +539,7 @@ class test_Channel(Case): c.parse_response.assert_called_with(c.connection, 'BRPOP') c.parse_response.side_effect = KeyError('foo') - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.channel._poll_error('BRPOP') def test_poll_error_on_type_LISTEN(self): @@ -558,7 +550,7 @@ class test_Channel(Case): c.parse_response.assert_called_with() c.parse_response.side_effect = KeyError('foo') - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.channel._poll_error('LISTEN') def test_put_fanout(self): @@ -609,14 +601,14 @@ class test_Channel(Case): self.channel._create_client.return_value = self.channel.client exists = self.channel.client.exists = Mock() exists.return_value = True - self.assertTrue(self.channel._has_queue('foo')) + assert self.channel._has_queue('foo') exists.assert_has_calls([ call(self.channel._q_for_pri('foo', pri)) for pri in redis.PRIORITY_STEPS ]) exists.return_value = False - self.assertFalse(self.channel._has_queue('foo')) + assert not self.channel._has_queue('foo') def test_close_when_closed(self): self.channel.closed = True @@ -642,47 +634,44 @@ class test_Channel(Case): def test_invalid_database_raises_ValueError(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.channel.connection.client.virtual_host = 'dwqeq' self.channel._connparams() def test_connparams_allows_slash_in_db(self): self.channel.connection.client.virtual_host = '/123' - self.assertEqual(self.channel._connparams()['db'], 123) + assert self.channel._connparams()['db'] == 123 def test_connparams_db_can_be_int(self): self.channel.connection.client.virtual_host = 124 - self.assertEqual(self.channel._connparams()['db'], 124) + assert self.channel._connparams()['db'] == 124 def test_new_queue_with_auto_delete(self): redis.Channel._new_queue(self.channel, 'george', auto_delete=False) - self.assertNotIn('george', self.channel.auto_delete_queues) + assert 'george' not in self.channel.auto_delete_queues redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True) - self.assertIn('elaine', self.channel.auto_delete_queues) + assert 'elaine' in self.channel.auto_delete_queues def test_connparams_regular_hostname(self): self.channel.connection.client.hostname = 'george.vandelay.com' - self.assertEqual( - self.channel._connparams()['host'], - 'george.vandelay.com', - ) + assert self.channel._connparams()['host'] == 'george.vandelay.com' def test_rotate_cycle_ValueError(self): cycle = self.channel._queue_cycle cycle.update(['kramer', 'jerry']) cycle.rotate('kramer') - self.assertEqual(cycle.items, ['jerry', 'kramer']) + assert cycle.items, ['jerry' == 'kramer'] cycle.rotate('elaine') def test_get_client(self): import redis as R KombuRedis = redis.Channel._get_client(self.channel) - self.assertTrue(KombuRedis) + assert KombuRedis Rv = getattr(R, 'VERSION', None) try: R.VERSION = (2, 4, 0) - with self.assertRaises(VersionMismatch): + with pytest.raises(VersionMismatch): redis.Channel._get_client(self.channel) finally: if Rv is not None: @@ -690,8 +679,7 @@ class test_Channel(Case): def test_get_response_error(self): from redis.exceptions import ResponseError - self.assertIs(redis.Channel._get_response_error(self.channel), - ResponseError) + assert redis.Channel._get_response_error(self.channel) is ResponseError def test_avail_client(self): self.channel._pool = Mock() @@ -745,12 +733,10 @@ class test_Channel(Case): cb.assert_called_with(ret[0]) def test_transport_get_errors(self): - self.assertTrue(redis.Transport._get_errors(self.connection.transport)) + assert redis.Transport._get_errors(self.connection.transport) def test_transport_driver_version(self): - self.assertTrue( - redis.Transport.driver_version(self.connection.transport), - ) + assert redis.Transport.driver_version(self.connection.transport) def test_transport_get_errors_when_InvalidData_used(self): from redis import exceptions @@ -764,8 +750,8 @@ class test_Channel(Case): exceptions.DataError = None try: errors = redis.Transport._get_errors(self.connection.transport) - self.assertTrue(errors) - self.assertIn(ID, errors[1]) + assert errors + assert ID in errors[1] finally: if DataError is not None: exceptions.DataError = DataError @@ -779,43 +765,44 @@ class test_Channel(Case): # Everything is fine, there is a list of queues. channel.client.sadd(key, 'celery\x06\x16\x06\x16celery') - self.assertListEqual(channel.get_table('celery'), - [('celery', '', 'celery')]) + assert channel.get_table('celery') == [ + ('celery', '', 'celery'), + ] # ... then for some reason, the _kombu.binding.celery key gets lost channel.client.srem(key) # which raises a channel error so that the consumer/publisher # can recover by redeclaring the required entities. - with self.assertRaises(InconsistencyError): + with pytest.raises(InconsistencyError): self.channel.get_table('celery') def test_socket_connection(self): with patch('kombu.transport.redis.Channel._create_client'): with Connection('redis+socket:///tmp/redis.sock') as conn: connparams = conn.default_channel._connparams() - self.assertTrue(issubclass( + assert issubclass( connparams['connection_class'], redis.redis.UnixDomainSocketConnection, - )) - self.assertEqual(connparams['path'], '/tmp/redis.sock') + ) + assert connparams['path'] == '/tmp/redis.sock' def test_ssl_argument__dict(self): with patch('kombu.transport.redis.Channel._create_client'): with Connection('redis://', ssl={'ca_cert': '/foo'}) as conn: connparams = conn.default_channel._connparams() - self.assertTrue(connparams['ssl']) - self.assertEqual(connparams['ca_cert'], '/foo') + assert connparams['ssl'] + assert connparams['ca_cert'] == '/foo' def test_ssl_argument__bool(self): with patch('kombu.transport.redis.Channel._create_client'): with Connection('redis://', ssl=True) as conn: connparams = conn.default_channel._connparams() - self.assertTrue(connparams['ssl']) + assert connparams['ssl'] @skip.unless_module('redis') -class test_Redis(Case): +class test_Redis: def setup(self): self.connection = Connection(transport=Transport) @@ -832,11 +819,10 @@ class test_Redis(Case): producer.publish({'hello': 'world'}) - self.assertDictEqual(self.queue(channel).get().payload, - {'hello': 'world'}) - self.assertIsNone(self.queue(channel).get()) - self.assertIsNone(self.queue(channel).get()) - self.assertIsNone(self.queue(channel).get()) + assert self.queue(channel).get().payload == {'hello': 'world'} + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None + assert self.queue(channel).get() is None def test_publish__consume(self): connection = Connection(transport=Transport) @@ -854,11 +840,11 @@ class test_Redis(Case): consumer.register_callback(callback) consumer.consume() - self.assertIn(channel, channel.connection.cycle._channels) + assert channel in channel.connection.cycle._channels try: connection.drain_events(timeout=1) - self.assertTrue(_received) - with self.assertRaises(socket.timeout): + assert _received + with pytest.raises(socket.timeout): connection.drain_events(timeout=0.01) finally: channel.close() @@ -871,8 +857,8 @@ class test_Redis(Case): for i in range(10): producer.publish({'hello': 'world-%s' % (i,)}) - self.assertEqual(channel._size('test_Redis'), 10) - self.assertEqual(self.queue(channel).purge(), 10) + assert channel._size('test_Redis') == 10 + assert self.queue(channel).purge() == 10 channel.close() def test_db_values(self): @@ -885,7 +871,7 @@ class test_Redis(Case): Connection(virtual_host='/1', transport=Transport).channel() - with self.assertRaises(Exception): + with pytest.raises(Exception): Connection('redis:///foo').channel() def test_db_port(self): @@ -900,7 +886,7 @@ class test_Redis(Case): cycle = c.connection.cycle c.client.connection c.close() - self.assertNotIn(c, cycle._channels) + assert c not in cycle._channels def test_close_ResponseError(self): c = Connection(transport=Transport).channel() @@ -912,12 +898,12 @@ class test_Redis(Case): conn1 = c.client.connection conn2 = c.subclient.connection c.close() - self.assertTrue(conn1.disconnected) - self.assertTrue(conn2.disconnected) + assert conn1.disconnected + assert conn2.disconnected def test_get__Empty(self): channel = self.connection.channel() - with self.assertRaises(Empty): + with pytest.raises(Empty): channel._get('does-not-exist') channel.close() @@ -925,16 +911,16 @@ class test_Redis(Case): with mock.module_exists(*_redis_modules()): conn = Connection(transport=Transport) chan = conn.channel() - self.assertTrue(chan.Client) - self.assertTrue(chan.ResponseError) - self.assertTrue(conn.transport.connection_errors) - self.assertTrue(conn.transport.channel_errors) + assert chan.Client + assert chan.ResponseError + assert conn.transport.connection_errors + assert conn.transport.channel_errors def test_check_at_least_we_try_to_connect_and_fail(self): import redis connection = Connection('redis://localhost:65534/') - with self.assertRaises(redis.exceptions.ConnectionError): + with pytest.raises(redis.exceptions.ConnectionError): chan = connection.channel() chan._size('some_queue') @@ -974,7 +960,7 @@ def _redis_modules(): @skip.unless_module('redis') -class test_MultiChannelPoller(Case): +class test_MultiChannelPoller: def setup(self): self.Poller = redis.MultiChannelPoller @@ -1013,7 +999,7 @@ class test_MultiChannelPoller(Case): p._channels = [] poller = Mock(name='poller') p.on_poll_init(poller) - self.assertIs(p.poller, poller) + assert p.poller is poller p._channels = [chan1] p.on_poll_init(poller) @@ -1043,7 +1029,7 @@ class test_MultiChannelPoller(Case): def test_fds(self): p = self.Poller() p._fd_to_chan = {1: 2} - self.assertDictEqual(p.fds, p._fd_to_chan) + assert p.fds == p._fd_to_chan def test_close_unregisters_fds(self): p = self.Poller() @@ -1052,12 +1038,14 @@ class test_MultiChannelPoller(Case): p.close() - self.assertEqual(poller.unregister.call_count, 3) + assert poller.unregister.call_count == 3 u_args = poller.unregister.call_args_list - self.assertItemsEqual(u_args, [((1,), {}), - ((2,), {}), - ((3,), {})]) + assert sorted(u_args) == [ + ((1,), {}), + ((2,), {}), + ((3,), {}), + ] def test_close_when_unregister_raises_KeyError(self): p = self.Poller() @@ -1091,8 +1079,8 @@ class test_MultiChannelPoller(Case): p._chan_to_sock = {(channel, client, type): 6} p._register(channel, client, type) p.poller.unregister.assert_called_with(6) - self.assertTupleEqual(p._fd_to_chan[10], (channel, type)) - self.assertEqual(p._chan_to_sock[(channel, client, type)], sock) + assert p._fd_to_chan[10] == (channel, type) + assert p._chan_to_sock[(channel, client, type)] == sock p.poller.register.assert_called_with(sock, p.eventflags) # when client not connected yet @@ -1113,15 +1101,15 @@ class test_MultiChannelPoller(Case): channel._in_poll = False p._register_BRPOP(channel) - self.assertEqual(channel._brpop_start.call_count, 1) - self.assertEqual(p._register.call_count, 1) + assert channel._brpop_start.call_count == 1 + assert p._register.call_count == 1 channel.client.connection._sock = Mock() p._chan_to_sock[(channel, channel.client, 'BRPOP')] = True channel._in_poll = True p._register_BRPOP(channel) - self.assertEqual(channel._brpop_start.call_count, 1) - self.assertEqual(p._register.call_count, 1) + assert channel._brpop_start.call_count == 1 + assert p._register.call_count == 1 def test_register_LISTEN(self): p = self.Poller() @@ -1132,15 +1120,15 @@ class test_MultiChannelPoller(Case): p._register_LISTEN(channel) p._register.assert_called_with(channel, channel.subclient, 'LISTEN') - self.assertEqual(p._register.call_count, 1) - self.assertEqual(channel._subscribe.call_count, 1) + assert p._register.call_count == 1 + assert channel._subscribe.call_count == 1 channel._in_listen = True p._chan_to_sock[(channel, channel.subclient, 'LISTEN')] = 3 channel.subclient.connection._sock = Mock() p._register_LISTEN(channel) - self.assertEqual(p._register.call_count, 1) - self.assertEqual(channel._subscribe.call_count, 1) + assert p._register.call_count == 1 + assert channel._subscribe.call_count == 1 def create_get(self, events=None, queues=None, fanouts=None): _pr = [] if events is None else events @@ -1163,7 +1151,7 @@ class test_MultiChannelPoller(Case): def test_get_no_actions(self): p, channel = self.create_get() - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): p.get() def test_qos_reject(self): @@ -1177,7 +1165,7 @@ class test_MultiChannelPoller(Case): p, channel = self.create_get(queues=['a_queue']) channel.qos.can_consume.return_value = True - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): p.get() p._register_BRPOP.assert_called_with(channel) @@ -1186,7 +1174,7 @@ class test_MultiChannelPoller(Case): p, channel = self.create_get(queues=['a_queue']) channel.qos.can_consume.return_value = False - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): p.get() p._register_BRPOP.assert_not_called() @@ -1194,7 +1182,7 @@ class test_MultiChannelPoller(Case): def test_get_listen(self): p, channel = self.create_get(fanouts=['f_queue']) - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): p.get() p._register_LISTEN.assert_called_with(channel) @@ -1203,7 +1191,7 @@ class test_MultiChannelPoller(Case): p, channel = self.create_get(events=[(1, eventio.ERR)]) p._fd_to_chan[1] = (channel, 'BRPOP') - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): p.get() channel._poll_error.assert_called_with('BRPOP') @@ -1213,14 +1201,14 @@ class test_MultiChannelPoller(Case): (1, eventio.ERR)]) p._fd_to_chan[1] = (channel, 'BRPOP') - with self.assertRaises(redis.Empty): + with pytest.raises(redis.Empty): p.get() channel._poll_error.assert_called_with('BRPOP') @skip.unless_module('redis') -class test_Mutex(Case): +class test_Mutex: def test_mutex(self, lock_id='xxx'): client = Mock(name='client') @@ -1234,29 +1222,29 @@ class test_Mutex(Case): held = False with redis.Mutex(client, 'foo1', 100): held = True - self.assertTrue(held) + assert held client.setnx.assert_called_with('foo1', lock_id) pipe.get.return_value = 'yyy' held = False with redis.Mutex(client, 'foo1', 100): held = True - self.assertTrue(held) + assert held # Did not win client.expire.reset_mock() pipe.get.return_value = lock_id client.setnx.return_value = False - with self.assertRaises(redis.MutexHeld): + with pytest.raises(redis.MutexHeld): held = False with redis.Mutex(client, 'foo1', '100'): held = True - self.assertFalse(held) + assert not held client.ttl.return_value = 0 - with self.assertRaises(redis.MutexHeld): + with pytest.raises(redis.MutexHeld): held = False with redis.Mutex(client, 'foo1', '100'): held = True - self.assertFalse(held) + assert not held client.expire.assert_called() # Wins but raises WatchError (and that is ignored) @@ -1265,14 +1253,11 @@ class test_Mutex(Case): held = False with redis.Mutex(client, 'foo1', 100): held = True - self.assertTrue(held) + assert held @skip.unless_module('redis.sentinel') -class test_RedisSentinel(Case): - - def setup(self): - pass +class test_RedisSentinel: def test_method_called(self): from kombu.transport.redis import SentinelChannel @@ -1289,9 +1274,7 @@ class test_RedisSentinel(Case): p.assert_called() def test_getting_master_from_sentinel(self): - from redis.sentinel import Sentinel - - with patch.object(Sentinel, '__new__') as patched: + with patch('redis.sentinel.Sentinel') as patched: connection = Connection( 'sentinel://localhost:65534/', transport_options={ @@ -1300,7 +1283,7 @@ class test_RedisSentinel(Case): ) connection.channel() - self.assertTrue(patched) + assert patched master_for = patched.return_value.master_for master_for.assert_called() @@ -1310,11 +1293,11 @@ class test_RedisSentinel(Case): def test_can_create_connection(self): from redis.exceptions import ConnectionError - with self.assertRaises(ConnectionError): - connection = Connection( - 'sentinel://localhost:65534/', - transport_options={ - 'master_name': 'not_important', - }, - ) + connection = Connection( + 'sentinel://localhost:65534/', + transport_options={ + 'master_name': 'not_important', + }, + ) + with pytest.raises(ConnectionError): connection.channel() diff --git a/kombu/tests/transport/test_transport.py b/t/unit/transport/test_transport.py index 2f7d2ede..26df9e19 100644 --- a/kombu/tests/transport/test_transport.py +++ b/t/unit/transport/test_transport.py @@ -1,26 +1,25 @@ from __future__ import absolute_import, unicode_literals -from kombu import transport +from case import Mock, patch -from kombu.tests.case import Case, Mock, patch +from kombu import transport -class test_supports_librabbitmq(Case): +class test_supports_librabbitmq: def test_eventlet(self): with patch('kombu.transport._detect_environment') as de: de.return_value = 'eventlet' - self.assertFalse(transport.supports_librabbitmq()) + assert not transport.supports_librabbitmq() -class test_transport(Case): +class test_transport: def test_resolve_transport(self): from kombu.transport.memory import Transport - self.assertIs(transport.resolve_transport( - 'kombu.transport.memory:Transport'), - Transport) - self.assertIs(transport.resolve_transport(Transport), Transport) + assert transport.resolve_transport( + 'kombu.transport.memory:Transport') is Transport + assert transport.resolve_transport(Transport) is Transport def test_resolve_transport_alias_callable(self): m = transport.TRANSPORT_ALIASES['George'] = Mock(name='lazyalias') @@ -31,4 +30,4 @@ class test_transport(Case): transport.TRANSPORT_ALIASES.pop('George') def test_resolve_transport_alias(self): - self.assertTrue(transport.resolve_transport('pyamqp')) + assert transport.resolve_transport('pyamqp') diff --git a/kombu/tests/utils/__init__.py b/t/unit/transport/virtual/__init__.py index e69de29b..e69de29b 100644 --- a/kombu/tests/utils/__init__.py +++ b/t/unit/transport/virtual/__init__.py diff --git a/kombu/tests/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py index 8b73a7fa..5eb23a08 100644 --- a/kombu/tests/transport/virtual/test_base.py +++ b/t/unit/transport/virtual/test_base.py @@ -1,16 +1,18 @@ from __future__ import absolute_import, unicode_literals +import io +import pytest import sys import warnings +from case import MagicMock, Mock, patch + from kombu import Connection from kombu.compression import compress from kombu.exceptions import ResourceError, ChannelError from kombu.transport import virtual from kombu.utils.uuid import uuid -from kombu.tests.case import Case, MagicMock, Mock, mock, patch - PY3 = sys.version_info[0] == 3 PRINT_FQDN = 'builtins.print' if PY3 else '__builtin__.print' @@ -23,17 +25,15 @@ def memory_client(): return Connection(transport='memory') -class test_BrokerState(Case): - - def test_constructor(self): - s = virtual.BrokerState() - self.assertTrue(hasattr(s, 'exchanges')) +def test_BrokerState(): + s = virtual.BrokerState() + assert hasattr(s, 'exchanges') - t = virtual.BrokerState(exchanges=16) - self.assertEqual(t.exchanges, 16) + t = virtual.BrokerState(exchanges=16) + assert t.exchanges == 16 -class test_QoS(Case): +class test_QoS: def setup(self): self.q = virtual.QoS(client().channel(), prefetch_count=10) @@ -42,17 +42,17 @@ class test_QoS(Case): self.q._on_collect.cancel() def test_constructor(self): - self.assertTrue(self.q.channel) - self.assertTrue(self.q.prefetch_count) - self.assertFalse(self.q._delivered.restored) - self.assertTrue(self.q._on_collect) + assert self.q.channel + assert self.q.prefetch_count + assert not self.q._delivered.restored + assert self.q._on_collect def test_restore_visible__interface(self): qos = virtual.QoS(client().channel()) qos.restore_visible() - @mock.stdouts - def test_can_consume(self, stdout, stderr): + def test_can_consume(self, stdouts): + stderr = io.StringIO() _restored = [] class RestoreChannel(virtual.Channel): @@ -61,63 +61,64 @@ class test_QoS(Case): def _restore(self, message): _restored.append(message) - self.assertTrue(self.q.can_consume()) + assert self.q.can_consume() for i in range(self.q.prefetch_count - 1): self.q.append(i, uuid()) - self.assertTrue(self.q.can_consume()) + assert self.q.can_consume() self.q.append(i + 1, uuid()) - self.assertFalse(self.q.can_consume()) + assert not self.q.can_consume() tag1 = next(iter(self.q._delivered)) self.q.ack(tag1) - self.assertTrue(self.q.can_consume()) + assert self.q.can_consume() tag2 = uuid() self.q.append(i + 2, tag2) - self.assertFalse(self.q.can_consume()) + assert not self.q.can_consume() self.q.reject(tag2) - self.assertTrue(self.q.can_consume()) + assert self.q.can_consume() self.q.channel = RestoreChannel(self.q.channel.connection) tag3 = uuid() self.q.append(i + 3, tag3) self.q.reject(tag3, requeue=True) self.q._flush() - self.q.restore_unacked_once() - self.assertListEqual(_restored, [11, 9, 8, 7, 6, 5, 4, 3, 2, 1]) - self.assertTrue(self.q._delivered.restored) - self.assertFalse(self.q._delivered) - - self.q.restore_unacked_once() + assert self.q._delivered + assert not self.q._delivered.restored + self.q.restore_unacked_once(stderr=stderr) + assert _restored == [11, 9, 8, 7, 6, 5, 4, 3, 2, 1] + assert self.q._delivered.restored + assert not self.q._delivered + + self.q.restore_unacked_once(stderr=stderr) self.q._delivered.restored = False - self.q.restore_unacked_once() + self.q.restore_unacked_once(stderr=stderr) - self.assertTrue(stderr.getvalue()) - self.assertFalse(stdout.getvalue()) + assert stderr.getvalue() + assert not stdouts.stdout.getvalue() self.q.restore_at_shutdown = False self.q.restore_unacked_once() def test_get(self): self.q._delivered['foo'] = 1 - self.assertEqual(self.q.get('foo'), 1) + assert self.q.get('foo') == 1 -class test_Message(Case): +class test_Message: def test_create(self): c = client().channel() data = c.prepare_message('the quick brown fox...') tag = data['properties']['delivery_tag'] = uuid() message = c.message_to_python(data) - self.assertIsInstance(message, virtual.Message) - self.assertIs(message, c.message_to_python(message)) + assert isinstance(message, virtual.Message) + assert message is c.message_to_python(message) if message.errors: message._reraise_error() - self.assertEqual(message.body, - 'the quick brown fox...'.encode('utf-8')) - self.assertTrue(message.delivery_tag, tag) + assert message.body == 'the quick brown fox...'.encode('utf-8') + assert message.delivery_tag, tag def test_create_no_body(self): virtual.Message(Mock(), { @@ -131,46 +132,45 @@ class test_Message(Case): tag = data['properties']['delivery_tag'] = uuid() message = c.message_to_python(data) dict_ = message.serializable() - self.assertEqual(dict_['body'], - 'the quick brown fox...'.encode('utf-8')) - self.assertEqual(dict_['properties']['delivery_tag'], tag) - self.assertNotIn('compression', dict_['headers']) + assert dict_['body'] == 'the quick brown fox...'.encode('utf-8') + assert dict_['properties']['delivery_tag'] == tag + assert 'compression' not in dict_['headers'] -class test_AbstractChannel(Case): +class test_AbstractChannel: def test_get(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): virtual.AbstractChannel()._get('queue') def test_put(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): virtual.AbstractChannel()._put('queue', 'm') def test_size(self): - self.assertEqual(virtual.AbstractChannel()._size('queue'), 0) + assert virtual.AbstractChannel()._size('queue') == 0 def test_purge(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): virtual.AbstractChannel()._purge('queue') def test_delete(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): virtual.AbstractChannel()._delete('queue') def test_new_queue(self): - self.assertIsNone(virtual.AbstractChannel()._new_queue('queue')) + assert virtual.AbstractChannel()._new_queue('queue') is None def test_has_queue(self): - self.assertTrue(virtual.AbstractChannel()._has_queue('queue')) + assert virtual.AbstractChannel()._has_queue('queue') def test_poll(self): cycle = Mock(name='cycle') - self.assertTrue(virtual.AbstractChannel()._poll(cycle)) + assert virtual.AbstractChannel()._poll(cycle) cycle.get.assert_called() -class test_Channel(Case): +class test_Channel: def setup(self): self.channel = client().channel() @@ -184,15 +184,15 @@ class test_Channel(Case): t = c.transport avail = t._avail_channel_ids = Mock(name='_avail_channel_ids') avail.pop.side_effect = IndexError() - with self.assertRaises(ResourceError): + with pytest.raises(ResourceError): virtual.Channel(t) def test_exchange_bind_interface(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.channel.exchange_bind('dest', 'src', 'key') def test_exchange_unbind_interface(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.channel.exchange_unbind('dest', 'src', 'key') def test_queue_unbind_interface(self): @@ -200,28 +200,28 @@ class test_Channel(Case): def test_management(self): m = self.channel.connection.client.get_manager() - self.assertTrue(m) + assert m m.get_bindings() m.close() def test_exchange_declare(self): c = self.channel - with self.assertRaises(ChannelError): + with pytest.raises(ChannelError): c.exchange_declare('test_exchange_declare', 'direct', durable=True, auto_delete=True, passive=True) c.exchange_declare('test_exchange_declare', 'direct', durable=True, auto_delete=True) c.exchange_declare('test_exchange_declare', 'direct', durable=True, auto_delete=True, passive=True) - self.assertIn('test_exchange_declare', c.state.exchanges) + assert 'test_exchange_declare' in c.state.exchanges # can declare again with same values c.exchange_declare('test_exchange_declare', 'direct', durable=True, auto_delete=True) - self.assertIn('test_exchange_declare', c.state.exchanges) + assert 'test_exchange_declare' in c.state.exchanges # using different values raises NotEquivalentError - with self.assertRaises(virtual.NotEquivalentError): + with pytest.raises(virtual.NotEquivalentError): c.exchange_declare('test_exchange_declare', 'direct', durable=False, auto_delete=True) @@ -236,18 +236,18 @@ class test_Channel(Case): c = PurgeChannel(self.channel.connection) c.exchange_declare(ex, 'direct', durable=True, auto_delete=True) - self.assertIn(ex, c.state.exchanges) - self.assertFalse(c.state.has_binding(ex, ex, ex)) # no bindings yet + assert ex in c.state.exchanges + assert not c.state.has_binding(ex, ex, ex) # no bindings yet c.exchange_delete(ex) - self.assertNotIn(ex, c.state.exchanges) + assert ex not in c.state.exchanges c.exchange_declare(ex, 'direct', durable=True, auto_delete=True) c.queue_declare(ex) c.queue_bind(ex, ex, ex) - self.assertTrue(c.state.has_binding(ex, ex, ex)) + assert c.state.has_binding(ex, ex, ex) c.exchange_delete(ex) - self.assertFalse(c.state.has_binding(ex, ex, ex)) - self.assertIn(ex, c.purged) + assert not c.state.has_binding(ex, ex, ex) + assert ex in c.purged def test_queue_delete__if_empty(self, n='test_queue_delete__if_empty'): class PurgeChannel(virtual.Channel): @@ -268,12 +268,12 @@ class test_Channel(Case): c.queue_bind(n, n, n) c.queue_delete(n, if_empty=True) - self.assertTrue(c.state.has_binding(n, n, n)) + assert c.state.has_binding(n, n, n) c.size = 0 c.queue_delete(n, if_empty=True) - self.assertFalse(c.state.has_binding(n, n, n)) - self.assertIn(n, c.purged) + assert not c.state.has_binding(n, n, n) + assert n in c.purged def test_queue_purge(self, n='test_queue_purge'): @@ -288,7 +288,7 @@ class test_Channel(Case): c.queue_declare(n) c.queue_bind(n, n, n) c.queue_purge(n) - self.assertIn(n, c.purged) + assert n in c.purged def test_basic_publish__anon_exchange(self): c = memory_client().channel() @@ -315,10 +315,10 @@ class test_Channel(Case): r1 = c1.message_to_python(c1.basic_get(n)) r2 = c2.message_to_python(c2.basic_get(n)) - self.assertNotEqual(r1.delivery_tag, r2.delivery_tag) - with self.assertRaises(ValueError): + assert r1.delivery_tag != r2.delivery_tag + with pytest.raises(ValueError): int(r1.delivery_tag) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): int(r2.delivery_tag) def test_basic_publish__get__consume__restore(self, @@ -335,31 +335,29 @@ class test_Channel(Case): c.basic_publish(m, n, n) r1 = c.message_to_python(c.basic_get(n)) - self.assertTrue(r1) - self.assertEqual(r1.body, - 'nthex quick brown fox...'.encode('utf-8')) - self.assertIsNone(c.basic_get(n)) + assert r1 + assert r1.body == 'nthex quick brown fox...'.encode('utf-8') + assert c.basic_get(n) is None consumer_tag = uuid() c.basic_consume(n + '2', False, consumer_tag=consumer_tag, callback=lambda *a: None) - self.assertIn(n + '2', c._active_queues) + assert n + '2' in c._active_queues r2, _ = c.drain_events() r2 = c.message_to_python(r2) - self.assertEqual(r2.body, - 'nthex quick brown fox...'.encode('utf-8')) - self.assertEqual(r2.delivery_info['exchange'], n) - self.assertEqual(r2.delivery_info['routing_key'], n) - with self.assertRaises(virtual.Empty): + assert r2.body == 'nthex quick brown fox...'.encode('utf-8') + assert r2.delivery_info['exchange'] == n + assert r2.delivery_info['routing_key'] == n + with pytest.raises(virtual.Empty): c.drain_events() c.basic_cancel(consumer_tag) c._restore(r2) r3 = c.message_to_python(c.basic_get(n)) - self.assertTrue(r3) - self.assertEqual(r3.body, 'nthex quick brown fox...'.encode('utf-8')) - self.assertIsNone(c.basic_get(n)) + assert r3 + assert r3.body == 'nthex quick brown fox...'.encode('utf-8') + assert c.basic_get(n) is None def test_basic_ack(self): @@ -371,7 +369,7 @@ class test_Channel(Case): self.channel._qos = MockQoS(self.channel) self.channel.basic_ack('foo') - self.assertTrue(self.channel._qos.was_acked) + assert self.channel._qos.was_acked def test_basic_recover__requeue(self): @@ -383,7 +381,7 @@ class test_Channel(Case): self.channel._qos = MockQoS(self.channel) self.channel.basic_recover(requeue=True) - self.assertTrue(self.channel._qos.was_restored) + assert self.channel._qos.was_restored def test_restore_unacked_raises_BaseException(self): q = self.channel.qos @@ -394,11 +392,11 @@ class test_Channel(Case): q.channel._restore.side_effect = SystemExit errors = q.restore_unacked() - self.assertIsInstance(errors[0][0], SystemExit) - self.assertEqual(errors[0][1], 1) - self.assertFalse(q._delivered) + assert isinstance(errors[0][0], SystemExit) + assert errors[0][1] == 1 + assert not q._delivered - @patch('kombu.transport.virtual.emergency_dump_state') + @patch('kombu.transport.virtual.base.emergency_dump_state') @patch(PRINT_FQDN) def test_restore_unacked_once_when_unrestored(self, print_, emergency_dump_state): @@ -423,7 +421,7 @@ class test_Channel(Case): emergency_dump_state.assert_called() def test_basic_recover(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.channel.basic_recover(requeue=False) def test_basic_reject(self): @@ -436,39 +434,38 @@ class test_Channel(Case): self.channel._qos = MockQoS(self.channel) self.channel.basic_reject('foo') - self.assertTrue(self.channel._qos.was_rejected) + assert self.channel._qos.was_rejected def test_basic_qos(self): self.channel.basic_qos(prefetch_count=128) - self.assertEqual(self.channel._qos.prefetch_count, 128) + assert self.channel._qos.prefetch_count == 128 def test_lookup__undeliverable(self, n='test_lookup__undeliverable'): warnings.resetwarnings() with warnings.catch_warnings(record=True) as log: - self.assertListEqual( - self.channel._lookup(n, n, 'ae.undeliver'), - ['ae.undeliver'], - ) - self.assertTrue(log) - self.assertIn('could not be delivered', log[0].message.args[0]) + assert self.channel._lookup(n, n, 'ae.undeliver') == [ + 'ae.undeliver', + ] + assert log + assert 'could not be delivered' in log[0].message.args[0] def test_context(self): x = self.channel.__enter__() - self.assertIs(x, self.channel) + assert x is self.channel x.__exit__() - self.assertTrue(x.closed) + assert x.closed def test_cycle_property(self): - self.assertTrue(self.channel.cycle) + assert self.channel.cycle def test_flow(self): - with self.assertRaises(NotImplementedError): + with pytest.raises(NotImplementedError): self.channel.flow(False) def test_close_when_no_connection(self): self.channel.connection = None self.channel.close() - self.assertTrue(self.channel.closed) + assert self.channel.closed def test_drain_events_has_get_many(self): c = self.channel @@ -483,7 +480,7 @@ class test_Channel(Case): def test_get_exchanges(self): self.channel.exchange_declare(exchange='foo') - self.assertTrue(self.channel.get_exchanges()) + assert self.channel.get_exchanges() def test_basic_cancel_not_in_active_queues(self): c = self.channel @@ -496,7 +493,7 @@ class test_Channel(Case): c._active_queues.remove.assert_called_with('foo') def test_basic_cancel_unknown_ctag(self): - self.assertIsNone(self.channel.basic_cancel('unknown-tag')) + assert self.channel.basic_cancel('unknown-tag') is None def test_list_bindings(self): c = self.channel @@ -504,7 +501,7 @@ class test_Channel(Case): c.queue_declare(queue='q') c.queue_bind(queue='q', exchange='foo', routing_key='rk') - self.assertIn(('q', 'foo', 'rk'), list(c.list_bindings())) + assert ('q', 'foo', 'rk') in list(c.list_bindings()) def test_after_reply_message_received(self): c = self.channel @@ -513,12 +510,12 @@ class test_Channel(Case): c.queue_delete.assert_called_with('foo') def test_queue_delete_unknown_queue(self): - self.assertIsNone(self.channel.queue_delete('xiwjqjwel')) + assert self.channel.queue_delete('xiwjqjwel') is None def test_queue_declare_passive(self): has_queue = self.channel._has_queue = Mock() has_queue.return_value = False - with self.assertRaises(ChannelError): + with pytest.raises(ChannelError): self.channel.queue_declare(queue='21wisdjwqe', passive=True) def test_get_message_priority(self): @@ -528,56 +525,46 @@ class test_Channel(Case): 'the message with priority', priority=priority, ) - self.assertEqual( - self.channel._get_message_priority(_message(5)), 5, - ) - self.assertEqual( - self.channel._get_message_priority( - _message(self.channel.min_priority - 10), - ), - self.channel.min_priority, - ) - self.assertEqual( - self.channel._get_message_priority( - _message(self.channel.max_priority + 10), - ), - self.channel.max_priority, - ) - self.assertEqual( - self.channel._get_message_priority(_message('foobar')), - self.channel.default_priority, - ) - self.assertEqual( - self.channel._get_message_priority(_message(2), reverse=True), - self.channel.max_priority - 2, - ) + assert self.channel._get_message_priority(_message(5)) == 5 + assert self.channel._get_message_priority( + _message(self.channel.min_priority - 10) + ) == self.channel.min_priority + assert self.channel._get_message_priority( + _message(self.channel.max_priority + 10), + ) == self.channel.max_priority + assert self.channel._get_message_priority( + _message('foobar'), + ) == self.channel.default_priority + assert self.channel._get_message_priority( + _message(2), reverse=True, + ) == self.channel.max_priority - 2 -class test_Transport(Case): +class test_Transport: def setup(self): self.transport = client().transport def test_custom_polling_interval(self): x = client(transport_options=dict(polling_interval=32.3)) - self.assertEqual(x.transport.polling_interval, 32.3) + assert x.transport.polling_interval == 32.3 def test_close_connection(self): c1 = self.transport.create_channel(self.transport) c2 = self.transport.create_channel(self.transport) - self.assertEqual(len(self.transport.channels), 2) + assert len(self.transport.channels) == 2 self.transport.close_connection(self.transport) - self.assertFalse(self.transport.channels) + assert not self.transport.channels del(c1) # so pyflakes doesn't complain del(c2) def test_drain_channel(self): channel = self.transport.create_channel(self.transport) - with self.assertRaises(virtual.Empty): + with pytest.raises(virtual.Empty): self.transport._drain_channel(channel) def test__deliver__no_queue(self): - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.transport._deliver(Mock(name='msg'), queue=None) def test__reject_inbound_message(self): @@ -601,12 +588,12 @@ class test_Transport(Case): callback.assert_called_with(msg) def test_on_message_ready__no_queue(self): - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.transport.on_message_ready( Mock(name='channel'), Mock(name='msg'), queue=None) def test_on_message_ready__no_callback(self): self.transport._callbacks = {} - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.transport.on_message_ready( Mock(name='channel'), Mock(name='msg'), queue='q1') diff --git a/t/unit/transport/virtual/test_exchange.py b/t/unit/transport/virtual/test_exchange.py new file mode 100644 index 00000000..93d48228 --- /dev/null +++ b/t/unit/transport/virtual/test_exchange.py @@ -0,0 +1,169 @@ +from __future__ import absolute_import, unicode_literals + +import pytest + +from case import Mock + +from kombu import Connection +from kombu.transport.virtual import exchange + +from t.mocks import Transport + + +class ExchangeCase: + type = None + + def setup(self): + if self.type: + self.e = self.type(Connection(transport=Transport).channel()) + + +class test_Direct(ExchangeCase): + type = exchange.DirectExchange + table = [('rFoo', None, 'qFoo'), + ('rFoo', None, 'qFox'), + ('rBar', None, 'qBar'), + ('rBaz', None, 'qBaz')] + + @pytest.mark.parametrize('exchange,routing_key,default,expected', [ + ('eFoo', 'rFoo', None, {'qFoo', 'qFox'}), + ('eMoz', 'rMoz', 'DEFAULT', set()), + ('eBar', 'rBar', None, {'qBar'}), + ]) + def test_lookup(self, exchange, routing_key, default, expected): + assert self.e.lookup( + self.table, exchange, routing_key, default) == expected + + +class test_Fanout(ExchangeCase): + type = exchange.FanoutExchange + table = [(None, None, 'qFoo'), + (None, None, 'qFox'), + (None, None, 'qBar')] + + def test_lookup(self): + assert self.e.lookup(self.table, 'eFoo', 'rFoo', None) == { + 'qFoo', 'qFox', 'qBar', + } + + def test_deliver_when_fanout_supported(self): + self.e.channel = Mock() + self.e.channel.supports_fanout = True + message = Mock() + + self.e.deliver(message, 'exchange', 'rkey') + self.e.channel._put_fanout.assert_called_with( + 'exchange', message, 'rkey', + ) + + def test_deliver_when_fanout_unsupported(self): + self.e.channel = Mock() + self.e.channel.supports_fanout = False + + self.e.deliver(Mock(), 'exchange', None) + self.e.channel._put_fanout.assert_not_called() + + +class test_Topic(ExchangeCase): + type = exchange.TopicExchange + table = [ + ('stock.#', None, 'rFoo'), + ('stock.us.*', None, 'rBar'), + ] + + def setup(self): + ExchangeCase.setup(self) + self.table = [(rkey, self.e.key_to_pattern(rkey), queue) + for rkey, _, queue in self.table] + + def test_prepare_bind(self): + x = self.e.prepare_bind('qFoo', 'eFoo', 'stock.#', {}) + assert x == ('stock.#', r'^stock\..*?$', 'qFoo') + + @pytest.mark.parametrize('exchange,routing_key,default,expected', [ + ('eFoo', 'stock.us.nasdaq', None, {'rFoo', 'rBar'}), + ('eFoo', 'stock.europe.OSE', None, {'rFoo'}), + ('eFoo', 'stockxeuropexOSE', None, set()), + ('eFoo', 'candy.schleckpulver.snap_crackle', None, set()), + ]) + def test_lookup(self, exchange, routing_key, default, expected): + assert self.e.lookup( + self.table, exchange, routing_key, default) == expected + assert self.e._compiled + + def test_deliver(self): + self.e.channel = Mock() + self.e.channel._lookup.return_value = ('a', 'b') + message = Mock() + self.e.deliver(message, 'exchange', 'rkey') + + assert self.e.channel._put.call_args_list == [ + (('a', message), {}), + (('b', message), {}), + ] + + +class test_TopicMultibind(ExchangeCase): + # Testing message delivery in case of multiple overlapping + # bindings for the same queue. As AMQP states, in case of + # overlapping bindings, a message must be delivered once to + # each matching queue. + type = exchange.TopicExchange + table = [ + ('stock', None, 'rFoo'), + ('stock.#', None, 'rFoo'), + ('stock.us.*', None, 'rFoo'), + ('#', None, 'rFoo'), + ] + + def setup(self): + ExchangeCase.setup(self) + self.table = [(rkey, self.e.key_to_pattern(rkey), queue) + for rkey, _, queue in self.table] + + @pytest.mark.parametrize('exchange,routing_key,default,expected', [ + ('eFoo', 'stock.us.nasdaq', None, {'rFoo'}), + ('eFoo', 'stock.europe.OSE', None, {'rFoo'}), + ('eFoo', 'stockxeuropexOSE', None, {'rFoo'}), + ('eFoo', 'candy.schleckpulver.snap_crackle', None, {'rFoo'}), + ]) + def test_lookup(self, exchange, routing_key, default, expected): + assert self.e._compiled + assert self.e.lookup( + self.table, exchange, routing_key, default) == expected + + +class test_ExchangeType(ExchangeCase): + type = exchange.ExchangeType + + def test_lookup(self): + with pytest.raises(NotImplementedError): + self.e.lookup([], 'eFoo', 'rFoo', None) + + def test_prepare_bind(self): + assert self.e.prepare_bind('qFoo', 'eFoo', 'rFoo', {}) == ( + 'rFoo', None, 'qFoo', + ) + + e1 = dict( + type='direct', + durable=True, + auto_delete=True, + arguments={}, + ) + e2 = dict(e1, arguments={'expires': 3000}) + + @pytest.mark.parametrize('ex,eq,name,type,durable,auto_delete,arguments', [ + (e1, True, 'eFoo', 'direct', True, True, {}), + (e1, False, 'eFoo', 'topic', True, True, {}), + (e1, False, 'eFoo', 'direct', False, True, {}), + (e1, False, 'eFoo', 'direct', True, False, {}), + (e1, False, 'eFoo', 'direct', True, True, {'expires': 3000}), + (e2, True, 'eFoo', 'direct', True, True, {'expires': 3000}), + (e2, False, 'eFoo', 'direct', True, True, {'expires': 6000}), + ]) + def test_equivalent( + self, ex, eq, name, type, durable, auto_delete, arguments): + is_eq = self.e.equivalent( + ex, name, type, durable, auto_delete, arguments) + assert is_eq if eq else not is_eq diff --git a/t/unit/utils/__init__.py b/t/unit/utils/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/t/unit/utils/__init__.py diff --git a/kombu/tests/utils/test_amq_manager.py b/t/unit/utils/test_amq_manager.py index 03eb634e..3ea1e41a 100644 --- a/kombu/tests/utils/test_amq_manager.py +++ b/t/unit/utils/test_amq_manager.py @@ -1,22 +1,24 @@ from __future__ import absolute_import, unicode_literals -from kombu import Connection +import pytest + +from case import mock, patch -from kombu.tests.case import Case, mock, patch +from kombu import Connection -class test_get_manager(Case): +class test_get_manager: @mock.mask_modules('pyrabbit') def test_without_pyrabbit(self): - with self.assertRaises(ImportError): + with pytest.raises(ImportError): Connection('amqp://').get_manager() @mock.module_exists('pyrabbit') def test_with_pyrabbit(self): with patch('pyrabbit.Client', create=True) as Client: manager = Connection('amqp://').get_manager() - self.assertIsNotNone(manager) + assert manager is not None Client.assert_called_with( 'localhost:15672', 'guest', 'guest', ) @@ -30,7 +32,7 @@ class test_get_manager(Case): 'manager_userid': 'george', 'manager_password': 'bosco', }).get_manager() - self.assertIsNotNone(manager) + assert manager is not None Client.assert_called_with( 'admin.mq.vandelay.com:808', 'george', 'bosco', ) diff --git a/kombu/tests/utils/test_compat.py b/t/unit/utils/test_compat.py index e4bd2c7a..963c50b6 100644 --- a/kombu/tests/utils/test_compat.py +++ b/t/unit/utils/test_compat.py @@ -1,32 +1,30 @@ from __future__ import absolute_import, unicode_literals -from kombu.utils.compat import entrypoints, maybe_fileno +from case import Mock, mock, patch -from kombu.tests.case import Case, Mock, mock, patch +from kombu.utils.compat import entrypoints, maybe_fileno -class test_entrypoints(Case): +class test_entrypoints: @mock.mask_modules('pkg_resources') def test_without_pkg_resources(self): - self.assertListEqual(list(entrypoints('kombu.test')), []) + assert list(entrypoints('kombu.test')) == [] @mock.module_exists('pkg_resources') def test_with_pkg_resources(self): with patch('pkg_resources.iter_entry_points', create=True) as iterep: eps = iterep.return_value = [Mock(), Mock()] - self.assertTrue(list(entrypoints('kombu.test'))) + assert list(entrypoints('kombu.test')) iterep.assert_called_with('kombu.test') eps[0].load.assert_called_with() eps[1].load.assert_called_with() -class test_maybe_fileno(Case): - - def test_maybe_fileno(self): - self.assertEqual(maybe_fileno(3), 3) - f = Mock(name='file') - self.assertIs(maybe_fileno(f), f.fileno()) - f.fileno.side_effect = ValueError() - self.assertIsNone(maybe_fileno(f)) +def test_maybe_fileno(): + assert maybe_fileno(3) == 3 + f = Mock(name='file') + assert maybe_fileno(f) is f.fileno() + f.fileno.side_effect = ValueError() + assert maybe_fileno(f) is None diff --git a/kombu/tests/utils/test_debug.py b/t/unit/utils/test_debug.py index 9f10cc76..732638cb 100644 --- a/kombu/tests/utils/test_debug.py +++ b/t/unit/utils/test_debug.py @@ -2,13 +2,13 @@ from __future__ import absolute_import, unicode_literals import logging +from case import Mock, patch + from kombu.five import bytes_if_py2 from kombu.utils.debug import Logwrapped, setup_logging -from kombu.tests.case import Case, Mock, patch - -class test_setup_logging(Case): +class test_setup_logging: def test_adds_handlers_sets_level(self): with patch('kombu.utils.debug.get_logger') as get_logger: @@ -21,7 +21,7 @@ class test_setup_logging(Case): logger.setLevel.assert_called_with(logging.DEBUG) -class test_Logwrapped(Case): +class test_Logwrapped: def test_wraps(self): with patch('kombu.utils.debug.get_logger') as get_logger: @@ -29,13 +29,13 @@ class test_Logwrapped(Case): W = Logwrapped(Mock(), 'kombu.test') get_logger.assert_called_with('kombu.test') - self.assertIsNotNone(W.instance) - self.assertIs(W.logger, logger) + assert W.instance is not None + assert W.logger is logger W.instance.__repr__ = lambda s: bytes_if_py2('foo') - self.assertEqual(repr(W), 'foo') + assert repr(W) == 'foo' W.instance.some_attr = 303 - self.assertEqual(W.some_attr, 303) + assert W.some_attr == 303 W.instance.some_method.__name__ = bytes_if_py2('some_method') W.some_method(1, 2, kw=1) @@ -50,6 +50,6 @@ class test_Logwrapped(Case): W.ident = 'ident' W.some_method(kw=1) logger.debug.assert_called() - self.assertIn('ident', logger.debug.call_args[0][0]) + assert 'ident' in logger.debug.call_args[0][0] - self.assertEqual(dir(W), dir(W.instance)) + assert dir(W) == dir(W.instance) diff --git a/t/unit/utils/test_div.py b/t/unit/utils/test_div.py new file mode 100644 index 00000000..f0b1a058 --- /dev/null +++ b/t/unit/utils/test_div.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import, unicode_literals + +import pickle + +from io import StringIO, BytesIO + +from kombu.utils.div import emergency_dump_state + + +class MyStringIO(StringIO): + + def close(self): + pass + + +class MyBytesIO(BytesIO): + + def close(self): + pass + + +class test_emergency_dump_state: + + def test_dump(self, stdouts): + fh = MyBytesIO() + stderr = StringIO() + emergency_dump_state( + {'foo': 'bar'}, open_file=lambda n, m: fh, stderr=stderr) + assert pickle.loads(fh.getvalue()) == {'foo': 'bar'} + assert stderr.getvalue() + assert not stdouts.stdout.getvalue() + + def test_dump_second_strategy(self, stdouts): + fh = MyStringIO() + stderr = StringIO() + + def raise_something(*args, **kwargs): + raise KeyError('foo') + + emergency_dump_state( + {'foo': 'bar'}, + open_file=lambda n, m: fh, + dump=raise_something, + stderr=stderr, + ) + assert 'foo' in fh.getvalue() + assert 'bar' in fh.getvalue() + assert stderr.getvalue() + assert not stdouts.stdout.getvalue() diff --git a/t/unit/utils/test_encoding.py b/t/unit/utils/test_encoding.py new file mode 100644 index 00000000..e3d1040a --- /dev/null +++ b/t/unit/utils/test_encoding.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, unicode_literals + +import sys + +from contextlib import contextmanager + +from case import patch, skip + +from kombu.five import bytes_t, string_t +from kombu.utils.encoding import ( + get_default_encoding_file, safe_str, + set_default_encoding_file, default_encoding, +) + + +@contextmanager +def clean_encoding(): + old_encoding = sys.modules.pop('kombu.utils.encoding', None) + import kombu.utils.encoding + try: + yield kombu.utils.encoding + finally: + if old_encoding: + sys.modules['kombu.utils.encoding'] = old_encoding + + +class test_default_encoding: + + def test_set_default_file(self): + prev = get_default_encoding_file() + try: + set_default_encoding_file('/foo.txt') + assert get_default_encoding_file() == '/foo.txt' + finally: + set_default_encoding_file(prev) + + @patch('sys.getfilesystemencoding') + def test_default(self, getdefaultencoding): + getdefaultencoding.return_value = 'ascii' + with clean_encoding() as encoding: + enc = encoding.default_encoding() + if sys.platform.startswith('java'): + assert enc == 'utf-8' + else: + assert enc == 'ascii' + getdefaultencoding.assert_called_with() + + +@skip.if_python3() +def test_str_to_bytes(): + with clean_encoding() as e: + assert isinstance(e.str_to_bytes('foobar'), bytes_t) + + +@skip.if_python3() +def test_from_utf8(): + with clean_encoding() as e: + assert isinstance(e.from_utf8('foobar'), bytes_t) + + +@skip.if_python3() +def test_default_encode(): + with clean_encoding() as e: + assert e.default_encode(b'foo') + + +class test_safe_str: + + def setup(self): + self._encoding = self.patching('sys.getfilesystemencoding') + self._encoding.return_value = 'ascii' + + def test_when_bytes(self): + assert safe_str('foo') == 'foo' + + def test_when_unicode(self): + assert isinstance(safe_str('foo'), string_t) + + def test_when_encoding_utf8(self): + self._encoding.return_value = 'utf-8' + assert default_encoding() == 'utf-8' + s = 'The quiæk fåx jømps øver the lazy dåg' + res = safe_str(s) + assert isinstance(res, str) + + def test_when_containing_high_chars(self): + self._encoding.return_value = 'ascii' + s = 'The quiæk fåx jømps øver the lazy dåg' + res = safe_str(s) + assert isinstance(res, str) + assert len(s) == len(res) + + def test_when_not_string(self): + o = object() + assert safe_str(o) == repr(o) + + def test_when_unrepresentable(self): + + class O(object): + + def __repr__(self): + raise KeyError('foo') + + assert '<Unrepresentable' in safe_str(O()) diff --git a/kombu/tests/utils/test_functional.py b/t/unit/utils/test_functional.py index e953979f..7a97b5b3 100644 --- a/kombu/tests/utils/test_functional.py +++ b/t/unit/utils/test_functional.py @@ -1,9 +1,12 @@ from __future__ import absolute_import, unicode_literals import pickle +import pytest from itertools import count +from case import Mock, mock, skip + from kombu.five import items from kombu.utils import functional as utils from kombu.utils.functional import ( @@ -11,21 +14,16 @@ from kombu.utils.functional import ( maybe_evaluate, maybe_list, reprcall, reprkwargs, retry_over_time, ) -from kombu.tests.case import Case, Mock, mock, skip - -class test_ChannelPromise(Case): +class test_ChannelPromise: def test_repr(self): obj = Mock(name='cb') - self.assertIn( - 'promise', - repr(ChannelPromise(obj)), - ) + assert 'promise' in repr(ChannelPromise(obj)) obj.assert_not_called() -class test_shufflecycle(Case): +class test_shufflecycle: def test_shuffles(self): prev_repeat, utils.repeat = utils.repeat, Mock() @@ -37,8 +35,8 @@ class test_shufflecycle(Case): for i in range(10): next(cycle) utils.repeat.assert_called_with(None) - self.assertTrue(seen.issubset(values)) - with self.assertRaises(StopIteration): + assert seen.issubset(values) + with pytest.raises(StopIteration): next(cycle) next(cycle) finally: @@ -49,7 +47,7 @@ def double(x): return x * 2 -class test_LRUCache(Case): +class test_LRUCache: def test_expires(self): limit = 100 @@ -57,16 +55,16 @@ class test_LRUCache(Case): slots = list(range(limit * 2)) for i in slots: x[i] = i - self.assertListEqual(list(x.keys()), list(slots[limit:])) - self.assertTrue(x.items()) - self.assertTrue(x.values()) + assert list(x.keys()) == list(slots[limit:]) + assert x.items() + assert x.values() def test_is_pickleable(self): x = LRUCache(limit=10) x.update(luke=1, leia=2) y = pickle.loads(pickle.dumps(x)) - self.assertEqual(y.limit, y.limit) - self.assertEqual(y, x) + assert y.limit == y.limit + assert y == x def test_update_expires(self): limit = 100 @@ -75,117 +73,111 @@ class test_LRUCache(Case): for i in slots: x.update({i: i}) - self.assertListEqual(list(x.keys()), list(slots[limit:])) + assert list(x.keys()) == list(slots[limit:]) def test_least_recently_used(self): x = LRUCache(3) x[1], x[2], x[3] = 1, 2, 3 - self.assertEqual(list(x.keys()), [1, 2, 3]) + assert list(x.keys()), [1, 2 == 3] x[4], x[5] = 4, 5 - self.assertEqual(list(x.keys()), [3, 4, 5]) + assert list(x.keys()), [3, 4 == 5] # access 3, which makes it the last used key. x[3] x[6] = 6 - self.assertEqual(list(x.keys()), [5, 3, 6]) + assert list(x.keys()), [5, 3 == 6] x[7] = 7 - self.assertEqual(list(x.keys()), [3, 6, 7]) + assert list(x.keys()), [3, 6 == 7] def test_update_larger_than_cache_size(self): x = LRUCache(2) x.update({x: x for x in range(100)}) - self.assertEqual(list(x.keys()), [98, 99]) + assert list(x.keys()), [98 == 99] def test_items(self): c = LRUCache() c.update(a=1, b=2, c=3) - self.assertTrue(list(items(c))) + assert list(items(c)) def test_incr(self): c = LRUCache() c.update(a='1') c.incr('a') - self.assertEqual(c['a'], '2') - + assert c['a'] == '2' -class test_memoize(Case): - def test_memoize(self): - counter = count(1) +def test_memoize(): + counter = count(1) - @memoize(maxsize=2) - def x(i): - return next(counter) + @memoize(maxsize=2) + def x(i): + return next(counter) - self.assertEqual(x(1), 1) - self.assertEqual(x(1), 1) - self.assertEqual(x(2), 2) - self.assertEqual(x(3), 3) - self.assertEqual(x(1), 4) - x.clear() - self.assertEqual(x(3), 5) + assert x(1) == 1 + assert x(1) == 1 + assert x(2) == 2 + assert x(3) == 3 + assert x(1) == 4 + x.clear() + assert x(3) == 5 -class test_lazy(Case): +class test_lazy: def test__str__(self): - self.assertEqual( - str(lazy(lambda: 'the quick brown fox')), - 'the quick brown fox', - ) + assert (str(lazy(lambda: 'the quick brown fox')) == + 'the quick brown fox') def test__repr__(self): - self.assertEqual( - repr(lazy(lambda: 'fi fa fo')).strip('u'), - "'fi fa fo'", - ) + assert repr(lazy(lambda: 'fi fa fo')).strip('u') == "'fi fa fo'" @skip.if_python3() def test__cmp__(self): - self.assertEqual(lazy(lambda: 10).__cmp__(lazy(lambda: 20)), -1) - self.assertEqual(lazy(lambda: 10).__cmp__(5), 1) + assert lazy(lambda: 10).__cmp__(lazy(lambda: 20)) == -1 + assert lazy(lambda: 10).__cmp__(5) == 1 def test_evaluate(self): - self.assertEqual(lazy(lambda: 2 + 2)(), 4) - self.assertEqual(lazy(lambda x: x * 4, 2), 8) - self.assertEqual(lazy(lambda x: x * 8, 2)(), 16) + assert lazy(lambda: 2 + 2)() == 4 + assert lazy(lambda x: x * 4, 2) == 8 + assert lazy(lambda x: x * 8, 2)() == 16 def test_cmp(self): - self.assertEqual(lazy(lambda: 10), lazy(lambda: 10)) - self.assertNotEqual(lazy(lambda: 10), lazy(lambda: 20)) + assert lazy(lambda: 10) == lazy(lambda: 10) + assert lazy(lambda: 10) != lazy(lambda: 20) def test__reduce__(self): x = lazy(double, 4) y = pickle.loads(pickle.dumps(x)) - self.assertEqual(x(), y()) + assert x() == y() def test__deepcopy__(self): from copy import deepcopy x = lazy(double, 4) y = deepcopy(x) - self.assertEqual(x._fun, y._fun) - self.assertEqual(x._args, y._args) - self.assertEqual(x(), y()) + assert x._fun == y._fun + assert x._args == y._args + assert x() == y() -class test_maybe_evaluate(Case): +@pytest.mark.parametrize('obj,expected', [ + (lazy(lambda: 10), 10), + (20, 20), +]) +def test_maybe_evaluate(obj, expected): + assert maybe_evaluate(obj) == expected - def test_evaluates(self): - self.assertEqual(maybe_evaluate(lazy(lambda: 10)), 10) - self.assertEqual(maybe_evaluate(20), 20) +class test_retry_over_time: -class test_retry_over_time(Case): + class Predicate(Exception): + pass def setup(self): self.index = 0 - class Predicate(Exception): - pass - def myfun(self): if self.index < 9: raise self.Predicate() @@ -195,7 +187,7 @@ class test_retry_over_time(Case): interval = next(intervals) sleepvals = (None, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 16.0) self.index += 1 - self.assertEqual(interval, sleepvals[self.index]) + assert interval == sleepvals[self.index] return interval @mock.sleepdeprived(module=utils) @@ -205,28 +197,28 @@ class test_retry_over_time(Case): utils.count.return_value = list(range(1)) x = retry_over_time(self.myfun, self.Predicate, errback=None, interval_max=14) - self.assertIsNone(x) + assert x is None utils.count.return_value = list(range(10)) cb = Mock() x = retry_over_time(self.myfun, self.Predicate, errback=self.errback, callback=cb, interval_max=14) - self.assertEqual(x, 42) - self.assertEqual(self.index, 9) + assert x == 42 + assert self.index == 9 cb.assert_called_with() finally: utils.count = prev_count @mock.sleepdeprived(module=utils) def test_retry_once(self): - with self.assertRaises(self.Predicate): + with pytest.raises(self.Predicate): retry_over_time( self.myfun, self.Predicate, max_retries=1, errback=self.errback, interval_max=14, ) - self.assertEqual(self.index, 1) + assert self.index == 1 # no errback - with self.assertRaises(self.Predicate): + with pytest.raises(self.Predicate): retry_over_time( self.myfun, self.Predicate, max_retries=1, errback=None, interval_max=14, @@ -250,38 +242,39 @@ class test_retry_over_time(Case): self.calls += 1 fun = Fun() - self.assertEqual( - retry_over_time( - fun, self.Predicate, - max_retries=0, errback=None, interval_max=14, - ), - 42, - ) - self.assertEqual(fun.calls, 11) - - -class test_utils(Case): - - def test_maybe_list(self): - self.assertIsNone(maybe_list(None)) - self.assertEqual(maybe_list(1), [1]) - self.assertEqual(maybe_list([1, 2, 3]), [1, 2, 3]) - - def test_fxrange_no_repeatlast(self): - self.assertEqual(list(fxrange(1.0, 3.0, 1.0)), - [1.0, 2.0, 3.0]) - - def test_fxrangemax(self): - self.assertEqual(list(fxrangemax(1.0, 3.0, 1.0, 30.0)), - [1.0, 2.0, 3.0, 3.0, 3.0, 3.0, - 3.0, 3.0, 3.0, 3.0, 3.0]) - self.assertEqual(list(fxrangemax(1.0, None, 1.0, 30.0)), - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) - - def test_reprkwargs(self): - self.assertTrue(reprkwargs({'foo': 'bar', 1: 2, 'k': 'v'})) - - def test_reprcall(self): - self.assertTrue( - reprcall('add', (2, 2), {'copy': True}), - ) + assert retry_over_time( + fun, self.Predicate, + max_retries=0, errback=None, interval_max=14) == 42 + assert fun.calls == 11 + + +@pytest.mark.parametrize('obj,expected', [ + (None, None), + (1, [1]), + ([1, 2, 3], [1, 2, 3]), +]) +def test_maybe_list(obj, expected): + assert maybe_list(obj) == expected + + +def test_fxrange__no_repeatlast(): + assert list(fxrange(1.0, 3.0, 1.0)) == [1.0, 2.0, 3.0] + + +@pytest.mark.parametrize('args,expected', [ + ((1.0, 3.0, 1.0, 30.0), + [1.0, 2.0, 3.0, 3.0, 3.0, 3.0, + 3.0, 3.0, 3.0, 3.0, 3.0]), + ((1.0, None, 1.0, 30.0), + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]), +]) +def test_fxrangemax(args, expected): + assert list(fxrangemax(*args)) == expected + + +def test_reprkwargs(): + assert reprkwargs({'foo': 'bar', 1: 2, 'k': 'v'}) + + +def test_reprcall(): + assert reprcall('add', (2, 2), {'copy': True}) diff --git a/kombu/tests/utils/test_imports.py b/t/unit/utils/test_imports.py index 49e2a72f..20b9a2f8 100644 --- a/kombu/tests/utils/test_imports.py +++ b/t/unit/utils/test_imports.py @@ -1,37 +1,34 @@ from __future__ import absolute_import, unicode_literals +import pytest + +from case import Mock + from kombu import Exchange from kombu.utils.imports import symbol_by_name -from kombu.tests.case import Case, Mock - -class test_symbol_by_name(Case): +class test_symbol_by_name: def test_instance_returns_instance(self): instance = object() - self.assertIs(symbol_by_name(instance), instance) + assert symbol_by_name(instance) is instance def test_returns_default(self): default = object() - self.assertIs( - symbol_by_name('xyz.ryx.qedoa.weq:foz', default=default), - default, - ) + assert symbol_by_name( + 'xyz.ryx.qedoa.weq:foz', default=default) is default def test_no_default(self): - with self.assertRaises(ImportError): + with pytest.raises(ImportError): symbol_by_name('xyz.ryx.qedoa.weq:foz') def test_imp_reraises_ValueError(self): imp = Mock() imp.side_effect = ValueError() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): symbol_by_name('kombu.Connection', imp=imp) def test_package(self): - self.assertIs( - symbol_by_name('.entity:Exchange', package='kombu'), - Exchange, - ) - self.assertTrue(symbol_by_name(':Consumer', package='kombu')) + assert symbol_by_name('.entity:Exchange', package='kombu') is Exchange + assert symbol_by_name(':Consumer', package='kombu') diff --git a/kombu/tests/utils/test_json.py b/t/unit/utils/test_json.py index 1ad8d1c9..ae7b374d 100644 --- a/kombu/tests/utils/test_json.py +++ b/t/unit/utils/test_json.py @@ -1,17 +1,18 @@ from __future__ import absolute_import, unicode_literals +import pytest import pytz from datetime import datetime from decimal import Decimal from uuid import uuid4 +from case import MagicMock, Mock, skip + from kombu.five import text_t from kombu.utils.encoding import str_to_bytes from kombu.utils.json import _DecodeError, dumps, loads -from kombu.tests.case import Case, MagicMock, Mock, skip - class Custom(object): @@ -22,7 +23,7 @@ class Custom(object): return self.data -class test_JSONEncoder(Case): +class test_JSONEncoder: def test_datetime(self): now = datetime.utcnow() @@ -34,64 +35,58 @@ class test_JSONEncoder(Case): 'date': now.date(), 'time': now.time()}, )) - self.assertDictEqual(serialized, { + assert serialized == { 'datetime': now.isoformat(), 'tz': '{0}Z'.format(now_utc.isoformat().split('+', 1)[0]), 'time': now.time().isoformat(), 'date': stripped.isoformat(), - }) + } def test_Decimal(self): d = Decimal('3314132.13363235235324234123213213214134') - self.assertDictEqual(loads(dumps({'d': d})), { - 'd': text_t(d), - }) + assert loads(dumps({'d': d})), {'d': text_t(d)} def test_UUID(self): id = uuid4() - self.assertDictEqual(loads(dumps({'u': id})), { - 'u': text_t(id), - }) + assert loads(dumps({'u': id})), {'u': text_t(id)} def test_default(self): - with self.assertRaises(TypeError): + with pytest.raises(TypeError): dumps({'o': object()}) -class test_dumps_loads(Case): +class test_dumps_loads: def test_dumps_custom_object(self): x = {'foo': Custom({'a': 'b'})} - self.assertEqual(loads(dumps(x)), {'foo': x['foo'].__json__()}) + assert loads(dumps(x)) == {'foo': x['foo'].__json__()} def test_dumps_custom_object_no_json(self): x = {'foo': object()} - with self.assertRaises(TypeError): + with pytest.raises(TypeError): dumps(x) def test_loads_memoryview(self): - self.assertEqual( - loads(memoryview(bytearray(dumps({'x': 'z'}), encoding='utf-8'))), - {'x': 'z'}, - ) + assert loads( + memoryview(bytearray(dumps({'x': 'z'}), encoding='utf-8')) + ) == {'x': 'z'} def test_loads_bytearray(self): - self.assertEqual( - loads(bytearray(dumps({'x': 'z'}), encoding='utf-8')), - {'x': 'z'}) + assert loads( + bytearray(dumps({'x': 'z'}), encoding='utf-8') + ) == {'x': 'z'} def test_loads_bytes(self): - self.assertEqual( - loads(str_to_bytes(dumps({'x': 'z'})), decode_bytes=True), - {'x': 'z'}, - ) + assert loads( + str_to_bytes(dumps({'x': 'z'})), + decode_bytes=True) == {'x': 'z'} @skip.if_python3() def test_loads_buffer(self): - self.assertEqual(loads(buffer(dumps({'x': 'z'}))), {'x': 'z'}) + assert loads(buffer(dumps({'x': 'z'}))) == {'x': 'z'} def test_loads_DecodeError(self): _loads = Mock(name='_loads') _loads.side_effect = _DecodeError( MagicMock(), MagicMock(), MagicMock()) - self.assertEqual(loads(dumps({'x': 'z'}), _loads=_loads), {'x': 'z'}) + assert loads(dumps({'x': 'z'}), _loads=_loads) == {'x': 'z'} diff --git a/kombu/tests/utils/test_objects.py b/t/unit/utils/test_objects.py index 22893e48..8b13a405 100644 --- a/kombu/tests/utils/test_objects.py +++ b/t/unit/utils/test_objects.py @@ -2,10 +2,8 @@ from __future__ import absolute_import, unicode_literals from kombu.utils.objects import cached_property -from kombu.tests.case import Case - -class test_cached_property(Case): +class test_cached_property: def test_deleting(self): @@ -22,10 +20,10 @@ class test_cached_property(Case): x = X() del(x.foo) - self.assertFalse(x.xx) + assert not x.xx x.__dict__['foo'] = 'here' del(x.foo) - self.assertEqual(x.xx, 'here') + assert x.xx == 'here' def test_when_access_from_class(self): @@ -41,15 +39,15 @@ class test_cached_property(Case): self.xx = 10 desc = X.__dict__['foo'] - self.assertIs(X.foo, desc) + assert X.foo is desc - self.assertIs(desc.__get__(None), desc) - self.assertIs(desc.__set__(None, 1), desc) - self.assertIs(desc.__delete__(None), desc) - self.assertTrue(desc.setter(1)) + assert desc.__get__(None) is desc + assert desc.__set__(None, 1) is desc + assert desc.__delete__(None) is desc + assert desc.setter(1) x = X() x.foo = 30 - self.assertEqual(x.xx, 10) + assert x.xx == 10 del(x.foo) diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py new file mode 100644 index 00000000..894286cd --- /dev/null +++ b/t/unit/utils/test_scheduling.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import, unicode_literals + +import pytest + +from kombu.utils.scheduling import FairCycle, cycle_by_name + + +class MyEmpty(Exception): + pass + + +def consume(fun, n): + r = [] + for i in range(n): + r.append(fun()) + return r + + +class test_FairCycle: + + def test_cycle(self): + resources = ['a', 'b', 'c', 'd', 'e'] + + def echo(r, timeout=None): + return r + + # cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat] + cycle = FairCycle(echo, resources, MyEmpty) + for i in range(len(resources)): + assert cycle.get() == (resources[i], resources[i]) + for i in range(len(resources)): + assert cycle.get() == (resources[i], resources[i]) + + def test_cycle_breaks(self): + resources = ['a', 'b', 'c', 'd', 'e'] + + def echo(r): + if r == 'c': + raise MyEmpty(r) + return r + + cycle = FairCycle(echo, resources, MyEmpty) + assert consume(cycle.get, len(resources)) == [ + ('a', 'a'), ('b', 'b'), ('d', 'd'), + ('e', 'e'), ('a', 'a'), + ] + assert consume(cycle.get, len(resources)) == [ + ('b', 'b'), ('d', 'd'), ('e', 'e'), + ('a', 'a'), ('b', 'b'), + ] + cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty) + with pytest.raises(MyEmpty): + consume(cycle2.get, 3) + + def test_cycle_no_resources(self): + cycle = FairCycle(None, [], MyEmpty) + cycle.pos = 10 + + with pytest.raises(MyEmpty): + cycle._next() + + def test__repr__(self): + assert repr(FairCycle(lambda x: x, [1, 2, 3], MyEmpty)) + + +def test_round_robin_cycle(): + it = cycle_by_name('round_robin')(['A', 'B', 'C']) + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('B') + assert it.consume(3) == ['A', 'C', 'B'] + it.rotate('A') + assert it.consume(3) == ['C', 'B', 'A'] + it.rotate('A') + assert it.consume(3) == ['C', 'B', 'A'] + it.rotate('C') + assert it.consume(3) == ['B', 'A', 'C'] + + +def test_priority_cycle(): + it = cycle_by_name('priority')(['A', 'B', 'C']) + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('B') + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('A') + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('A') + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('C') + assert it.consume(3) == ['A', 'B', 'C'] + + +def test_sorted_cycle(): + it = cycle_by_name('sorted')(['B', 'C', 'A']) + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('B') + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('A') + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('A') + assert it.consume(3) == ['A', 'B', 'C'] + it.rotate('C') + assert it.consume(3) == ['A', 'B', 'C'] diff --git a/t/unit/utils/test_url.py b/t/unit/utils/test_url.py new file mode 100644 index 00000000..3d2b0ede --- /dev/null +++ b/t/unit/utils/test_url.py @@ -0,0 +1,39 @@ +from __future__ import absolute_import, unicode_literals + +import pytest + +from kombu.utils.url import as_url, parse_url, maybe_sanitize_url + + +def test_parse_url(): + assert parse_url('amqp://user:pass@localhost:5672/my/vhost') == { + 'transport': 'amqp', + 'userid': 'user', + 'password': 'pass', + 'hostname': 'localhost', + 'port': 5672, + 'virtual_host': 'my/vhost', + } + + +@pytest.mark.parametrize('urltuple,expected', [ + (('https',), 'https:///'), + (('https', 'e.com'), 'https://e.com/'), + (('https', 'e.com', 80), 'https://e.com:80/'), + (('https', 'e.com', 80, 'u'), 'https://u@e.com:80/'), + (('https', 'e.com', 80, 'u', 'p'), 'https://u:p@e.com:80/'), + (('https', 'e.com', 80, None, 'p'), 'https://:p@e.com:80/'), + (('https', 'e.com', 80, None, 'p', '/foo'), 'https://:p@e.com:80//foo'), +]) +def test_as_url(urltuple, expected): + assert as_url(*urltuple) == expected + + +@pytest.mark.parametrize('url,expected', [ + ('foo', 'foo'), + ('http://u:p@e.com//foo', 'http://u:**@e.com//foo'), +]) +def test_maybe_sanitize_url(url, expected): + assert maybe_sanitize_url(url) == expected + assert (maybe_sanitize_url('http://u:p@e.com//foo') == + 'http://u:**@e.com//foo') diff --git a/t/unit/utils/test_utils.py b/t/unit/utils/test_utils.py new file mode 100644 index 00000000..f4668b83 --- /dev/null +++ b/t/unit/utils/test_utils.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import, unicode_literals + +import pytest + +from kombu import version_info_t +from kombu.utils.text import version_string_as_tuple + + +def test_dir(): + import kombu + assert dir(kombu) + + +@pytest.mark.parametrize('version,expected', [ + ('3', version_info_t(3, 0, 0, '', '')), + ('3.3', version_info_t(3, 3, 0, '', '')), + ('3.3.1', version_info_t(3, 3, 1, '', '')), + ('3.3.1a3', version_info_t(3, 3, 1, 'a3', '')), + ('3.3.1.a3.40c32', version_info_t(3, 3, 1, 'a3', '40c32')), +]) +def test_version_string_as_tuple(version, expected): + assert version_string_as_tuple(version) == expected diff --git a/kombu/tests/utils/test_uuid.py b/t/unit/utils/test_uuid.py index ac018f09..f6b32c28 100644 --- a/kombu/tests/utils/test_uuid.py +++ b/t/unit/utils/test_uuid.py @@ -2,16 +2,14 @@ from __future__ import absolute_import, unicode_literals from kombu.utils.uuid import uuid -from kombu.tests.case import Case - -class test_UUID(Case): +class test_UUID: def test_uuid4(self): - self.assertNotEqual(uuid(), uuid()) + assert uuid() != uuid() def test_uuid(self): i1 = uuid() i2 = uuid() - self.assertIsInstance(i1, str) - self.assertNotEqual(i1, i2) + assert isinstance(i1, str) + assert i1 != i2 @@ -15,8 +15,7 @@ deps= flake8,flakeplus: -r{toxinidir}/requirements/pkgutils.txt commands = pip install -U -r{toxinidir}/requirements/dev.txt - nosetests -vdsx kombu.tests \ - --with-coverage --cover-inclusive --cover-erase [] + py.test -xv --cov=kombu/ --cov-report=xml --no-cov-on-fail basepython = 2.7,flakeplus,flake8,apicheck,linkcheck: python2.7 |