summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore6
-rw-r--r--MANIFEST.in6
-rw-r--r--kombu/tests/__init__.py62
-rw-r--r--kombu/tests/test_clocks.py104
-rw-r--r--kombu/tests/test_exceptions.py11
-rw-r--r--kombu/tests/transport/test_qpid.py1953
-rw-r--r--kombu/tests/transport/virtual/test_exchange.py200
-rw-r--r--kombu/tests/utils/test_div.py51
-rw-r--r--kombu/tests/utils/test_encoding.py109
-rw-r--r--kombu/tests/utils/test_scheduling.py112
-rw-r--r--kombu/tests/utils/test_url.py50
-rw-r--r--kombu/tests/utils/test_utils.py42
-rw-r--r--kombu/transport/redis.py7
-rw-r--r--kombu/transport/virtual/__init__.py1001
-rw-r--r--kombu/transport/virtual/base.py989
-rw-r--r--requirements/test-ci.txt2
-rw-r--r--requirements/test.txt3
-rw-r--r--setup.cfg7
-rw-r--r--setup.py53
-rw-r--r--t/__init__.py (renamed from kombu/tests/async/__init__.py)0
-rw-r--r--t/conftest.py98
-rw-r--r--t/mocks.py (renamed from kombu/tests/mocks.py)24
-rw-r--r--t/unit/__init__.py1
-rw-r--r--t/unit/async/__init__.py (renamed from kombu/tests/async/aws/__init__.py)0
-rw-r--r--t/unit/async/aws/__init__.py (renamed from kombu/tests/async/aws/sqs/__init__.py)0
-rw-r--r--t/unit/async/aws/case.py (renamed from kombu/tests/async/aws/case.py)7
-rw-r--r--t/unit/async/aws/sqs/__init__.py (renamed from kombu/tests/async/http/__init__.py)0
-rw-r--r--t/unit/async/aws/sqs/test_connection.py (renamed from kombu/tests/async/aws/sqs/test_connection.py)17
-rw-r--r--t/unit/async/aws/sqs/test_message.py (renamed from kombu/tests/async/aws/sqs/test_message.py)19
-rw-r--r--t/unit/async/aws/sqs/test_queue.py (renamed from kombu/tests/async/aws/sqs/test_queue.py)41
-rw-r--r--t/unit/async/aws/sqs/test_sqs.py (renamed from kombu/tests/async/aws/sqs/test_sqs.py)21
-rw-r--r--t/unit/async/aws/test_aws.py (renamed from kombu/tests/async/aws/test_aws.py)7
-rw-r--r--t/unit/async/aws/test_connection.py (renamed from kombu/tests/async/aws/test_connection.py)148
-rw-r--r--t/unit/async/http/__init__.py (renamed from kombu/tests/transport/__init__.py)0
-rw-r--r--t/unit/async/http/test_curl.py (renamed from kombu/tests/async/http/test_curl.py)47
-rw-r--r--t/unit/async/http/test_http.py (renamed from kombu/tests/async/http/test_http.py)91
-rw-r--r--t/unit/async/test_hub.py (renamed from kombu/tests/async/test_hub.py)203
-rw-r--r--t/unit/async/test_semaphore.py (renamed from kombu/tests/async/test_semaphore.py)24
-rw-r--r--t/unit/async/test_timer.py (renamed from kombu/tests/async/test_timer.py)73
-rw-r--r--t/unit/case.py (renamed from kombu/tests/case.py)0
-rw-r--r--t/unit/test_clocks.py89
-rw-r--r--t/unit/test_common.py (renamed from kombu/tests/test_common.py)183
-rw-r--r--t/unit/test_compat.py (renamed from kombu/tests/test_compat.py)155
-rw-r--r--t/unit/test_compression.py (renamed from kombu/tests/test_compression.py)20
-rw-r--r--t/unit/test_connection.py (renamed from kombu/tests/test_connection.py)384
-rw-r--r--t/unit/test_entity.py (renamed from kombu/tests/test_entity.py)220
-rw-r--r--t/unit/test_exceptions.py11
-rw-r--r--t/unit/test_log.py (renamed from kombu/tests/test_log.py)62
-rw-r--r--t/unit/test_message.py (renamed from kombu/tests/test_message.py)21
-rw-r--r--t/unit/test_messaging.py (renamed from kombu/tests/test_messaging.py)286
-rw-r--r--t/unit/test_mixins.py (renamed from kombu/tests/test_mixins.py)45
-rw-r--r--t/unit/test_pidbox.py (renamed from kombu/tests/test_pidbox.py)125
-rw-r--r--t/unit/test_pools.py (renamed from kombu/tests/test_pools.py)84
-rw-r--r--t/unit/test_serialization.py (renamed from kombu/tests/test_serialization.py)209
-rw-r--r--t/unit/test_simple.py (renamed from kombu/tests/test_simple.py)69
-rw-r--r--t/unit/test_syn.py (renamed from kombu/tests/test_syn.py)24
-rw-r--r--t/unit/transport/__init__.py (renamed from kombu/tests/transport/virtual/__init__.py)0
-rw-r--r--t/unit/transport/test_SQS.py (renamed from kombu/tests/transport/test_SQS.py)73
-rw-r--r--t/unit/transport/test_base.py (renamed from kombu/tests/transport/test_base.py)52
-rw-r--r--t/unit/transport/test_consul.py (renamed from kombu/tests/transport/test_consul.py)38
-rw-r--r--t/unit/transport/test_filesystem.py (renamed from kombu/tests/transport/test_filesystem.py)50
-rw-r--r--t/unit/transport/test_librabbitmq.py (renamed from kombu/tests/transport/test_librabbitmq.py)66
-rw-r--r--t/unit/transport/test_memory.py (renamed from kombu/tests/transport/test_memory.py)37
-rw-r--r--t/unit/transport/test_mongodb.py (renamed from kombu/tests/transport/test_mongodb.py)52
-rw-r--r--t/unit/transport/test_pyamqp.py (renamed from kombu/tests/transport/test_pyamqp.py)50
-rw-r--r--t/unit/transport/test_redis.py (renamed from kombu/tests/transport/test_redis.py)299
-rw-r--r--t/unit/transport/test_transport.py (renamed from kombu/tests/transport/test_transport.py)19
-rw-r--r--t/unit/transport/virtual/__init__.py (renamed from kombu/tests/utils/__init__.py)0
-rw-r--r--t/unit/transport/virtual/test_base.py (renamed from kombu/tests/transport/virtual/test_base.py)271
-rw-r--r--t/unit/transport/virtual/test_exchange.py169
-rw-r--r--t/unit/utils/__init__.py0
-rw-r--r--t/unit/utils/test_amq_manager.py (renamed from kombu/tests/utils/test_amq_manager.py)14
-rw-r--r--t/unit/utils/test_compat.py (renamed from kombu/tests/utils/test_compat.py)24
-rw-r--r--t/unit/utils/test_debug.py (renamed from kombu/tests/utils/test_debug.py)20
-rw-r--r--t/unit/utils/test_div.py49
-rw-r--r--t/unit/utils/test_encoding.py105
-rw-r--r--t/unit/utils/test_functional.py (renamed from kombu/tests/utils/test_functional.py)209
-rw-r--r--t/unit/utils/test_imports.py (renamed from kombu/tests/utils/test_imports.py)27
-rw-r--r--t/unit/utils/test_json.py (renamed from kombu/tests/utils/test_json.py)51
-rw-r--r--t/unit/utils/test_objects.py (renamed from kombu/tests/utils/test_objects.py)20
-rw-r--r--t/unit/utils/test_scheduling.py102
-rw-r--r--t/unit/utils/test_url.py39
-rw-r--r--t/unit/utils/test_utils.py22
-rw-r--r--t/unit/utils/test_uuid.py (renamed from kombu/tests/utils/test_uuid.py)10
-rw-r--r--tox.ini3
85 files changed, 3674 insertions, 5773 deletions
diff --git a/.gitignore b/.gitignore
index b8e0befb..5d61bcd4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,7 +8,7 @@ settings_local.py
build/
.build/
_build/
-.*.sw[pon]
+.*.sw*
dist/
*.egg-info
pip-log.txt
@@ -26,3 +26,7 @@ kombu/tests/coverage.xml
.coverage
dump.rdb
.idea/
+.cache/
+htmlcov/
+test.db
+coverage.xml
diff --git a/MANIFEST.in b/MANIFEST.in
index 35c0ac82..ff42284e 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -12,6 +12,12 @@ include setup.cfg
recursive-include extra *
recursive-include docs *
recursive-include kombu *.py
+recursive-include t *.py
recursive-include requirements *.txt
recursive-include funtests *.py setup.cfg
recursive-include examples *.py
+
+recursive-exclude docs/_build *
+recursive-exclude * __pycache__
+recursive-exclude * *.py[co]
+recursive-exclude * .*.sw*
diff --git a/kombu/tests/__init__.py b/kombu/tests/__init__.py
deleted file mode 100644
index 17ac49e3..00000000
--- a/kombu/tests/__init__.py
+++ /dev/null
@@ -1,62 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import atexit
-import os
-import sys
-
-from kombu.exceptions import VersionMismatch
-
-
-def teardown():
- # Workaround for multiprocessing bug where logging
- # is attempted after global already collected at shutdown.
- canceled = set()
- try:
- import multiprocessing.util
- canceled.add(multiprocessing.util._exit_function)
- except (AttributeError, ImportError):
- pass
-
- try:
- atexit._exithandlers[:] = [
- e for e in atexit._exithandlers if e[0] not in canceled
- ]
- except AttributeError: # pragma: no cover
- pass # Py3 missing _exithandlers
-
-
-def find_distribution_modules(name=__name__, file=__file__):
- current_dist_depth = len(name.split('.')) - 1
- current_dist = os.path.join(os.path.dirname(file),
- *([os.pardir] * current_dist_depth))
- abs = os.path.abspath(current_dist)
- dist_name = os.path.basename(abs)
-
- for dirpath, dirnames, filenames in os.walk(abs):
- package = (dist_name + dirpath[len(abs):]).replace('/', '.')
- if '__init__.py' in filenames:
- yield package
- for filename in filenames:
- if filename.endswith('.py') and filename != '__init__.py':
- yield '.'.join([package, filename])[:-3]
-
-
-def import_all_modules(name=__name__, file=__file__, skip=[]):
- for module in find_distribution_modules(name, file):
- if module not in skip:
- print('preimporting %r for coverage...' % (module,))
- try:
- __import__(module)
- except (ImportError, VersionMismatch, AttributeError):
- pass
-
-
-def is_in_coverage():
- return (os.environ.get('COVER_ALL_MODULES') or
- '--with-coverage3' in sys.argv)
-
-
-def setup():
- # so coverage sees all our modules.
- if is_in_coverage():
- import_all_modules()
diff --git a/kombu/tests/test_clocks.py b/kombu/tests/test_clocks.py
deleted file mode 100644
index ef4ab87d..00000000
--- a/kombu/tests/test_clocks.py
+++ /dev/null
@@ -1,104 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import pickle
-
-from heapq import heappush
-from time import time
-
-from kombu.clocks import LamportClock, timetuple
-
-from .case import Mock, Case
-
-
-class test_LamportClock(Case):
-
- def test_clocks(self):
- c1 = LamportClock()
- c2 = LamportClock()
-
- c1.forward()
- c2.forward()
- c1.forward()
- c1.forward()
- c2.adjust(c1.value)
- self.assertEqual(c2.value, c1.value + 1)
- self.assertTrue(repr(c1))
-
- c2_val = c2.value
- c2.forward()
- c2.forward()
- c2.adjust(c1.value)
- self.assertEqual(c2.value, c2_val + 2 + 1)
-
- c1.adjust(c2.value)
- self.assertEqual(c1.value, c2.value + 1)
-
- def test_sort(self):
- c = LamportClock()
- pid1 = 'a.example.com:312'
- pid2 = 'b.example.com:311'
-
- events = []
-
- m1 = (c.forward(), pid1)
- heappush(events, m1)
- m2 = (c.forward(), pid2)
- heappush(events, m2)
- m3 = (c.forward(), pid1)
- heappush(events, m3)
- m4 = (30, pid1)
- heappush(events, m4)
- m5 = (30, pid2)
- heappush(events, m5)
-
- self.assertEqual(str(c), str(c.value))
-
- self.assertEqual(c.sort_heap(events), m1)
- self.assertEqual(c.sort_heap([m4, m5]), m4)
- self.assertEqual(c.sort_heap([m4, m5, m1]), m4)
-
-
-class test_timetuple(Case):
-
- def test_repr(self):
- x = timetuple(133, time(), 'id', Mock())
- self.assertTrue(repr(x))
-
- def test_pickleable(self):
- x = timetuple(133, time(), 'id', 'obj')
- self.assertEqual(pickle.loads(pickle.dumps(x)), tuple(x))
-
- def test_order(self):
- t1 = time()
- t2 = time() + 300 # windows clock not reliable
- a = timetuple(133, t1, 'A', 'obj')
- b = timetuple(140, t1, 'A', 'obj')
- self.assertTrue(a.__getnewargs__())
- self.assertEqual(a.clock, 133)
- self.assertEqual(a.timestamp, t1)
- self.assertEqual(a.id, 'A')
- self.assertEqual(a.obj, 'obj')
- self.assertTrue(
- a <= b,
- )
- self.assertTrue(
- b >= a,
- )
-
- self.assertEqual(
- timetuple(134, time(), 'A', 'obj').__lt__(tuple()),
- NotImplemented,
- )
- self.assertGreater(
- timetuple(134, t2, 'A', 'obj'),
- timetuple(133, t1, 'A', 'obj'),
- )
- self.assertGreater(
- timetuple(134, t1, 'B', 'obj'),
- timetuple(134, t1, 'A', 'obj'),
- )
-
- self.assertGreater(
- timetuple(None, t2, 'B', 'obj'),
- timetuple(None, t1, 'A', 'obj'),
- )
diff --git a/kombu/tests/test_exceptions.py b/kombu/tests/test_exceptions.py
deleted file mode 100644
index 1d04b359..00000000
--- a/kombu/tests/test_exceptions.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu.exceptions import HttpError
-
-from kombu.tests.case import Case, Mock
-
-
-class test_HttpError(Case):
-
- def test_str(self):
- self.assertTrue(str(HttpError(200, 'msg', Mock(name='response'))))
diff --git a/kombu/tests/transport/test_qpid.py b/kombu/tests/transport/test_qpid.py
deleted file mode 100644
index 0e913d84..00000000
--- a/kombu/tests/transport/test_qpid.py
+++ /dev/null
@@ -1,1953 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import select
-import ssl
-import socket
-import sys
-import time
-import uuid
-
-from collections import Callable, OrderedDict
-from itertools import count
-from functools import wraps
-
-from mock import call
-from nose import SkipTest
-
-from kombu.five import Empty, keys, range, monotonic
-from kombu.transport.qpid import (AuthenticationFailure, Channel, Connection,
- ConnectionError, Message, NotFound, QoS,
- Transport)
-from kombu.transport.virtual import Base64
-from kombu.tests.case import Case, Mock
-from kombu.tests.case import patch
-
-
-QPID_MODULE = 'kombu.transport.qpid'
-PY3 = sys.version_info[0] == 3
-
-
-def case_no_pypy(cls):
- setup = cls.setUp
-
- @wraps(setup)
- def around_setup(self):
- if getattr(sys, 'pypy_version_info', None):
- raise SkipTest('pypy incompatible')
- setup(self)
- cls.setUp = around_setup
- return cls
-
-
-def case_no_python3(cls):
- setup = cls.setUp
-
- @wraps(setup)
- def around_setup(self):
- if PY3:
- raise SkipTest('Python3 incompatible')
- setup(self)
- cls.setUp = around_setup
- return cls
-
-
-def disable_runtime_dependency_check(cls):
- """A decorator to disable runtime dependency checking"""
- setup = cls.setUp
- teardown = cls.tearDown
- dependency_is_none_patcher = patch(QPID_MODULE + '.dependency_is_none')
-
- @wraps(setup)
- def around_setup(self):
- mock_dependency_is_none = dependency_is_none_patcher.start()
- mock_dependency_is_none.return_value = False
- setup(self)
-
- @wraps(setup)
- def around_teardown(self):
- dependency_is_none_patcher.stop()
- teardown(self)
-
- cls.setUp = around_setup
- cls.tearDown = around_teardown
- return cls
-
-
-class ExtraAssertionsMixin(object):
- """A mixin class adding assertDictEqual and assertDictContainsSubset"""
-
- def assertDictEqual(self, a, b, msg=None):
- """
- Test that two dictionaries are equal.
-
- Implemented here because this method was not available until Python
- 2.6. This asserts that the unique set of keys are the same in a and b.
- Also asserts that the value of each key is the same in a and b using
- the is operator.
- """
- self.assertEqual(set(keys(a)), set(keys(b)))
- for key in keys(a):
- self.assertEqual(a[key], b[key])
-
- def assertDictContainsSubset(self, a, b, msg=None):
- """
- Assert that all the key/value pairs in a exist in b.
- """
- for key in keys(a):
- self.assertIn(key, b)
- self.assertEqual(a[key], b[key])
-
-
-class QpidException(Exception):
- """
- An object used to mock Exceptions provided by qpid.messaging.exceptions
- """
-
- def __init__(self, code=None, text=None):
- super(Exception, self).__init__(self)
- self.code = code
- self.text = text
-
-
-class BreakOutException(Exception):
- pass
-
-
-@case_no_python3
-@case_no_pypy
-class TestQoS__init__(Case):
-
- def setUp(self):
- self.mock_session = Mock()
- self.qos = QoS(self.mock_session)
-
- def test__init__prefetch_default_set_correct_without_prefetch_value(self):
- self.assertEqual(self.qos.prefetch_count, 1)
-
- def test__init__prefetch_is_hard_set_to_one(self):
- qos_limit_two = QoS(self.mock_session)
- self.assertEqual(qos_limit_two.prefetch_count, 1)
-
- def test__init___not_yet_acked_is_initialized(self):
- self.assertIsInstance(self.qos._not_yet_acked, OrderedDict)
-
-
-@case_no_python3
-@case_no_pypy
-class TestQoSCanConsume(Case):
-
- def setUp(self):
- session = Mock()
- self.qos = QoS(session)
-
- def test_True_when_prefetch_limit_is_zero(self):
- self.qos.prefetch_count = 0
- self.qos._not_yet_acked = []
- self.assertTrue(self.qos.can_consume())
-
- def test_True_when_len_of__not_yet_acked_is_lt_prefetch_count(self):
- self.qos.prefetch_count = 3
- self.qos._not_yet_acked = ['a', 'b']
- self.assertTrue(self.qos.can_consume())
-
- def test_False_when_len_of__not_yet_acked_is_eq_prefetch_count(self):
- self.qos.prefetch_count = 3
- self.qos._not_yet_acked = ['a', 'b', 'c']
- self.assertFalse(self.qos.can_consume())
-
-
-@case_no_python3
-@case_no_pypy
-class TestQoSCanConsumeMaxEstimate(Case):
-
- def setUp(self):
- self.mock_session = Mock()
- self.qos = QoS(self.mock_session)
-
- def test_return_one_when_prefetch_count_eq_zero(self):
- self.qos.prefetch_count = 0
- self.assertEqual(self.qos.can_consume_max_estimate(), 1)
-
- def test_return_prefetch_count_sub_len__not_yet_acked(self):
- self.qos._not_yet_acked = ['a', 'b']
- self.qos.prefetch_count = 4
- self.assertEqual(self.qos.can_consume_max_estimate(), 2)
-
-
-@case_no_python3
-@case_no_pypy
-class TestQoSAck(Case):
-
- def setUp(self):
- self.mock_session = Mock()
- self.qos = QoS(self.mock_session)
-
- def test_ack_pops__not_yet_acked(self):
- message = Mock()
- self.qos.append(message, 1)
- self.assertIn(1, self.qos._not_yet_acked)
- self.qos.ack(1)
- self.assertNotIn(1, self.qos._not_yet_acked)
-
- def test_ack_calls_session_acknowledge_with_message(self):
- message = Mock()
- self.qos.append(message, 1)
- self.qos.ack(1)
- self.qos.session.acknowledge.assert_called_with(message=message)
-
-
-@case_no_python3
-@case_no_pypy
-class TestQoSReject(Case):
-
- def setUp(self):
- self.mock_session = Mock()
- self.mock_message = Mock()
- self.qos = QoS(self.mock_session)
- self.patch_qpid = patch(QPID_MODULE + '.qpid')
- self.mock_qpid = self.patch_qpid.start()
- self.mock_Disposition = self.mock_qpid.messaging.Disposition
- self.mock_RELEASED = self.mock_qpid.messaging.RELEASED
- self.mock_REJECTED = self.mock_qpid.messaging.REJECTED
-
- def tearDown(self):
- self.patch_qpid.stop()
-
- def test_reject_pops__not_yet_acked(self):
- self.qos.append(self.mock_message, 1)
- self.assertIn(1, self.qos._not_yet_acked)
- self.qos.reject(1)
- self.assertNotIn(1, self.qos._not_yet_acked)
-
- def test_reject_requeue_true(self):
- self.qos.append(self.mock_message, 1)
- self.qos.reject(1, requeue=True)
- self.mock_Disposition.assert_called_with(self.mock_RELEASED)
- self.qos.session.acknowledge.assert_called_with(
- message=self.mock_message,
- disposition=self.mock_Disposition.return_value,
- )
-
- def test_reject_requeue_false(self):
- message = Mock()
- self.qos.append(message, 1)
- self.qos.reject(1, requeue=False)
- self.mock_Disposition.assert_called_with(self.mock_REJECTED)
- self.qos.session.acknowledge.assert_called_with(
- message=message, disposition=self.mock_Disposition.return_value,
- )
-
-
-@case_no_python3
-@case_no_pypy
-class TestQoS(Case):
-
- def mock_message_factory(self):
- """Create and return a mock message tag and delivery_tag."""
- m_delivery_tag = self.delivery_tag_generator.next()
- m = 'message %s' % (m_delivery_tag, )
- return m, m_delivery_tag
-
- def add_n_messages_to_qos(self, n, qos):
- """Add N mock messages into the passed in qos object"""
- for i in range(n):
- self.add_message_to_qos(qos)
-
- def add_message_to_qos(self, qos):
- """Add a single mock message into the passed in qos object.
-
- Uses the mock_message_factory() to create the message and
- delivery_tag.
- """
- m, m_delivery_tag = self.mock_message_factory()
- qos.append(m, m_delivery_tag)
-
- def setUp(self):
- self.mock_session = Mock()
- self.qos_no_limit = QoS(self.mock_session)
- self.qos_limit_2 = QoS(self.mock_session, prefetch_count=2)
- self.delivery_tag_generator = count(1)
-
- def test_append(self):
- """Append two messages and check inside the QoS object that they
- were put into the internal data structures correctly
- """
- qos = self.qos_no_limit
- m1, m1_tag = self.mock_message_factory()
- m2, m2_tag = self.mock_message_factory()
- qos.append(m1, m1_tag)
- length_not_yet_acked = len(qos._not_yet_acked)
- self.assertEqual(length_not_yet_acked, 1)
- checked_message1 = qos._not_yet_acked[m1_tag]
- self.assertIs(m1, checked_message1)
- qos.append(m2, m2_tag)
- length_not_yet_acked = len(qos._not_yet_acked)
- self.assertEqual(length_not_yet_acked, 2)
- checked_message2 = qos._not_yet_acked[m2_tag]
- self.assertIs(m2, checked_message2)
-
- def test_get(self):
- """Append two messages, and use get to receive them"""
- qos = self.qos_no_limit
- m1, m1_tag = self.mock_message_factory()
- m2, m2_tag = self.mock_message_factory()
- qos.append(m1, m1_tag)
- qos.append(m2, m2_tag)
- message1 = qos.get(m1_tag)
- message2 = qos.get(m2_tag)
- self.assertIs(m1, message1)
- self.assertIs(m2, message2)
-
-
-@case_no_python3
-@case_no_pypy
-class ConnectionTestBase(Case):
-
- @patch(QPID_MODULE + '.qpid')
- def setUp(self, mock_qpid):
- self.connection_options = {
- 'host': 'localhost',
- 'port': 5672,
- 'transport': 'tcp',
- 'timeout': 10,
- 'sasl_mechanisms': 'ANONYMOUS',
- }
- self.mock_qpid_connection = mock_qpid.messaging.Connection
- self.conn = Connection(**self.connection_options)
-
-
-@case_no_python3
-@case_no_pypy
-class TestConnectionInit(ExtraAssertionsMixin, ConnectionTestBase):
-
- def test_stores_connection_options(self):
- # ensure that only one mech was passed into connection. The other
- # options should all be passed through as-is
- modified_conn_opts = self.connection_options
- self.assertDictEqual(
- modified_conn_opts, self.conn.connection_options,
- )
-
- def test_class_variables(self):
- self.assertIsInstance(self.conn.channels, list)
- self.assertIsInstance(self.conn._callbacks, dict)
-
- def test_establishes_connection(self):
- modified_conn_opts = self.connection_options
- self.mock_qpid_connection.establish.assert_called_with(
- **modified_conn_opts
- )
-
- def test_saves_established_connection(self):
- created_conn = self.mock_qpid_connection.establish.return_value
- self.assertIs(self.conn._qpid_conn, created_conn)
-
- @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, ))
- @patch(QPID_MODULE + '.sys.exc_info')
- @patch(QPID_MODULE + '.qpid')
- def test_mutates_ConnError_by_message(self, mock_qpid, mock_exc_info):
- text = 'connection-forced: Authentication failed(320)'
- my_conn_error = QpidException(text=text)
- mock_qpid.messaging.Connection.establish.side_effect = my_conn_error
- mock_exc_info.return_value = 'a', 'b', None
- try:
- self.conn = Connection(**self.connection_options)
- except AuthenticationFailure as error:
- exc_info = sys.exc_info()
- self.assertNotIsInstance(error, QpidException)
- self.assertIs(exc_info[1], 'b')
- self.assertIsNone(exc_info[2])
- else:
- self.fail('ConnectionError type was not mutated correctly')
-
- @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, ))
- @patch(QPID_MODULE + '.sys.exc_info')
- @patch(QPID_MODULE + '.qpid')
- def test_mutates_ConnError_by_code(self, mock_qpid, mock_exc_info):
- my_conn_error = QpidException(code=320, text='someothertext')
- mock_qpid.messaging.Connection.establish.side_effect = my_conn_error
- mock_exc_info.return_value = 'a', 'b', None
- try:
- self.conn = Connection(**self.connection_options)
- except AuthenticationFailure as error:
- exc_info = sys.exc_info()
- self.assertNotIsInstance(error, QpidException)
- self.assertIs(exc_info[1], 'b')
- self.assertIsNone(exc_info[2])
- else:
- self.fail('ConnectionError type was not mutated correctly')
-
- @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, ))
- @patch(QPID_MODULE + '.sys.exc_info')
- @patch(QPID_MODULE + '.qpid')
- def test_connection__init__mutates_ConnError_by_message2(self, mock_qpid,
- mock_exc_info):
- """
- Test for PLAIN connection via python-saslwrapper, sans cyrus-sasl-plain
-
- This test is specific for what is returned when we attempt to connect
- with PLAIN mech and python-saslwrapper is installed, but
- cyrus-sasl-plain is not installed.
- """
- my_conn_error = QpidException()
- my_conn_error.text = 'Error in sasl_client_start (-4) SASL(-4): no '\
- 'mechanism available'
- mock_qpid.messaging.Connection.establish.side_effect = my_conn_error
- mock_exc_info.return_value = ('a', 'b', None)
- try:
- self.conn = Connection(**self.connection_options)
- except AuthenticationFailure as error:
- exc_info = sys.exc_info()
- self.assertTrue(not isinstance(error, QpidException))
- self.assertTrue(exc_info[1] is 'b')
- self.assertTrue(exc_info[2] is None)
- else:
- self.fail('ConnectionError type was not mutated correctly')
-
- @patch(QPID_MODULE + '.ConnectionError', new=(QpidException, ))
- @patch(QPID_MODULE + '.sys.exc_info')
- @patch(QPID_MODULE + '.qpid')
- def test_unknown_connection_error(self, mock_qpid, mock_exc_info):
- # If we get a connection error that we don't understand,
- # bubble it up as-is
- my_conn_error = QpidException(code=999, text='someothertext')
- mock_qpid.messaging.Connection.establish.side_effect = my_conn_error
- mock_exc_info.return_value = 'a', 'b', None
- try:
- self.conn = Connection(**self.connection_options)
- except Exception as error:
- self.assertTrue(error.code == 999)
- else:
- self.fail('Connection should have thrown an exception')
-
- @patch.object(Transport, 'channel_errors', new=(QpidException, ))
- @patch(QPID_MODULE + '.qpid')
- @patch(QPID_MODULE + '.ConnectionError', new=IOError)
- def test_non_qpid_error_raises(self, mock_qpid):
- mock_Qpid_Connection = mock_qpid.messaging.Connection
- my_conn_error = SyntaxError()
- my_conn_error.text = 'some non auth related error message'
- mock_Qpid_Connection.establish.side_effect = my_conn_error
- with self.assertRaises(SyntaxError):
- Connection(**self.connection_options)
-
- @patch(QPID_MODULE + '.qpid')
- @patch(QPID_MODULE + '.ConnectionError', new=IOError)
- def test_non_auth_conn_error_raises(self, mock_qpid):
- mock_Qpid_Connection = mock_qpid.messaging.Connection
- my_conn_error = IOError()
- my_conn_error.text = 'some non auth related error message'
- mock_Qpid_Connection.establish.side_effect = my_conn_error
- with self.assertRaises(IOError):
- Connection(**self.connection_options)
-
-
-@case_no_python3
-@case_no_pypy
-class TestConnectionClassAttributes(ConnectionTestBase):
-
- def test_connection_verify_class_attributes(self):
- self.assertEqual(Channel, Connection.Channel)
-
-
-@case_no_python3
-@case_no_pypy
-class TestConnectionGetQpidConnection(ConnectionTestBase):
-
- def test_connection_get_qpid_connection(self):
- self.conn._qpid_conn = Mock()
- returned_connection = self.conn.get_qpid_connection()
- self.assertIs(self.conn._qpid_conn, returned_connection)
-
-
-@case_no_python3
-@case_no_pypy
-class TestConnectionClose(ConnectionTestBase):
-
- def test_connection_close(self):
- self.conn._qpid_conn = Mock()
- self.conn.close()
- self.conn._qpid_conn.close.assert_called_once_with()
-
-
-@case_no_python3
-@case_no_pypy
-class TestConnectionCloseChannel(ConnectionTestBase):
-
- def setUp(self):
- super(TestConnectionCloseChannel, self).setUp()
- self.conn.channels = Mock()
-
- def test_connection_close_channel_removes_channel_from_channel_list(self):
- mock_channel = Mock()
- self.conn.close_channel(mock_channel)
- self.conn.channels.remove.assert_called_once_with(mock_channel)
-
- def test_connection_close_channel_handles_ValueError_being_raised(self):
- self.conn.channels.remove = Mock(side_effect=ValueError())
- self.conn.close_channel(Mock())
-
- def test_connection_close_channel_set_channel_connection_to_None(self):
- mock_channel = Mock()
- mock_channel.connection = False
- self.conn.channels.remove = Mock(side_effect=ValueError())
- self.conn.close_channel(mock_channel)
- self.assertIsNone(mock_channel.connection)
-
-
-@case_no_python3
-@case_no_pypy
-class ChannelTestBase(Case):
-
- def setUp(self):
- self.patch_qpidtoollibs = patch(QPID_MODULE + '.qpidtoollibs')
- self.mock_qpidtoollibs = self.patch_qpidtoollibs.start()
- self.mock_broker_agent = self.mock_qpidtoollibs.BrokerAgent
- self.conn = Mock()
- self.transport = Mock()
- self.channel = Channel(self.conn, self.transport)
-
- def tearDown(self):
- self.patch_qpidtoollibs.stop()
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelPurge(ChannelTestBase):
-
- def setUp(self):
- super(TestChannelPurge, self).setUp()
- self.mock_queue = Mock()
-
- def test_gets_queue(self):
- self.channel._purge(self.mock_queue)
- getQueue = self.mock_broker_agent.return_value.getQueue
- getQueue.assert_called_once_with(self.mock_queue)
-
- def test_does_not_call_purge_if_message_count_is_zero(self):
- values = {'msgDepth': 0}
- queue_obj = self.mock_broker_agent.return_value.getQueue.return_value
- queue_obj.values = values
- self.channel._purge(self.mock_queue)
- self.assertFalse(queue_obj.purge.called)
-
- def test_purges_all_messages_from_queue(self):
- values = {'msgDepth': 5}
- queue_obj = self.mock_broker_agent.return_value.getQueue.return_value
- queue_obj.values = values
- self.channel._purge(self.mock_queue)
- queue_obj.purge.assert_called_with(5)
-
- def test_returns_message_count(self):
- values = {'msgDepth': 5}
- queue_obj = self.mock_broker_agent.return_value.getQueue.return_value
- queue_obj.values = values
- result = self.channel._purge(self.mock_queue)
- self.assertEqual(result, 5)
-
- @patch(QPID_MODULE + '.NotFound', new=QpidException)
- def test_raises_channel_error_if_queue_does_not_exist(self):
- self.mock_broker_agent.return_value.getQueue.return_value = None
- self.assertRaises(QpidException, self.channel._purge, self.mock_queue)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelPut(ChannelTestBase):
-
- @patch(QPID_MODULE + '.qpid')
- def test_channel__put_onto_queue(self, mock_qpid):
- routing_key = 'routingkey'
- mock_message = Mock()
- mock_Message_cls = mock_qpid.messaging.Message
-
- self.channel._put(routing_key, mock_message)
-
- address_str = '{0}; {{assert: always, node: {{type: queue}}}}'.format(
- routing_key,
- )
- self.transport.session.sender.assert_called_with(address_str)
- mock_Message_cls.assert_called_with(
- content=mock_message, subject=None,
- )
- mock_sender = self.transport.session.sender.return_value
- mock_sender.send.assert_called_with(
- mock_Message_cls.return_value, sync=True,
- )
- mock_sender.close.assert_called_with()
-
- @patch(QPID_MODULE + '.qpid')
- def test_channel__put_onto_exchange(self, mock_qpid):
- mock_routing_key = 'routingkey'
- mock_exchange_name = 'myexchange'
- mock_message = Mock()
- mock_Message_cls = mock_qpid.messaging.Message
-
- self.channel._put(mock_routing_key, mock_message, mock_exchange_name)
-
- addrstr = '{0}/{1}; {{assert: always, node: {{type: topic}}}}'.format(
- mock_exchange_name, mock_routing_key,
- )
- self.transport.session.sender.assert_called_with(addrstr)
- mock_Message_cls.assert_called_with(
- content=mock_message, subject=mock_routing_key,
- )
- mock_sender = self.transport.session.sender.return_value
- mock_sender.send.assert_called_with(
- mock_Message_cls.return_value, sync=True,
- )
- mock_sender.close.assert_called_with()
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelGet(ChannelTestBase):
-
- def test_channel__get(self):
- mock_queue = Mock()
-
- result = self.channel._get(mock_queue)
-
- self.transport.session.receiver.assert_called_once_with(mock_queue)
- mock_rx = self.transport.session.receiver.return_value
- mock_rx.fetch.assert_called_once_with(timeout=0)
- mock_rx.close.assert_called_once_with()
- self.assertIs(mock_rx.fetch.return_value, result)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelClose(ChannelTestBase):
-
- def setUp(self):
- super(TestChannelClose, self).setUp()
- self.patch_basic_cancel = patch.object(self.channel, 'basic_cancel')
- self.mock_basic_cancel = self.patch_basic_cancel.start()
- self.mock_receiver1 = Mock()
- self.mock_receiver2 = Mock()
- self.channel._receivers = {
- 1: self.mock_receiver1, 2: self.mock_receiver2,
- }
- self.channel.closed = False
-
- def tearDown(self):
- self.patch_basic_cancel.stop()
- super(TestChannelClose, self).tearDown()
-
- def test_channel_close_sets_close_attribute(self):
- self.channel.close()
- self.assertTrue(self.channel.closed)
-
- def test_channel_close_calls_basic_cancel_on_all_receivers(self):
- self.channel.close()
- self.mock_basic_cancel.assert_has_calls([call(1), call(2)])
-
- def test_channel_close_calls_close_channel_on_connection(self):
- self.channel.close()
- self.conn.close_channel.assert_called_once_with(self.channel)
-
- def test_channel_close_calls_close_on_broker_agent(self):
- self.channel.close()
- self.channel._broker.close.assert_called_once_with()
-
- def test_channel_close_does_nothing_if_already_closed(self):
- self.channel.closed = True
- self.channel.close()
- self.assertFalse(self.mock_basic_cancel.called)
-
- def test_channel_close_does_not_call_close_channel_if_conn_is_None(self):
- self.channel.connection = None
- self.channel.close()
- self.assertFalse(self.conn.close_channel.called)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelBasicQoS(ChannelTestBase):
-
- def test_channel_basic_qos_always_returns_one(self):
- self.channel.basic_qos(2)
- self.assertEqual(self.channel.qos.prefetch_count, 1)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelBasicGet(ChannelTestBase):
-
- def setUp(self):
- super(TestChannelBasicGet, self).setUp()
- self.channel.Message = Mock()
- self.channel._get = Mock()
-
- def test_channel_basic_get_calls__get_with_queue(self):
- mock_queue = Mock()
- self.channel.basic_get(mock_queue)
- self.channel._get.assert_called_once_with(mock_queue)
-
- def test_channel_basic_get_creates_Message_correctly(self):
- mock_queue = Mock()
- self.channel.basic_get(mock_queue)
- mock_raw_message = self.channel._get.return_value.content
- self.channel.Message.assert_called_once_with(
- self.channel, mock_raw_message,
- )
-
- def test_channel_basic_get_acknowledges_message_by_default(self):
- mock_queue = Mock()
- self.channel.basic_get(mock_queue)
- mock_qpid_message = self.channel._get.return_value
- acknowledge = self.transport.session.acknowledge
- acknowledge.assert_called_once_with(message=mock_qpid_message)
-
- def test_channel_basic_get_acknowledges_message_with_no_ack_False(self):
- mock_queue = Mock()
- self.channel.basic_get(mock_queue, no_ack=False)
- mock_qpid_message = self.channel._get.return_value
- acknowledge = self.transport.session.acknowledge
- acknowledge.assert_called_once_with(message=mock_qpid_message)
-
- def test_channel_basic_get_acknowledges_message_with_no_ack_True(self):
- mock_queue = Mock()
- self.channel.basic_get(mock_queue, no_ack=True)
- mock_qpid_message = self.channel._get.return_value
- acknowledge = self.transport.session.acknowledge
- acknowledge.assert_called_once_with(message=mock_qpid_message)
-
- def test_channel_basic_get_returns_correct_message(self):
- mock_queue = Mock()
- basic_get_result = self.channel.basic_get(mock_queue)
- expected_message = self.channel.Message.return_value
- self.assertIs(expected_message, basic_get_result)
-
- def test_basic_get_returns_None_when_channel__get_raises_Empty(self):
- mock_queue = Mock()
- self.channel._get = Mock(side_effect=Empty)
- basic_get_result = self.channel.basic_get(mock_queue)
- self.assertEqual(self.channel.Message.call_count, 0)
- self.assertIsNone(basic_get_result)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelBasicCancel(ChannelTestBase):
-
- def setUp(self):
- super(TestChannelBasicCancel, self).setUp()
- self.channel._receivers = {1: Mock()}
-
- def test_channel_basic_cancel_no_error_if_consumer_tag_not_found(self):
- self.channel.basic_cancel(2)
-
- def test_channel_basic_cancel_pops_receiver(self):
- self.channel.basic_cancel(1)
- self.assertNotIn(1, self.channel._receivers)
-
- def test_channel_basic_cancel_closes_receiver(self):
- mock_receiver = self.channel._receivers[1]
- self.channel.basic_cancel(1)
- mock_receiver.close.assert_called_once_with()
-
- def test_channel_basic_cancel_pops__tag_to_queue(self):
- self.channel._tag_to_queue = Mock()
- self.channel.basic_cancel(1)
- self.channel._tag_to_queue.pop.assert_called_once_with(1, None)
-
- def test_channel_basic_cancel_pops_connection__callbacks(self):
- self.channel._tag_to_queue = Mock()
- self.channel.basic_cancel(1)
- mock_queue = self.channel._tag_to_queue.pop.return_value
- self.conn._callbacks.pop.assert_called_once_with(mock_queue, None)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelInit(ChannelTestBase, ExtraAssertionsMixin):
-
- def test_channel___init__sets_variables_as_expected(self):
- self.assertIs(self.conn, self.channel.connection)
- self.assertIs(self.transport, self.channel.transport)
- self.assertFalse(self.channel.closed)
- self.conn.get_qpid_connection.assert_called_once_with()
- expected_broker_agent = self.mock_broker_agent.return_value
- self.assertIs(self.channel._broker, expected_broker_agent)
- self.assertDictEqual(self.channel._tag_to_queue, {})
- self.assertDictEqual(self.channel._receivers, {})
- self.assertIs(self.channel._qos, None)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelBasicConsume(ChannelTestBase, ExtraAssertionsMixin):
-
- def setUp(self):
- super(TestChannelBasicConsume, self).setUp()
- self.conn._callbacks = {}
-
- def test_channel_basic_consume_adds_queue_to__tag_to_queue(self):
- mock_tag = Mock()
- mock_queue = Mock()
- self.channel.basic_consume(mock_queue, Mock(), Mock(), mock_tag)
- expected_dict = {mock_tag: mock_queue}
- self.assertDictEqual(expected_dict, self.channel._tag_to_queue)
-
- def test_channel_basic_consume_adds_entry_to_connection__callbacks(self):
- mock_queue = Mock()
- self.channel.basic_consume(mock_queue, Mock(), Mock(), Mock())
- self.assertIn(mock_queue, self.conn._callbacks)
- self.assertIsInstance(self.conn._callbacks[mock_queue], Callable)
-
- def test_channel_basic_consume_creates_new_receiver(self):
- mock_queue = Mock()
- self.channel.basic_consume(mock_queue, Mock(), Mock(), Mock())
- self.transport.session.receiver.assert_called_once_with(mock_queue)
-
- def test_channel_basic_consume_saves_new_receiver(self):
- mock_tag = Mock()
- self.channel.basic_consume(Mock(), Mock(), Mock(), mock_tag)
- new_mock_receiver = self.transport.session.receiver.return_value
- expected_dict = {mock_tag: new_mock_receiver}
- self.assertDictEqual(expected_dict, self.channel._receivers)
-
- def test_channel_basic_consume_sets_capacity_on_new_receiver(self):
- mock_prefetch_count = Mock()
- self.channel.qos.prefetch_count = mock_prefetch_count
- self.channel.basic_consume(Mock(), Mock(), Mock(), Mock())
- new_receiver = self.transport.session.receiver.return_value
- self.assertTrue(new_receiver.capacity is mock_prefetch_count)
-
- def get_callback(self, no_ack=Mock(), original_cb=Mock()):
- self.channel.Message = Mock()
- mock_queue = Mock()
- self.channel.basic_consume(mock_queue, no_ack, original_cb, Mock())
- return self.conn._callbacks[mock_queue]
-
- def test_channel_basic_consume_callback_creates_Message_correctly(self):
- callback = self.get_callback()
- mock_qpid_message = Mock()
- callback(mock_qpid_message)
- mock_content = mock_qpid_message.content
- self.channel.Message.assert_called_once_with(
- self.channel, mock_content,
- )
-
- def test_channel_basic_consume_callback_adds_message_to_QoS(self):
- self.channel._qos = Mock()
- callback = self.get_callback()
- mock_qpid_message = Mock()
- callback(mock_qpid_message)
- mock_delivery_tag = self.channel.Message.return_value.delivery_tag
- self.channel._qos.append.assert_called_once_with(
- mock_qpid_message, mock_delivery_tag,
- )
-
- def test_channel_basic_consume_callback_gratuitously_acks(self):
- self.channel.basic_ack = Mock()
- callback = self.get_callback()
- mock_qpid_message = Mock()
- callback(mock_qpid_message)
- mock_delivery_tag = self.channel.Message.return_value.delivery_tag
- self.channel.basic_ack.assert_called_once_with(mock_delivery_tag)
-
- def test_channel_basic_consume_callback_does_not_ack_when_needed(self):
- self.channel.basic_ack = Mock()
- callback = self.get_callback(no_ack=False)
- mock_qpid_message = Mock()
- callback(mock_qpid_message)
- self.assertFalse(self.channel.basic_ack.called)
-
- def test_channel_basic_consume_callback_calls_real_callback(self):
- self.channel.basic_ack = Mock()
- mock_original_callback = Mock()
- callback = self.get_callback(original_cb=mock_original_callback)
- mock_qpid_message = Mock()
- callback(mock_qpid_message)
- expected_message = self.channel.Message.return_value
- mock_original_callback.assert_called_once_with(expected_message)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannelQueueDelete(ChannelTestBase):
-
- def setUp(self):
- super(TestChannelQueueDelete, self).setUp()
- self.patch__has_queue = patch.object(self.channel, '_has_queue')
- self.mock__has_queue = self.patch__has_queue.start()
- self.patch__size = patch.object(self.channel, '_size')
- self.mock__size = self.patch__size.start()
- self.patch__delete = patch.object(self.channel, '_delete')
- self.mock__delete = self.patch__delete.start()
- self.mock_queue = Mock()
-
- def tearDown(self):
- self.patch__has_queue.stop()
- self.patch__size.stop()
- self.patch__delete.stop()
- super(TestChannelQueueDelete, self).tearDown()
-
- def test_checks_if_queue_exists(self):
- self.channel.queue_delete(self.mock_queue)
- self.mock__has_queue.assert_called_once_with(self.mock_queue)
-
- def test_does_nothing_if_queue_does_not_exist(self):
- self.mock__has_queue.return_value = False
- self.channel.queue_delete(self.mock_queue)
- self.assertFalse(self.mock__delete.called)
-
- def test_not_empty_and_if_empty_True_no_delete(self):
- self.mock__size.return_value = 1
- self.channel.queue_delete(self.mock_queue, if_empty=True)
- mock_broker = self.mock_broker_agent.return_value
- self.assertFalse(mock_broker.getQueue.called)
-
- def test_calls_get_queue(self):
- self.channel.queue_delete(self.mock_queue)
- getQueue = self.mock_broker_agent.return_value.getQueue
- getQueue.assert_called_once_with(self.mock_queue)
-
- def test_gets_queue_attribute(self):
- self.channel.queue_delete(self.mock_queue)
- queue_obj = self.mock_broker_agent.return_value.getQueue.return_value
- queue_obj.getAttributes.assert_called_once_with()
-
- def test_queue_in_use_and_if_unused_no_delete(self):
- queue_obj = self.mock_broker_agent.return_value.getQueue.return_value
- queue_obj.getAttributes.return_value = {'consumerCount': 1}
- self.channel.queue_delete(self.mock_queue, if_unused=True)
- self.assertFalse(self.mock__delete.called)
-
- def test_calls__delete_with_queue(self):
- self.channel.queue_delete(self.mock_queue)
- self.mock__delete.assert_called_once_with(self.mock_queue)
-
-
-@case_no_python3
-@case_no_pypy
-class TestChannel(ExtraAssertionsMixin, Case):
-
- @patch(QPID_MODULE + '.qpidtoollibs')
- def setUp(self, mock_qpidtoollibs):
- self.mock_connection = Mock()
- self.mock_qpid_connection = Mock()
- self.mock_qpid_session = Mock()
- self.mock_qpid_connection.session = Mock(
- return_value=self.mock_qpid_session,
- )
- self.mock_connection.get_qpid_connection = Mock(
- return_value=self.mock_qpid_connection,
- )
- self.mock_transport = Mock()
- self.mock_broker = Mock()
- self.mock_Message = Mock()
- self.mock_BrokerAgent = mock_qpidtoollibs.BrokerAgent
- self.mock_BrokerAgent.return_value = self.mock_broker
- self.my_channel = Channel(
- self.mock_connection, self.mock_transport,
- )
- self.my_channel.Message = self.mock_Message
-
- def test_verify_QoS_class_attribute(self):
- """Verify that the class attribute QoS refers to the QoS object"""
- self.assertIs(QoS, Channel.QoS)
-
- def test_verify_Message_class_attribute(self):
- """Verify that the class attribute Message refers to the Message
- object."""
- self.assertIs(Message, Channel.Message)
-
- def test_body_encoding_class_attribute(self):
- """Verify that the class attribute body_encoding is set to base64"""
- self.assertEqual('base64', Channel.body_encoding)
-
- def test_codecs_class_attribute(self):
- """Verify that the codecs class attribute has a correct key and
- value."""
- self.assertIsInstance(Channel.codecs, dict)
- self.assertIn('base64', Channel.codecs)
- self.assertIsInstance(Channel.codecs['base64'], Base64)
-
- def test_size(self):
- """Test getting the number of messages in a queue specified by
- name and returning them."""
- message_count = 5
- mock_queue = Mock()
- mock_queue_to_check = Mock()
- mock_queue_to_check.values = {'msgDepth': message_count}
- self.mock_broker.getQueue.return_value = mock_queue_to_check
- result = self.my_channel._size(mock_queue)
- self.mock_broker.getQueue.assert_called_with(mock_queue)
- self.assertEqual(message_count, result)
-
- def test_delete(self):
- """Test deleting a queue calls purge and delQueue with queue name."""
- mock_queue = Mock()
- self.my_channel._purge = Mock()
- result = self.my_channel._delete(mock_queue)
- self.my_channel._purge.assert_called_with(mock_queue)
- self.mock_broker.delQueue.assert_called_with(mock_queue)
- self.assertIsNone(result)
-
- def test_has_queue_true(self):
- """Test checking if a queue exists, and it does."""
- mock_queue = Mock()
- self.mock_broker.getQueue.return_value = True
- result = self.my_channel._has_queue(mock_queue)
- self.assertTrue(result)
-
- def test_has_queue_false(self):
- """Test checking if a queue exists, and it does not."""
- mock_queue = Mock()
- self.mock_broker.getQueue.return_value = False
- result = self.my_channel._has_queue(mock_queue)
- self.assertFalse(result)
-
- @patch('amqp.protocol.queue_declare_ok_t')
- def test_queue_declare_with_exception_raised(self,
- mock_queue_declare_ok_t):
- """Test declare_queue, where an exception is raised and silenced."""
- mock_queue = Mock()
- mock_passive = Mock()
- mock_durable = Mock()
- mock_exclusive = Mock()
- mock_auto_delete = Mock()
- mock_nowait = Mock()
- mock_arguments = Mock()
- mock_msg_count = Mock()
- mock_queue.startswith.return_value = False
- mock_queue.endswith.return_value = False
- options = {
- 'passive': mock_passive,
- 'durable': mock_durable,
- 'exclusive': mock_exclusive,
- 'auto-delete': mock_auto_delete,
- 'arguments': mock_arguments,
- }
- mock_consumer_count = Mock()
- mock_return_value = Mock()
- values_dict = {
- 'msgDepth': mock_msg_count,
- 'consumerCount': mock_consumer_count,
- }
- mock_queue_data = Mock()
- mock_queue_data.values = values_dict
- exception_to_raise = Exception('The foo object already exists.')
- self.mock_broker.addQueue.side_effect = exception_to_raise
- self.mock_broker.getQueue.return_value = mock_queue_data
- mock_queue_declare_ok_t.return_value = mock_return_value
- result = self.my_channel.queue_declare(
- mock_queue,
- passive=mock_passive,
- durable=mock_durable,
- exclusive=mock_exclusive,
- auto_delete=mock_auto_delete,
- nowait=mock_nowait,
- arguments=mock_arguments,
- )
- self.mock_broker.addQueue.assert_called_with(
- mock_queue, options=options,
- )
- mock_queue_declare_ok_t.assert_called_with(
- mock_queue, mock_msg_count, mock_consumer_count,
- )
- self.assertIs(mock_return_value, result)
-
- def test_queue_declare_set_ring_policy_for_celeryev(self):
- """Test declare_queue sets ring_policy for celeryev."""
- mock_queue = Mock()
- mock_queue.startswith.return_value = True
- mock_queue.endswith.return_value = False
- expected_default_options = {
- 'passive': False,
- 'durable': False,
- 'exclusive': False,
- 'auto-delete': True,
- 'arguments': None,
- 'qpid.policy_type': 'ring',
- }
- mock_msg_count = Mock()
- mock_consumer_count = Mock()
- values_dict = {
- 'msgDepth': mock_msg_count,
- 'consumerCount': mock_consumer_count,
- }
- mock_queue_data = Mock()
- mock_queue_data.values = values_dict
- self.mock_broker.addQueue.return_value = None
- self.mock_broker.getQueue.return_value = mock_queue_data
- self.my_channel.queue_declare(mock_queue)
- mock_queue.startswith.assert_called_with('celeryev')
- self.mock_broker.addQueue.assert_called_with(
- mock_queue, options=expected_default_options,
- )
-
- def test_queue_declare_set_ring_policy_for_pidbox(self):
- """Test declare_queue sets ring_policy for pidbox."""
- mock_queue = Mock()
- mock_queue.startswith.return_value = False
- mock_queue.endswith.return_value = True
- expected_default_options = {
- 'passive': False,
- 'durable': False,
- 'exclusive': False,
- 'auto-delete': True,
- 'arguments': None,
- 'qpid.policy_type': 'ring',
- }
- mock_msg_count = Mock()
- mock_consumer_count = Mock()
- values_dict = {
- 'msgDepth': mock_msg_count,
- 'consumerCount': mock_consumer_count,
- }
- mock_queue_data = Mock()
- mock_queue_data.values = values_dict
- self.mock_broker.addQueue.return_value = None
- self.mock_broker.getQueue.return_value = mock_queue_data
- self.my_channel.queue_declare(mock_queue)
- mock_queue.endswith.assert_called_with('pidbox')
- self.mock_broker.addQueue.assert_called_with(
- mock_queue, options=expected_default_options,
- )
-
- def test_queue_declare_ring_policy_not_set_as_expected(self):
- """Test declare_queue does not set ring_policy as expected."""
- mock_queue = Mock()
- mock_queue.startswith.return_value = False
- mock_queue.endswith.return_value = False
- expected_default_options = {
- 'passive': False,
- 'durable': False,
- 'exclusive': False,
- 'auto-delete': True,
- 'arguments': None,
- }
- mock_msg_count = Mock()
- mock_consumer_count = Mock()
- values_dict = {
- 'msgDepth': mock_msg_count,
- 'consumerCount': mock_consumer_count,
- }
- mock_queue_data = Mock()
- mock_queue_data.values = values_dict
- self.mock_broker.addQueue.return_value = None
- self.mock_broker.getQueue.return_value = mock_queue_data
- self.my_channel.queue_declare(mock_queue)
- mock_queue.startswith.assert_called_with('celeryev')
- mock_queue.endswith.assert_called_with('pidbox')
- self.mock_broker.addQueue.assert_called_with(
- mock_queue, options=expected_default_options,
- )
-
- def test_queue_declare_test_defaults(self):
- """Test declare_queue defaults."""
- mock_queue = Mock()
- mock_queue.startswith.return_value = False
- mock_queue.endswith.return_value = False
- expected_default_options = {
- 'passive': False,
- 'durable': False,
- 'exclusive': False,
- 'auto-delete': True,
- 'arguments': None,
- }
- mock_msg_count = Mock()
- mock_consumer_count = Mock()
- values_dict = {
- 'msgDepth': mock_msg_count,
- 'consumerCount': mock_consumer_count,
- }
- mock_queue_data = Mock()
- mock_queue_data.values = values_dict
- self.mock_broker.addQueue.return_value = None
- self.mock_broker.getQueue.return_value = mock_queue_data
- self.my_channel.queue_declare(mock_queue)
- self.mock_broker.addQueue.assert_called_with(
- mock_queue,
- options=expected_default_options,
- )
-
- def test_queue_declare_raises_exception_not_silenced(self):
- unique_exception = Exception('This exception should not be silenced')
- mock_queue = Mock()
- self.mock_broker.addQueue.side_effect = unique_exception
- with self.assertRaises(unique_exception.__class__):
- self.my_channel.queue_declare(mock_queue)
- self.mock_broker.addQueue.assert_called_once_with(
- mock_queue,
- options={
- 'exclusive': False,
- 'durable': False,
- 'qpid.policy_type': 'ring',
- 'passive': False,
- 'arguments': None,
- 'auto-delete': True
- })
-
- def test_exchange_declare_raises_exception_and_silenced(self):
- """Create exchange where an exception is raised and then silenced"""
- self.mock_broker.addExchange.side_effect = Exception(
- 'The foo object already exists.',
- )
- self.my_channel.exchange_declare()
-
- def test_exchange_declare_raises_exception_not_silenced(self):
- """Create Exchange where an exception is raised and not silenced."""
- unique_exception = Exception('This exception should not be silenced')
- self.mock_broker.addExchange.side_effect = unique_exception
- with self.assertRaises(unique_exception.__class__):
- self.my_channel.exchange_declare()
-
- def test_exchange_declare(self):
- """Create Exchange where an exception is NOT raised."""
- mock_exchange = Mock()
- mock_type = Mock()
- mock_durable = Mock()
- options = {'durable': mock_durable}
- result = self.my_channel.exchange_declare(
- mock_exchange, mock_type, mock_durable,
- )
- self.mock_broker.addExchange.assert_called_with(
- mock_type, mock_exchange, options,
- )
- self.assertIsNone(result)
-
- def test_exchange_delete(self):
- """Test the deletion of an exchange by name."""
- mock_exchange = Mock()
- result = self.my_channel.exchange_delete(mock_exchange)
- self.mock_broker.delExchange.assert_called_with(mock_exchange)
- self.assertIsNone(result)
-
- def test_queue_bind(self):
- """Test binding a queue to an exchange using a routing key."""
- mock_queue = Mock()
- mock_exchange = Mock()
- mock_routing_key = Mock()
- self.my_channel.queue_bind(
- mock_queue, mock_exchange, mock_routing_key,
- )
- self.mock_broker.bind.assert_called_with(
- mock_exchange, mock_queue, mock_routing_key,
- )
-
- def test_queue_unbind(self):
- """Test unbinding a queue from an exchange using a routing key."""
- mock_queue = Mock()
- mock_exchange = Mock()
- mock_routing_key = Mock()
- self.my_channel.queue_unbind(
- mock_queue, mock_exchange, mock_routing_key,
- )
- self.mock_broker.unbind.assert_called_with(
- mock_exchange, mock_queue, mock_routing_key,
- )
-
- def test_queue_purge(self):
- """Test purging a queue by name."""
- mock_queue = Mock()
- purge_result = Mock()
- self.my_channel._purge = Mock(return_value=purge_result)
- result = self.my_channel.queue_purge(mock_queue)
- self.my_channel._purge.assert_called_with(mock_queue)
- self.assertIs(purge_result, result)
-
- @patch(QPID_MODULE + '.Channel.qos')
- def test_basic_ack(self, mock_qos):
- """Test that basic_ack calls the QoS object properly."""
- mock_delivery_tag = Mock()
- self.my_channel.basic_ack(mock_delivery_tag)
- mock_qos.ack.assert_called_with(mock_delivery_tag)
-
- @patch(QPID_MODULE + '.Channel.qos')
- def test_basic_reject(self, mock_qos):
- """Test that basic_reject calls the QoS object properly."""
- mock_delivery_tag = Mock()
- mock_requeue_value = Mock()
- self.my_channel.basic_reject(mock_delivery_tag, mock_requeue_value)
- mock_qos.reject.assert_called_with(
- mock_delivery_tag, requeue=mock_requeue_value,
- )
-
- def test_qos_manager_is_none(self):
- """Test the qos property if the QoS object did not already exist."""
- self.my_channel._qos = None
- result = self.my_channel.qos
- self.assertIsInstance(result, QoS)
- self.assertEqual(result, self.my_channel._qos)
-
- def test_qos_manager_already_exists(self):
- """Test the qos property if the QoS object already exists."""
- mock_existing_qos = Mock()
- self.my_channel._qos = mock_existing_qos
- result = self.my_channel.qos
- self.assertIs(mock_existing_qos, result)
-
- def test_prepare_message(self):
- """Test that prepare_message() returns the correct result."""
- mock_body = Mock()
- mock_priority = Mock()
- mock_content_encoding = Mock()
- mock_content_type = Mock()
- mock_header1 = Mock()
- mock_header2 = Mock()
- mock_properties1 = Mock()
- mock_properties2 = Mock()
- headers = {'header1': mock_header1, 'header2': mock_header2}
- properties = {'properties1': mock_properties1,
- 'properties2': mock_properties2}
- result = self.my_channel.prepare_message(
- mock_body,
- priority=mock_priority,
- content_type=mock_content_type,
- content_encoding=mock_content_encoding,
- headers=headers,
- properties=properties)
- self.assertIs(mock_body, result['body'])
- self.assertIs(mock_content_encoding, result['content-encoding'])
- self.assertIs(mock_content_type, result['content-type'])
- self.assertDictEqual(headers, result['headers'])
- self.assertDictContainsSubset(properties, result['properties'])
- self.assertIs(
- mock_priority, result['properties']['delivery_info']['priority'],
- )
-
- @patch('__builtin__.buffer')
- @patch(QPID_MODULE + '.Channel.body_encoding')
- @patch(QPID_MODULE + '.Channel.encode_body')
- @patch(QPID_MODULE + '.Channel._put')
- def test_basic_publish(self, mock_put,
- mock_encode_body,
- mock_body_encoding,
- mock_buffer):
- """Test basic_publish()."""
- mock_original_body = Mock()
- mock_encoded_body = 'this is my encoded body'
- mock_message = {'body': mock_original_body,
- 'properties': {'delivery_info': {}}}
- mock_encode_body.return_value = (
- mock_encoded_body, mock_body_encoding,
- )
- mock_exchange = Mock()
- mock_routing_key = Mock()
- mock_encoded_buffered_body = Mock()
- mock_buffer.return_value = mock_encoded_buffered_body
- self.my_channel.basic_publish(
- mock_message, mock_exchange, mock_routing_key,
- )
- mock_encode_body.assert_called_once_with(
- mock_original_body, mock_body_encoding,
- )
- mock_buffer.assert_called_once_with(mock_encoded_body)
- self.assertIs(mock_message['body'], mock_encoded_buffered_body)
- self.assertIs(
- mock_message['properties']['body_encoding'], mock_body_encoding,
- )
- self.assertIsInstance(
- mock_message['properties']['delivery_tag'], uuid.UUID,
- )
- self.assertIs(
- mock_message['properties']['delivery_info']['exchange'],
- mock_exchange,
- )
- self.assertIs(
- mock_message['properties']['delivery_info']['routing_key'],
- mock_routing_key,
- )
- mock_put.assert_called_with(
- mock_routing_key, mock_message, mock_exchange,
- )
-
- @patch(QPID_MODULE + '.Channel.codecs')
- def test_encode_body_expected_encoding(self, mock_codecs):
- """Test if encode_body() works when encoding is set correctly"""
- mock_body = Mock()
- mock_encoder = Mock()
- mock_encoded_result = Mock()
- mock_codecs.get.return_value = mock_encoder
- mock_encoder.encode.return_value = mock_encoded_result
- result = self.my_channel.encode_body(mock_body, encoding='base64')
- expected_result = (mock_encoded_result, 'base64')
- self.assertEqual(expected_result, result)
-
- @patch(QPID_MODULE + '.Channel.codecs')
- def test_encode_body_not_expected_encoding(self, mock_codecs):
- """Test if encode_body() works when encoding is not set correctly."""
- mock_body = Mock()
- result = self.my_channel.encode_body(mock_body, encoding=None)
- expected_result = mock_body, None
- self.assertEqual(expected_result, result)
-
- @patch(QPID_MODULE + '.Channel.codecs')
- def test_decode_body_expected_encoding(self, mock_codecs):
- """Test if decode_body() works when encoding is set correctly."""
- mock_body = Mock()
- mock_decoder = Mock()
- mock_decoded_result = Mock()
- mock_codecs.get.return_value = mock_decoder
- mock_decoder.decode.return_value = mock_decoded_result
- result = self.my_channel.decode_body(mock_body, encoding='base64')
- self.assertEqual(mock_decoded_result, result)
-
- @patch(QPID_MODULE + '.Channel.codecs')
- def test_decode_body_not_expected_encoding(self, mock_codecs):
- """Test if decode_body() works when encoding is not set correctly."""
- mock_body = Mock()
- result = self.my_channel.decode_body(mock_body, encoding=None)
- self.assertEqual(mock_body, result)
-
- def test_typeof_exchange_exists(self):
- """Test that typeof() finds an exchange that already exists."""
- mock_exchange = Mock()
- mock_qpid_exchange = Mock()
- mock_attributes = {}
- mock_type = Mock()
- mock_attributes['type'] = mock_type
- mock_qpid_exchange.getAttributes.return_value = mock_attributes
- self.mock_broker.getExchange.return_value = mock_qpid_exchange
- result = self.my_channel.typeof(mock_exchange)
- self.assertIs(mock_type, result)
-
- def test_typeof_exchange_does_not_exist(self):
- """Test that typeof() finds an exchange that does not exists."""
- mock_exchange = Mock()
- mock_default = Mock()
- self.mock_broker.getExchange.return_value = None
- result = self.my_channel.typeof(mock_exchange, default=mock_default)
- self.assertIs(mock_default, result)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportInit(Case):
-
- def setUp(self):
- self.patch_a = patch.object(Transport, 'verify_runtime_environment')
- self.mock_verify_runtime_environment = self.patch_a.start()
-
- self.patch_b = patch(QPID_MODULE + '.base.Transport.__init__')
- self.mock_base_Transport__init__ = self.patch_b.start()
-
- def tearDown(self):
- self.patch_a.stop()
- self.patch_b.stop()
-
- def test_Transport___init___calls_verify_runtime_environment(self):
- Transport(Mock())
- self.mock_verify_runtime_environment.assert_called_once_with()
-
- def test_transport___init___calls_parent_class___init__(self):
- m = Mock()
- Transport(m)
- self.mock_base_Transport__init__.assert_called_once_with(m)
-
- def test_transport___init___sets_use_async_interface_False(self):
- transport = Transport(Mock())
- self.assertFalse(transport.use_async_interface)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportDrainEvents(Case):
-
- def setUp(self):
- self.transport = Transport(Mock())
- self.transport.session = Mock()
- self.mock_queue = Mock()
- self.mock_message = Mock()
- self.mock_conn = Mock()
- self.mock_callback = Mock()
- self.mock_conn._callbacks = {self.mock_queue: self.mock_callback}
-
- def mock_next_receiver(self, timeout):
- time.sleep(0.3)
- mock_receiver = Mock()
- mock_receiver.source = self.mock_queue
- mock_receiver.fetch.return_value = self.mock_message
- return mock_receiver
-
- def test_socket_timeout_raised_when_all_receivers_empty(self):
- with patch(QPID_MODULE + '.QpidEmpty', new=QpidException):
- self.transport.session.next_receiver.side_effect = QpidException()
- with self.assertRaises(socket.timeout):
- self.transport.drain_events(Mock())
-
- def test_socket_timeout_raised_when_by_timeout(self):
- self.transport.session.next_receiver = self.mock_next_receiver
- with self.assertRaises(socket.timeout):
- self.transport.drain_events(self.mock_conn, timeout=1)
-
- def test_timeout_returns_no_earlier_then_asked_for(self):
- self.transport.session.next_receiver = self.mock_next_receiver
- start_time = monotonic()
- try:
- self.transport.drain_events(self.mock_conn, timeout=1)
- except socket.timeout:
- pass
- elapsed_time_in_s = monotonic() - start_time
- self.assertGreaterEqual(elapsed_time_in_s, 1.0)
-
- def test_callback_is_called(self):
- self.transport.session.next_receiver = self.mock_next_receiver
- try:
- self.transport.drain_events(self.mock_conn, timeout=1)
- except socket.timeout:
- pass
- self.mock_callback.assert_called_with(self.mock_message)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportCreateChannel(Case):
-
- def setUp(self):
- self.transport = Transport(Mock())
- self.mock_conn = Mock()
- self.mock_new_channel = Mock()
- self.mock_conn.Channel.return_value = self.mock_new_channel
- self.returned_channel = self.transport.create_channel(self.mock_conn)
-
- def test_new_channel_created_from_connection(self):
- self.assertIs(self.mock_new_channel, self.returned_channel)
- self.mock_conn.Channel.assert_called_with(
- self.mock_conn, self.transport,
- )
-
- def test_new_channel_added_to_connection_channel_list(self):
- append_method = self.mock_conn.channels.append
- append_method.assert_called_with(self.mock_new_channel)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportEstablishConnection(Case):
-
- def setUp(self):
-
- class MockClient(object):
- pass
-
- self.client = MockClient()
- self.client.connect_timeout = 4
- self.client.ssl = False
- self.client.transport_options = {}
- self.client.userid = None
- self.client.password = None
- self.client.login_method = None
- self.transport = Transport(self.client)
- self.mock_conn = Mock()
- self.transport.Connection = self.mock_conn
-
- def test_transport_establish_conn_new_option_overwrites_default(self):
- self.client.userid = 'new-userid'
- self.client.password = 'new-password'
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- username=self.client.userid,
- password=self.client.password,
- sasl_mechanisms='PLAIN',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_establish_conn_empty_client_is_default(self):
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- sasl_mechanisms='ANONYMOUS',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_establish_conn_additional_transport_option(self):
- new_param_value = 'mynewparam'
- self.client.transport_options['new_param'] = new_param_value
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- sasl_mechanisms='ANONYMOUS',
- host='localhost',
- timeout=4,
- new_param=new_param_value,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_establish_conn_transform_localhost_to_127_0_0_1(self):
- self.client.hostname = 'localhost'
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- sasl_mechanisms='ANONYMOUS',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_password_no_userid_raises_exception(self):
- self.client.password = 'somepass'
- self.assertRaises(Exception, self.transport.establish_connection)
-
- def test_transport_userid_no_password_raises_exception(self):
- self.client.userid = 'someusername'
- self.assertRaises(Exception, self.transport.establish_connection)
-
- def test_transport_overrides_sasl_mech_from_login_method(self):
- self.client.login_method = 'EXTERNAL'
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- sasl_mechanisms='EXTERNAL',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_overrides_sasl_mech_has_username(self):
- self.client.userid = 'new-userid'
- self.client.login_method = 'EXTERNAL'
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- username=self.client.userid,
- sasl_mechanisms='EXTERNAL',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_establish_conn_set_password(self):
- self.client.userid = 'someuser'
- self.client.password = 'somepass'
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- username='someuser',
- password='somepass',
- sasl_mechanisms='PLAIN',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_establish_conn_no_ssl_sets_transport_tcp(self):
- self.client.ssl = False
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- sasl_mechanisms='ANONYMOUS',
- host='localhost',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_establish_conn_with_ssl_with_hostname_check(self):
- self.client.ssl = {
- 'keyfile': 'my_keyfile',
- 'certfile': 'my_certfile',
- 'ca_certs': 'my_cacerts',
- 'cert_reqs': ssl.CERT_REQUIRED,
- }
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- ssl_certfile='my_certfile',
- ssl_trustfile='my_cacerts',
- timeout=4,
- ssl_skip_hostname_check=False,
- sasl_mechanisms='ANONYMOUS',
- host='localhost',
- ssl_keyfile='my_keyfile',
- port=5672, transport='ssl',
- )
-
- def test_transport_establish_conn_with_ssl_skip_hostname_check(self):
- self.client.ssl = {
- 'keyfile': 'my_keyfile',
- 'certfile': 'my_certfile',
- 'ca_certs': 'my_cacerts',
- 'cert_reqs': ssl.CERT_OPTIONAL,
- }
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- ssl_certfile='my_certfile',
- ssl_trustfile='my_cacerts',
- timeout=4,
- ssl_skip_hostname_check=True,
- sasl_mechanisms='ANONYMOUS',
- host='localhost',
- ssl_keyfile='my_keyfile',
- port=5672, transport='ssl',
- )
-
- def test_transport_establish_conn_sets_client_on_connection_object(self):
- self.transport.establish_connection()
- self.assertIs(self.mock_conn.return_value.client, self.client)
-
- def test_transport_establish_conn_creates_session_on_transport(self):
- self.transport.establish_connection()
- qpid_conn = self.mock_conn.return_value.get_qpid_connection
- new_mock_session = qpid_conn.return_value.session.return_value
- self.assertIs(self.transport.session, new_mock_session)
-
- def test_transport_establish_conn_returns_new_connection_object(self):
- new_conn = self.transport.establish_connection()
- self.assertIs(new_conn, self.mock_conn.return_value)
-
- def test_transport_establish_conn_uses_hostname_if_not_default(self):
- self.client.hostname = 'some_other_hostname'
- self.transport.establish_connection()
- self.mock_conn.assert_called_once_with(
- sasl_mechanisms='ANONYMOUS',
- host='some_other_hostname',
- timeout=4,
- port=5672,
- transport='tcp',
- )
-
- def test_transport_sets_qpid_message_ready_handler(self):
- self.transport.establish_connection()
- qpid_conn_call = self.mock_conn.return_value.get_qpid_connection
- mock_session = qpid_conn_call.return_value.session.return_value
- mock_set_callback = mock_session.set_message_received_notify_handler
- expected_msg_callback = self.transport._qpid_message_ready_handler
- mock_set_callback.assert_called_once_with(expected_msg_callback)
-
- def test_transport_sets_session_exception_handler(self):
- self.transport.establish_connection()
- qpid_conn_call = self.mock_conn.return_value.get_qpid_connection
- mock_session = qpid_conn_call.return_value.session.return_value
- mock_set_callback = mock_session.set_async_exception_notify_handler
- exc_callback = self.transport._qpid_async_exception_notify_handler
- mock_set_callback.assert_called_once_with(exc_callback)
-
- def test_transport_sets_connection_exception_handler(self):
- self.transport.establish_connection()
- qpid_conn_call = self.mock_conn.return_value.get_qpid_connection
- qpid_conn = qpid_conn_call.return_value
- mock_set_callback = qpid_conn.set_async_exception_notify_handler
- exc_callback = self.transport._qpid_async_exception_notify_handler
- mock_set_callback.assert_called_once_with(exc_callback)
-
-
-@case_no_python3
-@case_no_pypy
-class TestTransportClassAttributes(Case):
-
- def test_verify_Connection_attribute(self):
- self.assertIs(Connection, Transport.Connection)
-
- def test_verify_polling_disabled(self):
- self.assertIsNone(Transport.polling_interval)
-
- def test_transport_verify_supports_asynchronous_events(self):
- self.assertTrue(Transport.supports_ev)
-
- def test_verify_driver_type_and_name(self):
- self.assertEqual('qpid', Transport.driver_type)
- self.assertEqual('qpid', Transport.driver_name)
-
- def test_transport_verify_recoverable_connection_errors(self):
- connection_errors = Transport.recoverable_connection_errors
- self.assertIn(ConnectionError, connection_errors)
- self.assertIn(select.error, connection_errors)
-
- def test_transport_verify_recoverable_channel_errors(self):
- channel_errors = Transport.recoverable_channel_errors
- self.assertIn(NotFound, channel_errors)
-
- def test_transport_verify_pre_kombu_3_0_exception_labels(self):
- self.assertEqual(Transport.recoverable_channel_errors,
- Transport.channel_errors)
- self.assertEqual(Transport.recoverable_connection_errors,
- Transport.connection_errors)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportRegisterWithEventLoop(Case):
-
- def test_transport_register_with_event_loop_calls_add_reader(self):
- transport = Transport(Mock())
- mock_connection = Mock()
- mock_loop = Mock()
- transport.register_with_event_loop(mock_connection, mock_loop)
- mock_loop.add_reader.assert_called_with(
- transport.r, transport.on_readable, mock_connection, mock_loop,
- )
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportQpidCallbackHandlersAsync(Case):
-
- def setUp(self):
- self.patch_a = patch(QPID_MODULE + '.os.write')
- self.mock_os_write = self.patch_a.start()
- self.transport = Transport(Mock())
- self.transport.register_with_event_loop(Mock(), Mock())
-
- def tearDown(self):
- self.patch_a.stop()
-
- def test__qpid_message_ready_handler_writes_symbol_to_fd(self):
- self.transport._qpid_message_ready_handler(Mock())
- self.mock_os_write.assert_called_once_with(self.transport._w, '0')
-
- def test__qpid_async_exception_notify_handler_writes_symbol_to_fd(self):
- self.transport._qpid_async_exception_notify_handler(Mock(), Mock())
- self.mock_os_write.assert_called_once_with(self.transport._w, 'e')
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportQpidCallbackHandlersSync(Case):
-
- def setUp(self):
- self.patch_a = patch(QPID_MODULE + '.os.write')
- self.mock_os_write = self.patch_a.start()
- self.transport = Transport(Mock())
-
- def tearDown(self):
- self.patch_a.stop()
-
- def test__qpid_message_ready_handler_dows_not_write(self):
- self.transport._qpid_message_ready_handler(Mock())
- self.assertTrue(not self.mock_os_write.called)
-
- def test__qpid_async_exception_notify_handler_does_not_write(self):
- self.transport._qpid_async_exception_notify_handler(Mock(), Mock())
- self.assertTrue(not self.mock_os_write.called)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportOnReadable(Case):
-
- def setUp(self):
- self.patch_a = patch(QPID_MODULE + '.os.read')
- self.mock_os_read = self.patch_a.start()
-
- self.patch_b = patch.object(Transport, 'drain_events')
- self.mock_drain_events = self.patch_b.start()
- self.transport = Transport(Mock())
- self.transport.register_with_event_loop(Mock(), Mock())
-
- def tearDown(self):
- self.patch_a.stop()
- self.patch_b.stop()
-
- def test_transport_on_readable_reads_symbol_from_fd(self):
- self.transport.on_readable(Mock(), Mock())
- self.mock_os_read.assert_called_once_with(self.transport.r, 1)
-
- def test_transport_on_readable_calls_drain_events(self):
- mock_connection = Mock()
- self.transport.on_readable(mock_connection, Mock())
- self.mock_drain_events.assert_called_with(mock_connection)
-
- def test_transport_on_readable_catches_socket_timeout(self):
- self.mock_drain_events.side_effect = socket.timeout()
- self.transport.on_readable(Mock(), Mock())
-
- def test_transport_on_readable_ignores_non_socket_timeout_exception(self):
- self.mock_drain_events.side_effect = IOError()
- with self.assertRaises(IOError):
- self.transport.on_readable(Mock(), Mock())
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransportVerifyRuntimeEnvironment(Case):
-
- def setUp(self):
- self.verify_runtime_environment = Transport.verify_runtime_environment
- self.patch_a = patch.object(Transport, 'verify_runtime_environment')
- self.patch_a.start()
- self.transport = Transport(Mock())
-
- def tearDown(self):
- self.patch_a.stop()
-
- @patch(QPID_MODULE + '.PY3', new=True)
- def test_raises_exception_for_Python3(self):
- with self.assertRaises(RuntimeError):
- self.verify_runtime_environment(self.transport)
-
- @patch('__builtin__.getattr')
- def test_raises_exc_for_PyPy(self, mock_getattr):
- mock_getattr.return_value = True
- with self.assertRaises(RuntimeError):
- self.verify_runtime_environment(self.transport)
-
- @patch(QPID_MODULE + '.dependency_is_none')
- def test_raises_exc_dep_missing(self, mock_dep_is_none):
- mock_dep_is_none.return_value = True
- with self.assertRaises(RuntimeError):
- self.verify_runtime_environment(self.transport)
-
- @patch(QPID_MODULE + '.dependency_is_none')
- def test_calls_dependency_is_none(self, mock_dep_is_none):
- mock_dep_is_none.return_value = False
- self.verify_runtime_environment(self.transport)
- self.assertTrue(mock_dep_is_none.called)
-
- def test_raises_no_exception(self):
- self.verify_runtime_environment(self.transport)
-
-
-@case_no_python3
-@case_no_pypy
-@disable_runtime_dependency_check
-class TestTransport(ExtraAssertionsMixin, Case):
-
- def setUp(self):
- """Creates a mock_client to be used in testing."""
- self.mock_client = Mock()
-
- def test_close_connection(self):
- """Test that close_connection calls close on the connection."""
- my_transport = Transport(self.mock_client)
- mock_connection = Mock()
- my_transport.close_connection(mock_connection)
- mock_connection.close.assert_called_once_with()
-
- def test_default_connection_params(self):
- """Test that the default_connection_params are correct"""
- correct_params = {
- 'hostname': 'localhost',
- 'port': 5672,
- }
- my_transport = Transport(self.mock_client)
- result_params = my_transport.default_connection_params
- self.assertDictEqual(correct_params, result_params)
-
- @patch(QPID_MODULE + '.os.close')
- def test_del_sync(self, close):
- my_transport = Transport(self.mock_client)
- my_transport.__del__()
- self.assertFalse(close.called)
-
- @patch(QPID_MODULE + '.os.close')
- def test_del_async(self, close):
- my_transport = Transport(self.mock_client)
- my_transport.register_with_event_loop(Mock(), Mock())
- my_transport.__del__()
- self.assertTrue(close.called)
-
- @patch(QPID_MODULE + '.os.close')
- def test_del_async_failed(self, close):
- close.side_effect = OSError()
- my_transport = Transport(self.mock_client)
- my_transport.register_with_event_loop(Mock(), Mock())
- my_transport.__del__()
- self.assertTrue(close.called)
diff --git a/kombu/tests/transport/virtual/test_exchange.py b/kombu/tests/transport/virtual/test_exchange.py
deleted file mode 100644
index e5f5df39..00000000
--- a/kombu/tests/transport/virtual/test_exchange.py
+++ /dev/null
@@ -1,200 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu import Connection
-from kombu.transport.virtual import exchange
-
-from kombu.tests.case import Case, Mock
-from kombu.tests.mocks import Transport
-
-
-class ExchangeCase(Case):
- type = None
-
- def setup(self):
- if self.type:
- self.e = self.type(Connection(transport=Transport).channel())
-
-
-class test_Direct(ExchangeCase):
- type = exchange.DirectExchange
- table = [('rFoo', None, 'qFoo'),
- ('rFoo', None, 'qFox'),
- ('rBar', None, 'qBar'),
- ('rBaz', None, 'qBaz')]
-
- def test_lookup(self):
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'rFoo', None),
- {'qFoo', 'qFox'},
- )
- self.assertSetEqual(
- self.e.lookup(self.table, 'eMoz', 'rMoz', 'DEFAULT'),
- set(),
- )
- self.assertSetEqual(
- self.e.lookup(self.table, 'eBar', 'rBar', None),
- {'qBar'},
- )
-
-
-class test_Fanout(ExchangeCase):
- type = exchange.FanoutExchange
- table = [(None, None, 'qFoo'),
- (None, None, 'qFox'),
- (None, None, 'qBar')]
-
- def test_lookup(self):
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'rFoo', None),
- {'qFoo', 'qFox', 'qBar'},
- )
-
- def test_deliver_when_fanout_supported(self):
- self.e.channel = Mock()
- self.e.channel.supports_fanout = True
- message = Mock()
-
- self.e.deliver(message, 'exchange', 'rkey')
- self.e.channel._put_fanout.assert_called_with(
- 'exchange', message, 'rkey',
- )
-
- def test_deliver_when_fanout_unsupported(self):
- self.e.channel = Mock()
- self.e.channel.supports_fanout = False
-
- self.e.deliver(Mock(), 'exchange', None)
- self.e.channel._put_fanout.assert_not_called()
-
-
-class test_Topic(ExchangeCase):
- type = exchange.TopicExchange
- table = [
- ('stock.#', None, 'rFoo'),
- ('stock.us.*', None, 'rBar'),
- ]
-
- def setup(self):
- super(test_Topic, self).setup()
- self.table = [(rkey, self.e.key_to_pattern(rkey), queue)
- for rkey, _, queue in self.table]
-
- def test_prepare_bind(self):
- x = self.e.prepare_bind('qFoo', 'eFoo', 'stock.#', {})
- self.assertTupleEqual(x, ('stock.#', r'^stock\..*?$', 'qFoo'))
-
- def test_lookup(self):
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'stock.us.nasdaq', None),
- {'rFoo', 'rBar'},
- )
- self.assertTrue(self.e._compiled)
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'stock.europe.OSE', None),
- {'rFoo'},
- )
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'stockxeuropexOSE', None),
- set(),
- )
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo',
- 'candy.schleckpulver.snap_crackle', None),
- set(),
- )
-
- def test_deliver(self):
- self.e.channel = Mock()
- self.e.channel._lookup.return_value = ('a', 'b')
- message = Mock()
- self.e.deliver(message, 'exchange', 'rkey')
-
- expected = [(('a', message), {}),
- (('b', message), {})]
- self.assertListEqual(self.e.channel._put.call_args_list, expected)
-
-
-class test_TopicMultibind(ExchangeCase):
- # Testing message delivery in case of multiple overlapping
- # bindings for the same queue. As AMQP states, in case of
- # overlapping bindings, a message must be delivered once to
- # each matching queue.
- type = exchange.TopicExchange
- table = [
- ('stock', None, 'rFoo'),
- ('stock.#', None, 'rFoo'),
- ('stock.us.*', None, 'rFoo'),
- ('#', None, 'rFoo'),
- ]
-
- def setup(self):
- super(test_TopicMultibind, self).setup()
- self.table = [(rkey, self.e.key_to_pattern(rkey), queue)
- for rkey, _, queue in self.table]
-
- def test_lookup(self):
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'stock.us.nasdaq', None),
- {'rFoo'},
- )
- self.assertTrue(self.e._compiled)
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'stock.europe.OSE', None),
- {'rFoo'},
- )
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo', 'stockxeuropexOSE', None),
- {'rFoo'},
- )
- self.assertSetEqual(
- self.e.lookup(self.table, 'eFoo',
- 'candy.schleckpulver.snap_crackle', None),
- {'rFoo'},
- )
-
-
-class test_ExchangeType(ExchangeCase):
- type = exchange.ExchangeType
-
- def test_lookup(self):
- with self.assertRaises(NotImplementedError):
- self.e.lookup([], 'eFoo', 'rFoo', None)
-
- def test_prepare_bind(self):
- self.assertTupleEqual(
- self.e.prepare_bind('qFoo', 'eFoo', 'rFoo', {}),
- ('rFoo', None, 'qFoo'),
- )
-
- def test_equivalent(self):
- e1 = dict(
- type='direct',
- durable=True,
- auto_delete=True,
- arguments={},
- )
- self.assertTrue(
- self.e.equivalent(e1, 'eFoo', 'direct', True, True, {}),
- )
- self.assertFalse(
- self.e.equivalent(e1, 'eFoo', 'topic', True, True, {}),
- )
- self.assertFalse(
- self.e.equivalent(e1, 'eFoo', 'direct', False, True, {}),
- )
- self.assertFalse(
- self.e.equivalent(e1, 'eFoo', 'direct', True, False, {}),
- )
- self.assertFalse(
- self.e.equivalent(e1, 'eFoo', 'direct', True, True,
- {'expires': 3000}),
- )
- e2 = dict(e1, arguments={'expires': 3000})
- self.assertTrue(
- self.e.equivalent(e2, 'eFoo', 'direct', True, True,
- {'expires': 3000}),
- )
- self.assertFalse(
- self.e.equivalent(e2, 'eFoo', 'direct', True, True,
- {'expires': 6000}),
- )
diff --git a/kombu/tests/utils/test_div.py b/kombu/tests/utils/test_div.py
deleted file mode 100644
index 513cb1b7..00000000
--- a/kombu/tests/utils/test_div.py
+++ /dev/null
@@ -1,51 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-import pickle
-
-from io import StringIO, BytesIO
-
-from kombu.utils.div import emergency_dump_state
-
-from kombu.tests.case import Case, mock
-
-
-class MyStringIO(StringIO):
-
- def close(self):
- pass
-
-
-class MyBytesIO(BytesIO):
-
- def close(self):
- pass
-
-
-class test_emergency_dump_state(Case):
-
- @mock.stdouts
- def test_dump(self, stdout, stderr):
- fh = MyBytesIO()
-
- emergency_dump_state(
- {'foo': 'bar'}, open_file=lambda n, m: fh)
- self.assertDictEqual(
- pickle.loads(fh.getvalue()), {'foo': 'bar'})
- self.assertTrue(stderr.getvalue())
- self.assertFalse(stdout.getvalue())
-
- @mock.stdouts
- def test_dump_second_strategy(self, stdout, stderr):
- fh = MyStringIO()
-
- def raise_something(*args, **kwargs):
- raise KeyError('foo')
-
- emergency_dump_state(
- {'foo': 'bar'},
- open_file=lambda n, m: fh, dump=raise_something
- )
- self.assertIn('foo', fh.getvalue())
- self.assertIn('bar', fh.getvalue())
- self.assertTrue(stderr.getvalue())
- self.assertFalse(stdout.getvalue())
diff --git a/kombu/tests/utils/test_encoding.py b/kombu/tests/utils/test_encoding.py
deleted file mode 100644
index 67bf4ad4..00000000
--- a/kombu/tests/utils/test_encoding.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
-import sys
-
-from contextlib import contextmanager
-
-from kombu.five import bytes_t, string_t
-from kombu.utils.encoding import (
- get_default_encoding_file, safe_str,
- set_default_encoding_file, default_encoding,
-)
-
-from kombu.tests.case import Case, patch, skip
-
-
-@contextmanager
-def clean_encoding():
- old_encoding = sys.modules.pop('kombu.utils.encoding', None)
- import kombu.utils.encoding
- try:
- yield kombu.utils.encoding
- finally:
- if old_encoding:
- sys.modules['kombu.utils.encoding'] = old_encoding
-
-
-class test_default_encoding(Case):
-
- def test_set_default_file(self):
- prev = get_default_encoding_file()
- try:
- set_default_encoding_file('/foo.txt')
- self.assertEqual(get_default_encoding_file(), '/foo.txt')
- finally:
- set_default_encoding_file(prev)
-
- @patch('sys.getfilesystemencoding')
- def test_default(self, getdefaultencoding):
- getdefaultencoding.return_value = 'ascii'
- with clean_encoding() as encoding:
- enc = encoding.default_encoding()
- if sys.platform.startswith('java'):
- self.assertEqual(enc, 'utf-8')
- else:
- self.assertEqual(enc, 'ascii')
- getdefaultencoding.assert_called_with()
-
-
-@skip.if_python3
-class test_encoding_utils(Case):
-
- def test_str_to_bytes(self):
- with clean_encoding() as e:
- self.assertIsInstance(e.str_to_bytes('foobar'), bytes_t)
-
- def test_from_utf8(self):
- with clean_encoding() as e:
- self.assertIsInstance(e.from_utf8('foobar'), bytes_t)
-
- def test_default_encode(self):
- with clean_encoding() as e:
- self.assertTrue(e.default_encode(b'foo'))
-
-
-class test_safe_str(Case):
-
- def setup(self):
- self._cencoding = patch('sys.getfilesystemencoding')
- self._encoding = self._cencoding.__enter__()
- self._encoding.return_value = 'ascii'
-
- def teardown(self):
- self._cencoding.__exit__()
-
- def test_when_bytes(self):
- self.assertEqual(safe_str('foo'), 'foo')
-
- def test_when_unicode(self):
- self.assertIsInstance(safe_str('foo'), string_t)
-
- def test_when_encoding_utf8(self):
- with patch('sys.getfilesystemencoding') as encoding:
- encoding.return_value = 'utf-8'
- self.assertEqual(default_encoding(), 'utf-8')
- s = 'The quiæk fåx jømps øver the lazy dåg'
- res = safe_str(s)
- self.assertIsInstance(res, str)
-
- def test_when_containing_high_chars(self):
- with patch('sys.getfilesystemencoding') as encoding:
- encoding.return_value = 'ascii'
- s = 'The quiæk fåx jømps øver the lazy dåg'
- res = safe_str(s)
- self.assertIsInstance(res, str)
- self.assertEqual(len(s), len(res))
-
- def test_when_not_string(self):
- o = object()
- self.assertEqual(safe_str(o), repr(o))
-
- def test_when_unrepresentable(self):
-
- class O(object):
-
- def __repr__(self):
- raise KeyError('foo')
-
- self.assertIn('<Unrepresentable', safe_str(O()))
diff --git a/kombu/tests/utils/test_scheduling.py b/kombu/tests/utils/test_scheduling.py
deleted file mode 100644
index ed668714..00000000
--- a/kombu/tests/utils/test_scheduling.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu.utils.scheduling import FairCycle, cycle_by_name
-
-from kombu.tests.case import Case
-
-
-class MyEmpty(Exception):
- pass
-
-
-def consume(fun, n):
- r = []
- for i in range(n):
- r.append(fun())
- return r
-
-
-class test_FairCycle(Case):
-
- def test_cycle(self):
- resources = ['a', 'b', 'c', 'd', 'e']
-
- def echo(r, timeout=None):
- return r
-
- # cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat]
- cycle = FairCycle(echo, resources, MyEmpty)
- for i in range(len(resources)):
- self.assertEqual(cycle.get(), (resources[i],
- resources[i]))
- for i in range(len(resources)):
- self.assertEqual(cycle.get(), (resources[i],
- resources[i]))
-
- def test_cycle_breaks(self):
- resources = ['a', 'b', 'c', 'd', 'e']
-
- def echo(r):
- if r == 'c':
- raise MyEmpty(r)
- return r
-
- cycle = FairCycle(echo, resources, MyEmpty)
- self.assertEqual(
- consume(cycle.get, len(resources)),
- [('a', 'a'), ('b', 'b'), ('d', 'd'),
- ('e', 'e'), ('a', 'a')],
- )
- self.assertEqual(
- consume(cycle.get, len(resources)),
- [('b', 'b'), ('d', 'd'), ('e', 'e'),
- ('a', 'a'), ('b', 'b')],
- )
- cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty)
- with self.assertRaises(MyEmpty):
- consume(cycle2.get, 3)
-
- def test_cycle_no_resources(self):
- cycle = FairCycle(None, [], MyEmpty)
- cycle.pos = 10
-
- with self.assertRaises(MyEmpty):
- cycle._next()
-
- def test__repr__(self):
- self.assertTrue(repr(FairCycle(lambda x: x, [1, 2, 3], MyEmpty)))
-
-
-class test_round_robin_cycle(Case):
-
- def test_round_robin_cycle(self):
- it = cycle_by_name('round_robin')(['A', 'B', 'C'])
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('B')
- self.assertListEqual(it.consume(3), ['A', 'C', 'B'])
- it.rotate('A')
- self.assertListEqual(it.consume(3), ['C', 'B', 'A'])
- it.rotate('A')
- self.assertListEqual(it.consume(3), ['C', 'B', 'A'])
- it.rotate('C')
- self.assertListEqual(it.consume(3), ['B', 'A', 'C'])
-
-
-class test_priority_cycle(Case):
-
- def test_priority_cycle(self):
- it = cycle_by_name('priority')(['A', 'B', 'C'])
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('B')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('A')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('A')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('C')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
-
-
-class test_sorted_cycle(Case):
-
- def test_sorted_cycle(self):
- it = cycle_by_name('sorted')(['B', 'C', 'A'])
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('B')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('A')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('A')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
- it.rotate('C')
- self.assertListEqual(it.consume(3), ['A', 'B', 'C'])
diff --git a/kombu/tests/utils/test_url.py b/kombu/tests/utils/test_url.py
deleted file mode 100644
index 67c7efce..00000000
--- a/kombu/tests/utils/test_url.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu.utils.url import as_url, parse_url, maybe_sanitize_url
-
-from kombu.tests.case import Case
-
-
-class test_parse_url(Case):
-
- def test_parse_url(self):
- result = parse_url('amqp://user:pass@localhost:5672/my/vhost')
- self.assertDictEqual(result, {
- 'transport': 'amqp',
- 'userid': 'user',
- 'password': 'pass',
- 'hostname': 'localhost',
- 'port': 5672,
- 'virtual_host': 'my/vhost',
- })
-
-
-class test_as_url(Case):
-
- def test_as_url(self):
- self.assertEqual(as_url('https'), 'https:///')
- self.assertEqual(as_url('https', 'e.com'), 'https://e.com/')
- self.assertEqual(as_url('https', 'e.com', 80), 'https://e.com:80/')
- self.assertEqual(
- as_url('https', 'e.com', 80, 'u'), 'https://u@e.com:80/',
- )
- self.assertEqual(
- as_url('https', 'e.com', 80, 'u', 'p'), 'https://u:p@e.com:80/',
- )
- self.assertEqual(
- as_url('https', 'e.com', 80, None, 'p'), 'https://:p@e.com:80/',
- )
- self.assertEqual(
- as_url('https', 'e.com', 80, None, 'p', '/foo'),
- 'https://:p@e.com:80//foo',
- )
-
-
-class test_maybe_sanitize_url(Case):
-
- def test_maybe_sanitize_url(self):
- self.assertEqual(maybe_sanitize_url('foo'), 'foo')
- self.assertEqual(
- maybe_sanitize_url('http://u:p@e.com//foo'),
- 'http://u:**@e.com//foo',
- )
diff --git a/kombu/tests/utils/test_utils.py b/kombu/tests/utils/test_utils.py
deleted file mode 100644
index 84adf6c9..00000000
--- a/kombu/tests/utils/test_utils.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from __future__ import absolute_import, unicode_literals
-
-from kombu import version_info_t
-from kombu.utils.text import version_string_as_tuple
-
-from kombu.tests.case import Case
-
-
-class test_kombu_module(Case):
-
- def test_dir(self):
- import kombu
- self.assertTrue(dir(kombu))
-
-
-class test_version_string_as_tuple(Case):
-
- def test_versions(self):
- self.assertTupleEqual(
- version_string_as_tuple('3'),
- version_info_t(3, 0, 0, '', ''),
- )
- self.assertTupleEqual(
- version_string_as_tuple('3.3'),
- version_info_t(3, 3, 0, '', ''),
- )
- self.assertTupleEqual(
- version_string_as_tuple('3.3.1'),
- version_info_t(3, 3, 1, '', ''),
- )
- self.assertTupleEqual(
- version_string_as_tuple('3.3.1a3'),
- version_info_t(3, 3, 1, 'a3', ''),
- )
- self.assertTupleEqual(
- version_string_as_tuple('3.3.1a3-40c32'),
- version_info_t(3, 3, 1, 'a3', '40c32'),
- )
- self.assertEqual(
- version_string_as_tuple('3.3.1.a3.40c32'),
- version_info_t(3, 3, 1, 'a3', '40c32'),
- )
diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py
index b6a818fe..d7546a60 100644
--- a/kombu/transport/redis.py
+++ b/kombu/transport/redis.py
@@ -1073,15 +1073,14 @@ class SentinelChannel(Channel):
additional_params = connparams.copy()
- del additional_params['host']
- del additional_params['port']
+ additional_params.pop('host', None)
+ additional_params.pop('port', None)
sentinel_inst = sentinel.Sentinel(
[(connparams['host'], connparams['port'])],
min_other_sentinels=getattr(self, 'min_other_sentinels', 0),
sentinel_kwargs=getattr(self, 'sentinel_kwargs', {}),
- **additional_params
- )
+ **additional_params)
master_name = getattr(self, 'master_name', None)
diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py
index ccfab350..6960104c 100644
--- a/kombu/transport/virtual/__init__.py
+++ b/kombu/transport/virtual/__init__.py
@@ -1,988 +1,13 @@
-"""Virtual transport implementation.
-
-Emulates the AMQ API for non-AMQ transports.
-"""
-from __future__ import absolute_import, print_function, unicode_literals
-
-import base64
-import socket
-import sys
-import warnings
-
-from array import array
-from collections import OrderedDict, defaultdict, namedtuple
-from itertools import count
-from multiprocessing.util import Finalize
-from time import sleep
-
-from amqp.protocol import queue_declare_ok_t
-
-from kombu.exceptions import ResourceError, ChannelError
-from kombu.five import Empty, items, monotonic
-from kombu.log import get_logger
-from kombu.utils.encoding import str_to_bytes, bytes_to_str
-from kombu.utils.div import emergency_dump_state
-from kombu.utils.scheduling import FairCycle
-from kombu.utils.uuid import uuid
-
-from kombu.transport import base
-
-from .exchange import STANDARD_EXCHANGE_TYPES
-
-ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H'
-
-UNDELIVERABLE_FMT = """\
-Message could not be delivered: No queues bound to exchange {exchange!r} \
-using binding key {routing_key!r}.
-"""
-
-NOT_EQUIVALENT_FMT = """\
-Cannot redeclare exchange {0!r} in vhost {1!r} with \
-different type, durable, autodelete or arguments value.\
-"""
-
-W_NO_CONSUMERS = """\
-Requeuing undeliverable message for queue %r: No consumers.\
-"""
-
-RESTORING_FMT = 'Restoring {0!r} unacknowledged message(s)'
-RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}'
-
-logger = get_logger(__name__)
-
-#: Key format used for queue argument lookups in BrokerState.bindings.
-binding_key_t = namedtuple('binding_key_t', (
- 'queue', 'exchange', 'routing_key',
-))
-
-#: BrokerState.queue_bindings generates tuples in this format.
-queue_binding_t = namedtuple('queue_binding_t', (
- 'exchange', 'routing_key', 'arguments',
-))
-
-
-class Base64(object):
-
- def encode(self, s):
- return bytes_to_str(base64.b64encode(str_to_bytes(s)))
-
- def decode(self, s):
- return base64.b64decode(str_to_bytes(s))
-
-
-class NotEquivalentError(Exception):
- """Entity declaration is not equivalent to the previous declaration."""
- pass
-
-
-class UndeliverableWarning(UserWarning):
- """The message could not be delivered to a queue."""
- pass
-
-
-class BrokerState(object):
-
- #: Mapping of exchange name to
- #: :class:`kombu.transport.virtual.exchange.ExchangeType`
- exchanges = None
-
- #: This is the actual bindings registry, used to store bindings and to
- #: test 'in' relationships in constant time. It has the following
- #: structure::
- #:
- #: {
- #: (queue, exchange, routing_key): arguments,
- #: # ...,
- #: }
- bindings = None
-
- #: The queue index is used to access directly (constant time)
- #: all the bindings of a certain queue. It has the following structure::
- #: {
- #: queue: {
- #: (queue, exchange, routing_key),
- #: # ...,
- #: },
- #: # ...,
- #: }
- queue_index = None
-
- def __init__(self, exchanges=None):
- self.exchanges = {} if exchanges is None else exchanges
- self.bindings = {}
- self.queue_index = defaultdict(set)
-
- def clear(self):
- self.exchanges.clear()
- self.bindings.clear()
- self.queue_index.clear()
-
- def has_binding(self, queue, exchange, routing_key):
- return (queue, exchange, routing_key) in self.bindings
-
- def binding_declare(self, queue, exchange, routing_key, arguments):
- key = binding_key_t(queue, exchange, routing_key)
- self.bindings.setdefault(key, arguments)
- self.queue_index[queue].add(key)
-
- def binding_delete(self, queue, exchange, routing_key):
- key = binding_key_t(queue, exchange, routing_key)
- try:
- del self.bindings[key]
- except KeyError:
- pass
- else:
- self.queue_index[queue].remove(key)
-
- def queue_bindings_delete(self, queue):
- try:
- bindings = self.queue_index.pop(queue)
- except KeyError:
- pass
- else:
- [self.bindings.pop(binding, None) for binding in bindings]
-
- def queue_bindings(self, queue):
- return (
- queue_binding_t(key.exchange, key.routing_key, self.bindings[key])
- for key in self.queue_index[queue]
- )
-
-
-class QoS(object):
- """Quality of Service guarantees.
-
- Only supports `prefetch_count` at this point.
-
- Arguments:
- channel (ChannelT): Connection channel.
- prefetch_count (int): Initial prefetch count (defaults to 0).
- """
-
- #: current prefetch count value
- prefetch_count = 0
-
- #: :class:`~collections.OrderedDict` of active messages.
- #: *NOTE*: Can only be modified by the consuming thread.
- _delivered = None
-
- #: acks can be done by other threads than the consuming thread.
- #: Instead of a mutex, which doesn't perform well here, we mark
- #: the delivery tags as dirty, so subsequent calls to append() can remove
- #: them.
- _dirty = None
-
- #: If disabled, unacked messages won't be restored at shutdown.
- restore_at_shutdown = True
-
- def __init__(self, channel, prefetch_count=0):
- self.channel = channel
- self.prefetch_count = prefetch_count or 0
-
- self._delivered = OrderedDict()
- self._delivered.restored = False
- self._dirty = set()
- self._quick_ack = self._dirty.add
- self._quick_append = self._delivered.__setitem__
- self._on_collect = Finalize(
- self, self.restore_unacked_once, exitpriority=1,
- )
-
- def can_consume(self):
- """Return true if the channel can be consumed from.
-
- Used to ensure the client adhers to currently active
- prefetch limits.
- """
- pcount = self.prefetch_count
- return not pcount or len(self._delivered) - len(self._dirty) < pcount
-
- def can_consume_max_estimate(self):
- """Returns the maximum number of messages allowed to be returned.
-
- Returns an estimated number of messages that a consumer may be allowed
- to consume at once from the broker. This is used for services where
- bulk 'get message' calls are preferred to many individual 'get message'
- calls - like SQS.
-
- Returns:
- int: greater than zero.
- """
- pcount = self.prefetch_count
- if pcount:
- return max(pcount - (len(self._delivered) - len(self._dirty)), 0)
-
- def append(self, message, delivery_tag):
- """Append message to transactional state."""
- if self._dirty:
- self._flush()
- self._quick_append(delivery_tag, message)
-
- def get(self, delivery_tag):
- return self._delivered[delivery_tag]
-
- def _flush(self):
- """Flush dirty (acked/rejected) tags from."""
- dirty = self._dirty
- delivered = self._delivered
- while 1:
- try:
- dirty_tag = dirty.pop()
- except KeyError:
- break
- delivered.pop(dirty_tag, None)
-
- def ack(self, delivery_tag):
- """Acknowledge message and remove from transactional state."""
- self._quick_ack(delivery_tag)
-
- def reject(self, delivery_tag, requeue=False):
- """Remove from transactional state and requeue message."""
- if requeue:
- self.channel._restore_at_beginning(self._delivered[delivery_tag])
- self._quick_ack(delivery_tag)
-
- def restore_unacked(self):
- """Restore all unacknowledged messages."""
- self._flush()
- delivered = self._delivered
- errors = []
- restore = self.channel._restore
- pop_message = delivered.popitem
-
- while delivered:
- try:
- _, message = pop_message()
- except KeyError: # pragma: no cover
- break
-
- try:
- restore(message)
- except BaseException as exc:
- errors.append((exc, message))
- delivered.clear()
- return errors
-
- def restore_unacked_once(self, stderr=None):
- """Restores all unacknowledged messages at shutdown/gc collect.
-
- Note:
- Can only be called once for each instance, subsequent
- calls will be ignored.
- """
- self._on_collect.cancel()
- self._flush()
- stderr = sys.stderr if stderr is None else stderr
- state = self._delivered
-
- if not self.restore_at_shutdown or not self.channel.do_restore:
- return
- if getattr(state, 'restored', None):
- assert not state
- return
- try:
- if state:
- print(RESTORING_FMT.format(len(self._delivered)),
- file=stderr)
- unrestored = self.restore_unacked()
-
- if unrestored:
- errors, messages = list(zip(*unrestored))
- print(RESTORE_PANIC_FMT.format(len(errors), errors),
- file=stderr)
- emergency_dump_state(messages, stderr=stderr)
- finally:
- state.restored = True
-
- def restore_visible(self, *args, **kwargs):
- """Restore any pending unackwnowledged messages for visibility_timeout
- style implementations.
-
- Note:
- This is implementation optional, and currently only
- used by the Redis transport.
- """
- pass
-
-
-class Message(base.Message):
-
- def __init__(self, channel, payload, **kwargs):
- self._raw = payload
- properties = payload['properties']
- body = payload.get('body')
- if body:
- body = channel.decode_body(body, properties.get('body_encoding'))
- kwargs.update({
- 'body': body,
- 'delivery_tag': properties['delivery_tag'],
- 'content_type': payload.get('content-type'),
- 'content_encoding': payload.get('content-encoding'),
- 'headers': payload.get('headers'),
- 'properties': properties,
- 'delivery_info': properties.get('delivery_info'),
- 'postencode': 'utf-8',
- })
- super(Message, self).__init__(channel, **kwargs)
-
- def serializable(self):
- props = self.properties
- body, _ = self.channel.encode_body(self.body,
- props.get('body_encoding'))
- headers = dict(self.headers)
- # remove compression header
- headers.pop('compression', None)
- return {
- 'body': body,
- 'properties': props,
- 'content-type': self.content_type,
- 'content-encoding': self.content_encoding,
- 'headers': headers,
- }
-
-
-class AbstractChannel(object):
- """This is an abstract class defining the channel methods
- you'd usually want to implement in a virtual channel.
-
- Note:
- Do not subclass directly, but rather inherit
- from :class:`Channel`.
- """
-
- def _get(self, queue, timeout=None):
- """Get next message from `queue`."""
- raise NotImplementedError('Virtual channels must implement _get')
-
- def _put(self, queue, message):
- """Put `message` onto `queue`."""
- raise NotImplementedError('Virtual channels must implement _put')
-
- def _purge(self, queue):
- """Remove all messages from `queue`."""
- raise NotImplementedError('Virtual channels must implement _purge')
-
- def _size(self, queue):
- """Return the number of messages in `queue` as an :class:`int`."""
- return 0
-
- def _delete(self, queue, *args, **kwargs):
- """Delete `queue`.
-
- Note:
- This just purges the queue, if you need to do more you can
- override this method.
- """
- self._purge(queue)
-
- def _new_queue(self, queue, **kwargs):
- """Create new queue.
-
- Note:
- Your transport can override this method if it needs
- to do something whenever a new queue is declared.
- """
- pass
-
- def _has_queue(self, queue, **kwargs):
- """Verify that queue exists.
-
- Returns:
- bool: Should return :const:`True` if the queue exists
- or :const:`False` otherwise.
- """
- return True
-
- def _poll(self, cycle, timeout=None):
- """Poll a list of queues for available messages."""
- return cycle.get()
-
-
-class Channel(AbstractChannel, base.StdChannel):
- """Virtual channel.
-
- Arguments:
- connection (ConnectionT): The transport instance this
- channel is part of.
- """
- #: message class used.
- Message = Message
-
- #: QoS class used.
- QoS = QoS
-
- #: flag to restore unacked messages when channel
- #: goes out of scope.
- do_restore = True
-
- #: mapping of exchange types and corresponding classes.
- exchange_types = dict(STANDARD_EXCHANGE_TYPES)
-
- #: flag set if the channel supports fanout exchanges.
- supports_fanout = False
-
- #: Binary <-> ASCII codecs.
- codecs = {'base64': Base64()}
-
- #: Default body encoding.
- #: NOTE: ``transport_options['body_encoding']`` will override this value.
- body_encoding = 'base64'
-
- #: counter used to generate delivery tags for this channel.
- _delivery_tags = count(1)
-
- #: Optional queue where messages with no route is delivered.
- #: Set by ``transport_options['deadletter_queue']``.
- deadletter_queue = None
-
- # List of options to transfer from :attr:`transport_options`.
- from_transport_options = ('body_encoding', 'deadletter_queue')
-
- # Priority defaults
- default_priority = 0
- min_priority = 0
- max_priority = 9
-
- def __init__(self, connection, **kwargs):
- self.connection = connection
- self._consumers = set()
- self._cycle = None
- self._tag_to_queue = {}
- self._active_queues = []
- self._qos = None
- self.closed = False
-
- # instantiate exchange types
- self.exchange_types = dict(
- (typ, cls(self)) for typ, cls in items(self.exchange_types)
- )
-
- try:
- self.channel_id = self.connection._avail_channel_ids.pop()
- except IndexError:
- raise ResourceError(
- 'No free channel ids, current={0}, channel_max={1}'.format(
- len(self.connection.channels),
- self.connection.channel_max), (20, 10),
- )
-
- topts = self.connection.client.transport_options
- for opt_name in self.from_transport_options:
- try:
- setattr(self, opt_name, topts[opt_name])
- except KeyError:
- pass
-
- def exchange_declare(self, exchange=None, type='direct', durable=False,
- auto_delete=False, arguments=None,
- nowait=False, passive=False):
- """Declare exchange."""
- type = type or 'direct'
- exchange = exchange or 'amq.%s' % type
- if passive:
- if exchange not in self.state.exchanges:
- raise ChannelError(
- 'NOT_FOUND - no exchange {0!r} in vhost {1!r}'.format(
- exchange, self.connection.client.virtual_host or '/'),
- (50, 10), 'Channel.exchange_declare', '404',
- )
- return
- try:
- prev = self.state.exchanges[exchange]
- if not self.typeof(exchange).equivalent(prev, exchange, type,
- durable, auto_delete,
- arguments):
- raise NotEquivalentError(NOT_EQUIVALENT_FMT.format(
- exchange, self.connection.client.virtual_host or '/'))
- except KeyError:
- self.state.exchanges[exchange] = {
- 'type': type,
- 'durable': durable,
- 'auto_delete': auto_delete,
- 'arguments': arguments or {},
- 'table': [],
- }
-
- def exchange_delete(self, exchange, if_unused=False, nowait=False):
- """Delete `exchange` and all its bindings."""
- for rkey, _, queue in self.get_table(exchange):
- self.queue_delete(queue, if_unused=True, if_empty=True)
- self.state.exchanges.pop(exchange, None)
-
- def queue_declare(self, queue=None, passive=False, **kwargs):
- """Declare queue."""
- queue = queue or 'amq.gen-%s' % uuid()
- if passive and not self._has_queue(queue, **kwargs):
- raise ChannelError(
- 'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format(
- queue, self.connection.client.virtual_host or '/'),
- (50, 10), 'Channel.queue_declare', '404',
- )
- else:
- self._new_queue(queue, **kwargs)
- return queue_declare_ok_t(queue, self._size(queue), 0)
-
- def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs):
- """Delete queue."""
- if if_empty and self._size(queue):
- return
- for exchange, routing_key, args in self.state.queue_bindings(queue):
- meta = self.typeof(exchange).prepare_bind(
- queue, exchange, routing_key, args,
- )
- self._delete(queue, exchange, *meta, **kwargs)
- self.state.queue_bindings_delete(queue)
-
- def after_reply_message_received(self, queue):
- self.queue_delete(queue)
-
- def exchange_bind(self, destination, source='', routing_key='',
- nowait=False, arguments=None):
- raise NotImplementedError('transport does not support exchange_bind')
-
- def exchange_unbind(self, destination, source='', routing_key='',
- nowait=False, arguments=None):
- raise NotImplementedError('transport does not support exchange_unbind')
-
- def queue_bind(self, queue, exchange=None, routing_key='',
- arguments=None, **kwargs):
- """Bind `queue` to `exchange` with `routing key`."""
- exchange = exchange or 'amq.direct'
- if self.state.has_binding(queue, exchange, routing_key):
- return
- # Add binding:
- self.state.binding_declare(queue, exchange, routing_key, arguments)
- # Update exchange's routing table:
- table = self.state.exchanges[exchange].setdefault('table', [])
- meta = self.typeof(exchange).prepare_bind(
- queue, exchange, routing_key, arguments,
- )
- table.append(meta)
- if self.supports_fanout:
- self._queue_bind(exchange, *meta)
-
- def queue_unbind(self, queue, exchange=None, routing_key='',
- arguments=None, **kwargs):
- # Remove queue binding:
- self.state.binding_delete(queue, exchange, routing_key)
- try:
- table = self.get_table(exchange)
- except KeyError:
- return
- binding_meta = self.typeof(exchange).prepare_bind(
- queue, exchange, routing_key, arguments,
- )
- # TODO: the complexity of this operation is O(number of bindings).
- # Should be optimized. Modifying table in place.
- table[:] = [meta for meta in table if meta != binding_meta]
-
- def list_bindings(self):
- return ((queue, exchange, rkey)
- for exchange in self.state.exchanges
- for rkey, pattern, queue in self.get_table(exchange))
-
- def queue_purge(self, queue, **kwargs):
- """Remove all ready messages from queue."""
- return self._purge(queue)
-
- def _next_delivery_tag(self):
- return uuid()
-
- def basic_publish(self, message, exchange, routing_key, **kwargs):
- """Publish message."""
- message['body'], body_encoding = self.encode_body(
- message['body'], self.body_encoding,
- )
- props = message['properties']
- props.update(
- body_encoding=body_encoding,
- delivery_tag=self._next_delivery_tag(),
- )
- props['delivery_info'].update(
- exchange=exchange,
- routing_key=routing_key,
- )
- if exchange:
- return self.typeof(exchange).deliver(
- message, exchange, routing_key, **kwargs
- )
- # anon exchange: routing_key is the destination queue
- return self._put(routing_key, message, **kwargs)
-
- def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):
- """Consume from `queue`"""
- self._tag_to_queue[consumer_tag] = queue
- self._active_queues.append(queue)
-
- def _callback(raw_message):
- message = self.Message(self, raw_message)
- if not no_ack:
- self.qos.append(message, message.delivery_tag)
- return callback(message)
-
- self.connection._callbacks[queue] = _callback
- self._consumers.add(consumer_tag)
-
- self._reset_cycle()
-
- def basic_cancel(self, consumer_tag):
- """Cancel consumer by consumer tag."""
- if consumer_tag in self._consumers:
- self._consumers.remove(consumer_tag)
- self._reset_cycle()
- queue = self._tag_to_queue.pop(consumer_tag, None)
- try:
- self._active_queues.remove(queue)
- except ValueError:
- pass
- self.connection._callbacks.pop(queue, None)
-
- def basic_get(self, queue, no_ack=False, **kwargs):
- """Get message by direct access (synchronous)."""
- try:
- message = self.Message(self, self._get(queue))
- if not no_ack:
- self.qos.append(message, message.delivery_tag)
- return message
- except Empty:
- pass
-
- def basic_ack(self, delivery_tag, multiple=False):
- """Acknowledge message."""
- self.qos.ack(delivery_tag)
-
- def basic_recover(self, requeue=False):
- """Recover unacked messages."""
- if requeue:
- return self.qos.restore_unacked()
- raise NotImplementedError('Does not support recover(requeue=False)')
-
- def basic_reject(self, delivery_tag, requeue=False):
- """Reject message."""
- self.qos.reject(delivery_tag, requeue=requeue)
-
- def basic_qos(self, prefetch_size=0, prefetch_count=0,
- apply_global=False):
- """Change QoS settings for this channel.
-
- Note:
- Only `prefetch_count` is supported.
- """
- self.qos.prefetch_count = prefetch_count
-
- def get_exchanges(self):
- return list(self.state.exchanges)
-
- def get_table(self, exchange):
- """Get table of bindings for `exchange`."""
- return self.state.exchanges[exchange]['table']
-
- def typeof(self, exchange, default='direct'):
- """Get the exchange type instance for `exchange`."""
- try:
- type = self.state.exchanges[exchange]['type']
- except KeyError:
- type = default
- return self.exchange_types[type]
-
- def _lookup(self, exchange, routing_key, default=None):
- """Find all queues matching `routing_key` for the given `exchange`.
-
- Returns:
- str: queue name -- must return the string `default`
- if no queues matched.
- """
- if default is None:
- default = self.deadletter_queue
- if not exchange: # anon exchange
- return [routing_key or default]
-
- try:
- R = self.typeof(exchange).lookup(
- self.get_table(exchange),
- exchange, routing_key, default,
- )
- except KeyError:
- R = []
-
- if not R and default is not None:
- warnings.warn(UndeliverableWarning(UNDELIVERABLE_FMT.format(
- exchange=exchange, routing_key=routing_key)),
- )
- self._new_queue(default)
- R = [default]
- return R
-
- def _restore(self, message):
- """Redeliver message to its original destination."""
- delivery_info = message.delivery_info
- message = message.serializable()
- message['redelivered'] = True
- for queue in self._lookup(
- delivery_info['exchange'], delivery_info['routing_key']):
- self._put(queue, message)
-
- def _restore_at_beginning(self, message):
- return self._restore(message)
-
- def drain_events(self, timeout=None):
- if self._consumers and self.qos.can_consume():
- if hasattr(self, '_get_many'):
- return self._get_many(self._active_queues, timeout=timeout)
- return self._poll(self.cycle, timeout=timeout)
- raise Empty()
-
- def message_to_python(self, raw_message):
- """Convert raw message to :class:`Message` instance."""
- if not isinstance(raw_message, self.Message):
- return self.Message(self, payload=raw_message)
- return raw_message
-
- def prepare_message(self, body, priority=None, content_type=None,
- content_encoding=None, headers=None, properties=None):
- """Prepare message data."""
- properties = properties or {}
- properties.setdefault('delivery_info', {})
- properties.setdefault('priority', priority or self.default_priority)
-
- return {'body': body,
- 'content-encoding': content_encoding,
- 'content-type': content_type,
- 'headers': headers or {},
- 'properties': properties or {}}
-
- def flow(self, active=True):
- """Enable/disable message flow.
-
- Raises:
- NotImplementedError: as flow
- is not implemented by the base virtual implementation.
- """
- raise NotImplementedError('virtual channels do not support flow.')
-
- def close(self):
- """Close channel, cancel all consumers, and requeue unacked
- messages."""
- if not self.closed:
- self.closed = True
- for consumer in list(self._consumers):
- self.basic_cancel(consumer)
- if self._qos:
- self._qos.restore_unacked_once()
- if self._cycle is not None:
- self._cycle.close()
- self._cycle = None
- if self.connection is not None:
- self.connection.close_channel(self)
- self.exchange_types = None
-
- def encode_body(self, body, encoding=None):
- if encoding:
- return self.codecs.get(encoding).encode(body), encoding
- return body, encoding
-
- def decode_body(self, body, encoding=None):
- if encoding:
- return self.codecs.get(encoding).decode(body)
- return body
-
- def _reset_cycle(self):
- self._cycle = FairCycle(self._get, self._active_queues, Empty)
-
- def __enter__(self):
- return self
-
- def __exit__(self, *exc_info):
- self.close()
-
- @property
- def state(self):
- """Broker state containing exchanges and bindings."""
- return self.connection.state
-
- @property
- def qos(self):
- """:class:`QoS` manager for this channel."""
- if self._qos is None:
- self._qos = self.QoS(self)
- return self._qos
-
- @property
- def cycle(self):
- if self._cycle is None:
- self._reset_cycle()
- return self._cycle
-
- def _get_message_priority(self, message, reverse=False):
- """Get priority from message and limit the value within a
- boundary of 0 to 9.
-
- Note:
- Higher value has more priority.
- """
- try:
- priority = max(
- min(int(message['properties']['priority']),
- self.max_priority),
- self.min_priority,
- )
- except (TypeError, ValueError, KeyError):
- priority = self.default_priority
-
- return (self.max_priority - priority) if reverse else priority
-
-
-class Management(base.Management):
-
- def __init__(self, transport):
- super(Management, self).__init__(transport)
- self.channel = transport.client.channel()
-
- def get_bindings(self):
- return [dict(destination=q, source=e, routing_key=r)
- for q, e, r in self.channel.list_bindings()]
-
- def close(self):
- self.channel.close()
-
-
-class Transport(base.Transport):
- """Virtual transport.
-
- Arguments:
- client (kombu.Connection): The client this is a transport for.
- """
- Channel = Channel
- Cycle = FairCycle
- Management = Management
-
- #: :class:`BrokerState` containing declared exchanges and
- #: bindings (set by constructor).
- state = BrokerState()
-
- #: :class:`~kombu.utils.scheduling.FairCycle` instance
- #: used to fairly drain events from channels (set by constructor).
- cycle = None
-
- #: port number used when no port is specified.
- default_port = None
-
- #: active channels.
- channels = None
-
- #: queue/callback map.
- _callbacks = None
-
- #: Time to sleep between unsuccessful polls.
- polling_interval = 1.0
-
- #: Max number of channels
- channel_max = 65535
-
- implements = base.Transport.implements.extend(
- async=False,
- exchange_type=frozenset(['direct', 'topic']),
- heartbeats=False,
- )
-
- def __init__(self, client, **kwargs):
- self.client = client
- self.channels = []
- self._avail_channels = []
- self._callbacks = {}
- self.cycle = self.Cycle(self._drain_channel, self.channels, Empty)
- polling_interval = client.transport_options.get('polling_interval')
- if polling_interval is not None:
- self.polling_interval = polling_interval
- self._avail_channel_ids = array(
- ARRAY_TYPE_H, range(self.channel_max, 0, -1),
- )
-
- def create_channel(self, connection):
- try:
- return self._avail_channels.pop()
- except IndexError:
- channel = self.Channel(connection)
- self.channels.append(channel)
- return channel
-
- def close_channel(self, channel):
- try:
- self._avail_channel_ids.append(channel.channel_id)
- try:
- self.channels.remove(channel)
- except ValueError:
- pass
- finally:
- channel.connection = None
-
- def establish_connection(self):
- # creates channel to verify connection.
- # this channel is then used as the next requested channel.
- # (returned by ``create_channel``).
- self._avail_channels.append(self.create_channel(self))
- return self # for drain events
-
- def close_connection(self, connection):
- self.cycle.close()
- for l in self._avail_channels, self.channels:
- while l:
- try:
- channel = l.pop()
- except LookupError: # pragma: no cover
- pass
- else:
- channel.close()
-
- def drain_events(self, connection, timeout=None):
- loop = 0
- time_start = monotonic()
- get = self.cycle.get
- polling_interval = self.polling_interval
- while 1:
- try:
- item, channel = get(timeout=timeout)
- except Empty:
- if timeout and monotonic() - time_start >= timeout:
- raise socket.timeout()
- loop += 1
- if polling_interval is not None:
- sleep(polling_interval)
- else:
- break
- self._deliver(*item)
-
- def _deliver(self, message, queue):
- if not queue:
- raise KeyError(
- 'Received message without destination queue: {0}'.format(
- message))
- try:
- callback = self._callbacks[queue]
- except KeyError:
- logger.warn(W_NO_CONSUMERS, queue)
- self._reject_inbound_message(message)
- else:
- callback(message)
-
- def _reject_inbound_message(self, raw_message):
- for channel in self.channels:
- if channel:
- message = channel.Message(channel, raw_message)
- channel.qos.append(message, message.delivery_tag)
- channel.basic_reject(message.delivery_tag, requeue=True)
- break
-
- def on_message_ready(self, channel, message, queue):
- if not queue or queue not in self._callbacks:
- raise KeyError(
- 'Message for queue {0!r} without consumers: {1}'.format(
- queue, message))
- self._callbacks[queue](message)
-
- def _drain_channel(self, channel, timeout=None):
- return channel.drain_events(timeout=timeout)
-
- @property
- def default_connection_params(self):
- return {'port': self.default_port, 'hostname': 'localhost'}
+from __future__ import absolute_import, unicode_literals
+
+from .base import (
+ Base64, NotEquivalentError, UndeliverableWarning, BrokerState,
+ QoS, Message, AbstractChannel, Channel, Management, Transport,
+ Empty, binding_key_t, queue_binding_t,
+)
+
+__all__ = [
+ 'Base64', 'NotEquivalentError', 'UndeliverableWarning', 'BrokerState',
+ 'QoS', 'Message', 'AbstractChannel', 'Channel', 'Management', 'Transport',
+ 'Empty', 'binding_key_t', 'queue_binding_t',
+]
diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py
new file mode 100644
index 00000000..3811ddb5
--- /dev/null
+++ b/kombu/transport/virtual/base.py
@@ -0,0 +1,989 @@
+"""Virtual transport implementation.
+
+Emulates the AMQ API for non-AMQ transports.
+"""
+from __future__ import absolute_import, print_function, unicode_literals
+
+import base64
+import socket
+import sys
+import warnings
+
+from array import array
+from collections import OrderedDict, defaultdict, namedtuple
+from itertools import count
+from multiprocessing.util import Finalize
+from time import sleep
+
+from amqp.protocol import queue_declare_ok_t
+
+from kombu.exceptions import ResourceError, ChannelError
+from kombu.five import Empty, items, monotonic
+from kombu.log import get_logger
+from kombu.utils.encoding import str_to_bytes, bytes_to_str
+from kombu.utils.div import emergency_dump_state
+from kombu.utils.scheduling import FairCycle
+from kombu.utils.uuid import uuid
+
+from kombu.transport import base
+
+from .exchange import STANDARD_EXCHANGE_TYPES
+
+ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H'
+
+UNDELIVERABLE_FMT = """\
+Message could not be delivered: No queues bound to exchange {exchange!r} \
+using binding key {routing_key!r}.
+"""
+
+NOT_EQUIVALENT_FMT = """\
+Cannot redeclare exchange {0!r} in vhost {1!r} with \
+different type, durable, autodelete or arguments value.\
+"""
+
+W_NO_CONSUMERS = """\
+Requeuing undeliverable message for queue %r: No consumers.\
+"""
+
+RESTORING_FMT = 'Restoring {0!r} unacknowledged message(s)'
+RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}'
+
+logger = get_logger(__name__)
+
+#: Key format used for queue argument lookups in BrokerState.bindings.
+binding_key_t = namedtuple('binding_key_t', (
+ 'queue', 'exchange', 'routing_key',
+))
+
+#: BrokerState.queue_bindings generates tuples in this format.
+queue_binding_t = namedtuple('queue_binding_t', (
+ 'exchange', 'routing_key', 'arguments',
+))
+
+
+class Base64(object):
+
+ def encode(self, s):
+ return bytes_to_str(base64.b64encode(str_to_bytes(s)))
+
+ def decode(self, s):
+ return base64.b64decode(str_to_bytes(s))
+
+
+class NotEquivalentError(Exception):
+ """Entity declaration is not equivalent to the previous declaration."""
+ pass
+
+
+class UndeliverableWarning(UserWarning):
+ """The message could not be delivered to a queue."""
+ pass
+
+
+class BrokerState(object):
+
+ #: Mapping of exchange name to
+ #: :class:`kombu.transport.virtual.exchange.ExchangeType`
+ exchanges = None
+
+ #: This is the actual bindings registry, used to store bindings and to
+ #: test 'in' relationships in constant time. It has the following
+ #: structure::
+ #:
+ #: {
+ #: (queue, exchange, routing_key): arguments,
+ #: # ...,
+ #: }
+ bindings = None
+
+ #: The queue index is used to access directly (constant time)
+ #: all the bindings of a certain queue. It has the following structure::
+ #: {
+ #: queue: {
+ #: (queue, exchange, routing_key),
+ #: # ...,
+ #: },
+ #: # ...,
+ #: }
+ queue_index = None
+
+ def __init__(self, exchanges=None):
+ self.exchanges = {} if exchanges is None else exchanges
+ self.bindings = {}
+ self.queue_index = defaultdict(set)
+
+ def clear(self):
+ self.exchanges.clear()
+ self.bindings.clear()
+ self.queue_index.clear()
+
+ def has_binding(self, queue, exchange, routing_key):
+ return (queue, exchange, routing_key) in self.bindings
+
+ def binding_declare(self, queue, exchange, routing_key, arguments):
+ key = binding_key_t(queue, exchange, routing_key)
+ self.bindings.setdefault(key, arguments)
+ self.queue_index[queue].add(key)
+
+ def binding_delete(self, queue, exchange, routing_key):
+ key = binding_key_t(queue, exchange, routing_key)
+ try:
+ del self.bindings[key]
+ except KeyError:
+ pass
+ else:
+ self.queue_index[queue].remove(key)
+
+ def queue_bindings_delete(self, queue):
+ try:
+ bindings = self.queue_index.pop(queue)
+ except KeyError:
+ pass
+ else:
+ [self.bindings.pop(binding, None) for binding in bindings]
+
+ def queue_bindings(self, queue):
+ return (
+ queue_binding_t(key.exchange, key.routing_key, self.bindings[key])
+ for key in self.queue_index[queue]
+ )
+
+
+class QoS(object):
+ """Quality of Service guarantees.
+
+ Only supports `prefetch_count` at this point.
+
+ Arguments:
+ channel (ChannelT): Connection channel.
+ prefetch_count (int): Initial prefetch count (defaults to 0).
+ """
+
+ #: current prefetch count value
+ prefetch_count = 0
+
+ #: :class:`~collections.OrderedDict` of active messages.
+ #: *NOTE*: Can only be modified by the consuming thread.
+ _delivered = None
+
+ #: acks can be done by other threads than the consuming thread.
+ #: Instead of a mutex, which doesn't perform well here, we mark
+ #: the delivery tags as dirty, so subsequent calls to append() can remove
+ #: them.
+ _dirty = None
+
+ #: If disabled, unacked messages won't be restored at shutdown.
+ restore_at_shutdown = True
+
+ def __init__(self, channel, prefetch_count=0):
+ self.channel = channel
+ self.prefetch_count = prefetch_count or 0
+
+ self._delivered = OrderedDict()
+ self._delivered.restored = False
+ self._dirty = set()
+ self._quick_ack = self._dirty.add
+ self._quick_append = self._delivered.__setitem__
+ self._on_collect = Finalize(
+ self, self.restore_unacked_once, exitpriority=1,
+ )
+
+ def can_consume(self):
+ """Return true if the channel can be consumed from.
+
+ Used to ensure the client adhers to currently active
+ prefetch limits.
+ """
+ pcount = self.prefetch_count
+ return not pcount or len(self._delivered) - len(self._dirty) < pcount
+
+ def can_consume_max_estimate(self):
+ """Returns the maximum number of messages allowed to be returned.
+
+ Returns an estimated number of messages that a consumer may be allowed
+ to consume at once from the broker. This is used for services where
+ bulk 'get message' calls are preferred to many individual 'get message'
+ calls - like SQS.
+
+ Returns:
+ int: greater than zero.
+ """
+ pcount = self.prefetch_count
+ if pcount:
+ return max(pcount - (len(self._delivered) - len(self._dirty)), 0)
+
+ def append(self, message, delivery_tag):
+ """Append message to transactional state."""
+ if self._dirty:
+ self._flush()
+ self._quick_append(delivery_tag, message)
+
+ def get(self, delivery_tag):
+ return self._delivered[delivery_tag]
+
+ def _flush(self):
+ """Flush dirty (acked/rejected) tags from."""
+ dirty = self._dirty
+ delivered = self._delivered
+ while 1:
+ try:
+ dirty_tag = dirty.pop()
+ except KeyError:
+ break
+ delivered.pop(dirty_tag, None)
+
+ def ack(self, delivery_tag):
+ """Acknowledge message and remove from transactional state."""
+ self._quick_ack(delivery_tag)
+
+ def reject(self, delivery_tag, requeue=False):
+ """Remove from transactional state and requeue message."""
+ if requeue:
+ self.channel._restore_at_beginning(self._delivered[delivery_tag])
+ self._quick_ack(delivery_tag)
+
+ def restore_unacked(self):
+ """Restore all unacknowledged messages."""
+ self._flush()
+ delivered = self._delivered
+ errors = []
+ restore = self.channel._restore
+ pop_message = delivered.popitem
+
+ while delivered:
+ try:
+ _, message = pop_message()
+ except KeyError: # pragma: no cover
+ break
+
+ try:
+ restore(message)
+ except BaseException as exc:
+ errors.append((exc, message))
+ delivered.clear()
+ return errors
+
+ def restore_unacked_once(self, stderr=None):
+ """Restores all unacknowledged messages at shutdown/gc collect.
+
+ Note:
+ Can only be called once for each instance, subsequent
+ calls will be ignored.
+ """
+ self._on_collect.cancel()
+ self._flush()
+ stderr = sys.stderr if stderr is None else stderr
+ state = self._delivered
+
+ if not self.restore_at_shutdown or not self.channel.do_restore:
+ return
+ if getattr(state, 'restored', None):
+ assert not state
+ return
+ try:
+ if state:
+ print('GOING TO RESTORE')
+ print(RESTORING_FMT.format(len(self._delivered)),
+ file=stderr)
+ unrestored = self.restore_unacked()
+
+ if unrestored:
+ errors, messages = list(zip(*unrestored))
+ print(RESTORE_PANIC_FMT.format(len(errors), errors),
+ file=stderr)
+ emergency_dump_state(messages, stderr=stderr)
+ finally:
+ state.restored = True
+
+ def restore_visible(self, *args, **kwargs):
+ """Restore any pending unackwnowledged messages for visibility_timeout
+ style implementations.
+
+ Note:
+ This is implementation optional, and currently only
+ used by the Redis transport.
+ """
+ pass
+
+
+class Message(base.Message):
+
+ def __init__(self, channel, payload, **kwargs):
+ self._raw = payload
+ properties = payload['properties']
+ body = payload.get('body')
+ if body:
+ body = channel.decode_body(body, properties.get('body_encoding'))
+ kwargs.update({
+ 'body': body,
+ 'delivery_tag': properties['delivery_tag'],
+ 'content_type': payload.get('content-type'),
+ 'content_encoding': payload.get('content-encoding'),
+ 'headers': payload.get('headers'),
+ 'properties': properties,
+ 'delivery_info': properties.get('delivery_info'),
+ 'postencode': 'utf-8',
+ })
+ super(Message, self).__init__(channel, **kwargs)
+
+ def serializable(self):
+ props = self.properties
+ body, _ = self.channel.encode_body(self.body,
+ props.get('body_encoding'))
+ headers = dict(self.headers)
+ # remove compression header
+ headers.pop('compression', None)
+ return {
+ 'body': body,
+ 'properties': props,
+ 'content-type': self.content_type,
+ 'content-encoding': self.content_encoding,
+ 'headers': headers,
+ }
+
+
+class AbstractChannel(object):
+ """This is an abstract class defining the channel methods
+ you'd usually want to implement in a virtual channel.
+
+ Note:
+ Do not subclass directly, but rather inherit
+ from :class:`Channel`.
+ """
+
+ def _get(self, queue, timeout=None):
+ """Get next message from `queue`."""
+ raise NotImplementedError('Virtual channels must implement _get')
+
+ def _put(self, queue, message):
+ """Put `message` onto `queue`."""
+ raise NotImplementedError('Virtual channels must implement _put')
+
+ def _purge(self, queue):
+ """Remove all messages from `queue`."""
+ raise NotImplementedError('Virtual channels must implement _purge')
+
+ def _size(self, queue):
+ """Return the number of messages in `queue` as an :class:`int`."""
+ return 0
+
+ def _delete(self, queue, *args, **kwargs):
+ """Delete `queue`.
+
+ Note:
+ This just purges the queue, if you need to do more you can
+ override this method.
+ """
+ self._purge(queue)
+
+ def _new_queue(self, queue, **kwargs):
+ """Create new queue.
+
+ Note:
+ Your transport can override this method if it needs
+ to do something whenever a new queue is declared.
+ """
+ pass
+
+ def _has_queue(self, queue, **kwargs):
+ """Verify that queue exists.
+
+ Returns:
+ bool: Should return :const:`True` if the queue exists
+ or :const:`False` otherwise.
+ """
+ return True
+
+ def _poll(self, cycle, timeout=None):
+ """Poll a list of queues for available messages."""
+ return cycle.get()
+
+
+class Channel(AbstractChannel, base.StdChannel):
+ """Virtual channel.
+
+ Arguments:
+ connection (ConnectionT): The transport instance this
+ channel is part of.
+ """
+ #: message class used.
+ Message = Message
+
+ #: QoS class used.
+ QoS = QoS
+
+ #: flag to restore unacked messages when channel
+ #: goes out of scope.
+ do_restore = True
+
+ #: mapping of exchange types and corresponding classes.
+ exchange_types = dict(STANDARD_EXCHANGE_TYPES)
+
+ #: flag set if the channel supports fanout exchanges.
+ supports_fanout = False
+
+ #: Binary <-> ASCII codecs.
+ codecs = {'base64': Base64()}
+
+ #: Default body encoding.
+ #: NOTE: ``transport_options['body_encoding']`` will override this value.
+ body_encoding = 'base64'
+
+ #: counter used to generate delivery tags for this channel.
+ _delivery_tags = count(1)
+
+ #: Optional queue where messages with no route is delivered.
+ #: Set by ``transport_options['deadletter_queue']``.
+ deadletter_queue = None
+
+ # List of options to transfer from :attr:`transport_options`.
+ from_transport_options = ('body_encoding', 'deadletter_queue')
+
+ # Priority defaults
+ default_priority = 0
+ min_priority = 0
+ max_priority = 9
+
+ def __init__(self, connection, **kwargs):
+ self.connection = connection
+ self._consumers = set()
+ self._cycle = None
+ self._tag_to_queue = {}
+ self._active_queues = []
+ self._qos = None
+ self.closed = False
+
+ # instantiate exchange types
+ self.exchange_types = dict(
+ (typ, cls(self)) for typ, cls in items(self.exchange_types)
+ )
+
+ try:
+ self.channel_id = self.connection._avail_channel_ids.pop()
+ except IndexError:
+ raise ResourceError(
+ 'No free channel ids, current={0}, channel_max={1}'.format(
+ len(self.connection.channels),
+ self.connection.channel_max), (20, 10),
+ )
+
+ topts = self.connection.client.transport_options
+ for opt_name in self.from_transport_options:
+ try:
+ setattr(self, opt_name, topts[opt_name])
+ except KeyError:
+ pass
+
+ def exchange_declare(self, exchange=None, type='direct', durable=False,
+ auto_delete=False, arguments=None,
+ nowait=False, passive=False):
+ """Declare exchange."""
+ type = type or 'direct'
+ exchange = exchange or 'amq.%s' % type
+ if passive:
+ if exchange not in self.state.exchanges:
+ raise ChannelError(
+ 'NOT_FOUND - no exchange {0!r} in vhost {1!r}'.format(
+ exchange, self.connection.client.virtual_host or '/'),
+ (50, 10), 'Channel.exchange_declare', '404',
+ )
+ return
+ try:
+ prev = self.state.exchanges[exchange]
+ if not self.typeof(exchange).equivalent(prev, exchange, type,
+ durable, auto_delete,
+ arguments):
+ raise NotEquivalentError(NOT_EQUIVALENT_FMT.format(
+ exchange, self.connection.client.virtual_host or '/'))
+ except KeyError:
+ self.state.exchanges[exchange] = {
+ 'type': type,
+ 'durable': durable,
+ 'auto_delete': auto_delete,
+ 'arguments': arguments or {},
+ 'table': [],
+ }
+
+ def exchange_delete(self, exchange, if_unused=False, nowait=False):
+ """Delete `exchange` and all its bindings."""
+ for rkey, _, queue in self.get_table(exchange):
+ self.queue_delete(queue, if_unused=True, if_empty=True)
+ self.state.exchanges.pop(exchange, None)
+
+ def queue_declare(self, queue=None, passive=False, **kwargs):
+ """Declare queue."""
+ queue = queue or 'amq.gen-%s' % uuid()
+ if passive and not self._has_queue(queue, **kwargs):
+ raise ChannelError(
+ 'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format(
+ queue, self.connection.client.virtual_host or '/'),
+ (50, 10), 'Channel.queue_declare', '404',
+ )
+ else:
+ self._new_queue(queue, **kwargs)
+ return queue_declare_ok_t(queue, self._size(queue), 0)
+
+ def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs):
+ """Delete queue."""
+ if if_empty and self._size(queue):
+ return
+ for exchange, routing_key, args in self.state.queue_bindings(queue):
+ meta = self.typeof(exchange).prepare_bind(
+ queue, exchange, routing_key, args,
+ )
+ self._delete(queue, exchange, *meta, **kwargs)
+ self.state.queue_bindings_delete(queue)
+
+ def after_reply_message_received(self, queue):
+ self.queue_delete(queue)
+
+ def exchange_bind(self, destination, source='', routing_key='',
+ nowait=False, arguments=None):
+ raise NotImplementedError('transport does not support exchange_bind')
+
+ def exchange_unbind(self, destination, source='', routing_key='',
+ nowait=False, arguments=None):
+ raise NotImplementedError('transport does not support exchange_unbind')
+
+ def queue_bind(self, queue, exchange=None, routing_key='',
+ arguments=None, **kwargs):
+ """Bind `queue` to `exchange` with `routing key`."""
+ exchange = exchange or 'amq.direct'
+ if self.state.has_binding(queue, exchange, routing_key):
+ return
+ # Add binding:
+ self.state.binding_declare(queue, exchange, routing_key, arguments)
+ # Update exchange's routing table:
+ table = self.state.exchanges[exchange].setdefault('table', [])
+ meta = self.typeof(exchange).prepare_bind(
+ queue, exchange, routing_key, arguments,
+ )
+ table.append(meta)
+ if self.supports_fanout:
+ self._queue_bind(exchange, *meta)
+
+ def queue_unbind(self, queue, exchange=None, routing_key='',
+ arguments=None, **kwargs):
+ # Remove queue binding:
+ self.state.binding_delete(queue, exchange, routing_key)
+ try:
+ table = self.get_table(exchange)
+ except KeyError:
+ return
+ binding_meta = self.typeof(exchange).prepare_bind(
+ queue, exchange, routing_key, arguments,
+ )
+ # TODO: the complexity of this operation is O(number of bindings).
+ # Should be optimized. Modifying table in place.
+ table[:] = [meta for meta in table if meta != binding_meta]
+
+ def list_bindings(self):
+ return ((queue, exchange, rkey)
+ for exchange in self.state.exchanges
+ for rkey, pattern, queue in self.get_table(exchange))
+
+ def queue_purge(self, queue, **kwargs):
+ """Remove all ready messages from queue."""
+ return self._purge(queue)
+
+ def _next_delivery_tag(self):
+ return uuid()
+
+ def basic_publish(self, message, exchange, routing_key, **kwargs):
+ """Publish message."""
+ message['body'], body_encoding = self.encode_body(
+ message['body'], self.body_encoding,
+ )
+ props = message['properties']
+ props.update(
+ body_encoding=body_encoding,
+ delivery_tag=self._next_delivery_tag(),
+ )
+ props['delivery_info'].update(
+ exchange=exchange,
+ routing_key=routing_key,
+ )
+ if exchange:
+ return self.typeof(exchange).deliver(
+ message, exchange, routing_key, **kwargs
+ )
+ # anon exchange: routing_key is the destination queue
+ return self._put(routing_key, message, **kwargs)
+
+ def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):
+ """Consume from `queue`"""
+ self._tag_to_queue[consumer_tag] = queue
+ self._active_queues.append(queue)
+
+ def _callback(raw_message):
+ message = self.Message(self, raw_message)
+ if not no_ack:
+ self.qos.append(message, message.delivery_tag)
+ return callback(message)
+
+ self.connection._callbacks[queue] = _callback
+ self._consumers.add(consumer_tag)
+
+ self._reset_cycle()
+
+ def basic_cancel(self, consumer_tag):
+ """Cancel consumer by consumer tag."""
+ if consumer_tag in self._consumers:
+ self._consumers.remove(consumer_tag)
+ self._reset_cycle()
+ queue = self._tag_to_queue.pop(consumer_tag, None)
+ try:
+ self._active_queues.remove(queue)
+ except ValueError:
+ pass
+ self.connection._callbacks.pop(queue, None)
+
+ def basic_get(self, queue, no_ack=False, **kwargs):
+ """Get message by direct access (synchronous)."""
+ try:
+ message = self.Message(self, self._get(queue))
+ if not no_ack:
+ self.qos.append(message, message.delivery_tag)
+ return message
+ except Empty:
+ pass
+
+ def basic_ack(self, delivery_tag, multiple=False):
+ """Acknowledge message."""
+ self.qos.ack(delivery_tag)
+
+ def basic_recover(self, requeue=False):
+ """Recover unacked messages."""
+ if requeue:
+ return self.qos.restore_unacked()
+ raise NotImplementedError('Does not support recover(requeue=False)')
+
+ def basic_reject(self, delivery_tag, requeue=False):
+ """Reject message."""
+ self.qos.reject(delivery_tag, requeue=requeue)
+
+ def basic_qos(self, prefetch_size=0, prefetch_count=0,
+ apply_global=False):
+ """Change QoS settings for this channel.
+
+ Note:
+ Only `prefetch_count` is supported.
+ """
+ self.qos.prefetch_count = prefetch_count
+
+ def get_exchanges(self):
+ return list(self.state.exchanges)
+
+ def get_table(self, exchange):
+ """Get table of bindings for `exchange`."""
+ return self.state.exchanges[exchange]['table']
+
+ def typeof(self, exchange, default='direct'):
+ """Get the exchange type instance for `exchange`."""
+ try:
+ type = self.state.exchanges[exchange]['type']
+ except KeyError:
+ type = default
+ return self.exchange_types[type]
+
+ def _lookup(self, exchange, routing_key, default=None):
+ """Find all queues matching `routing_key` for the given `exchange`.
+
+ Returns:
+ str: queue name -- must return the string `default`
+ if no queues matched.
+ """
+ if default is None:
+ default = self.deadletter_queue
+ if not exchange: # anon exchange
+ return [routing_key or default]
+
+ try:
+ R = self.typeof(exchange).lookup(
+ self.get_table(exchange),
+ exchange, routing_key, default,
+ )
+ except KeyError:
+ R = []
+
+ if not R and default is not None:
+ warnings.warn(UndeliverableWarning(UNDELIVERABLE_FMT.format(
+ exchange=exchange, routing_key=routing_key)),
+ )
+ self._new_queue(default)
+ R = [default]
+ return R
+
+ def _restore(self, message):
+ """Redeliver message to its original destination."""
+ delivery_info = message.delivery_info
+ message = message.serializable()
+ message['redelivered'] = True
+ for queue in self._lookup(
+ delivery_info['exchange'], delivery_info['routing_key']):
+ self._put(queue, message)
+
+ def _restore_at_beginning(self, message):
+ return self._restore(message)
+
+ def drain_events(self, timeout=None):
+ if self._consumers and self.qos.can_consume():
+ if hasattr(self, '_get_many'):
+ return self._get_many(self._active_queues, timeout=timeout)
+ return self._poll(self.cycle, timeout=timeout)
+ raise Empty()
+
+ def message_to_python(self, raw_message):
+ """Convert raw message to :class:`Message` instance."""
+ if not isinstance(raw_message, self.Message):
+ return self.Message(self, payload=raw_message)
+ return raw_message
+
+ def prepare_message(self, body, priority=None, content_type=None,
+ content_encoding=None, headers=None, properties=None):
+ """Prepare message data."""
+ properties = properties or {}
+ properties.setdefault('delivery_info', {})
+ properties.setdefault('priority', priority or self.default_priority)
+
+ return {'body': body,
+ 'content-encoding': content_encoding,
+ 'content-type': content_type,
+ 'headers': headers or {},
+ 'properties': properties or {}}
+
+ def flow(self, active=True):
+ """Enable/disable message flow.
+
+ Raises:
+ NotImplementedError: as flow
+ is not implemented by the base virtual implementation.
+ """
+ raise NotImplementedError('virtual channels do not support flow.')
+
+ def close(self):
+ """Close channel, cancel all consumers, and requeue unacked
+ messages."""
+ if not self.closed:
+ self.closed = True
+ for consumer in list(self._consumers):
+ self.basic_cancel(consumer)
+ if self._qos:
+ self._qos.restore_unacked_once()
+ if self._cycle is not None:
+ self._cycle.close()
+ self._cycle = None
+ if self.connection is not None:
+ self.connection.close_channel(self)
+ self.exchange_types = None
+
+ def encode_body(self, body, encoding=None):
+ if encoding:
+ return self.codecs.get(encoding).encode(body), encoding
+ return body, encoding
+
+ def decode_body(self, body, encoding=None):
+ if encoding:
+ return self.codecs.get(encoding).decode(body)
+ return body
+
+ def _reset_cycle(self):
+ self._cycle = FairCycle(self._get, self._active_queues, Empty)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *exc_info):
+ self.close()
+
+ @property
+ def state(self):
+ """Broker state containing exchanges and bindings."""
+ return self.connection.state
+
+ @property
+ def qos(self):
+ """:class:`QoS` manager for this channel."""
+ if self._qos is None:
+ self._qos = self.QoS(self)
+ return self._qos
+
+ @property
+ def cycle(self):
+ if self._cycle is None:
+ self._reset_cycle()
+ return self._cycle
+
+ def _get_message_priority(self, message, reverse=False):
+ """Get priority from message and limit the value within a
+ boundary of 0 to 9.
+
+ Note:
+ Higher value has more priority.
+ """
+ try:
+ priority = max(
+ min(int(message['properties']['priority']),
+ self.max_priority),
+ self.min_priority,
+ )
+ except (TypeError, ValueError, KeyError):
+ priority = self.default_priority
+
+ return (self.max_priority - priority) if reverse else priority
+
+
+class Management(base.Management):
+
+ def __init__(self, transport):
+ super(Management, self).__init__(transport)
+ self.channel = transport.client.channel()
+
+ def get_bindings(self):
+ return [dict(destination=q, source=e, routing_key=r)
+ for q, e, r in self.channel.list_bindings()]
+
+ def close(self):
+ self.channel.close()
+
+
+class Transport(base.Transport):
+ """Virtual transport.
+
+ Arguments:
+ client (kombu.Connection): The client this is a transport for.
+ """
+ Channel = Channel
+ Cycle = FairCycle
+ Management = Management
+
+ #: :class:`BrokerState` containing declared exchanges and
+ #: bindings (set by constructor).
+ state = BrokerState()
+
+ #: :class:`~kombu.utils.scheduling.FairCycle` instance
+ #: used to fairly drain events from channels (set by constructor).
+ cycle = None
+
+ #: port number used when no port is specified.
+ default_port = None
+
+ #: active channels.
+ channels = None
+
+ #: queue/callback map.
+ _callbacks = None
+
+ #: Time to sleep between unsuccessful polls.
+ polling_interval = 1.0
+
+ #: Max number of channels
+ channel_max = 65535
+
+ implements = base.Transport.implements.extend(
+ async=False,
+ exchange_type=frozenset(['direct', 'topic']),
+ heartbeats=False,
+ )
+
+ def __init__(self, client, **kwargs):
+ self.client = client
+ self.channels = []
+ self._avail_channels = []
+ self._callbacks = {}
+ self.cycle = self.Cycle(self._drain_channel, self.channels, Empty)
+ polling_interval = client.transport_options.get('polling_interval')
+ if polling_interval is not None:
+ self.polling_interval = polling_interval
+ self._avail_channel_ids = array(
+ ARRAY_TYPE_H, range(self.channel_max, 0, -1),
+ )
+
+ def create_channel(self, connection):
+ try:
+ return self._avail_channels.pop()
+ except IndexError:
+ channel = self.Channel(connection)
+ self.channels.append(channel)
+ return channel
+
+ def close_channel(self, channel):
+ try:
+ self._avail_channel_ids.append(channel.channel_id)
+ try:
+ self.channels.remove(channel)
+ except ValueError:
+ pass
+ finally:
+ channel.connection = None
+
+ def establish_connection(self):
+ # creates channel to verify connection.
+ # this channel is then used as the next requested channel.
+ # (returned by ``create_channel``).
+ self._avail_channels.append(self.create_channel(self))
+ return self # for drain events
+
+ def close_connection(self, connection):
+ self.cycle.close()
+ for l in self._avail_channels, self.channels:
+ while l:
+ try:
+ channel = l.pop()
+ except LookupError: # pragma: no cover
+ pass
+ else:
+ channel.close()
+
+ def drain_events(self, connection, timeout=None):
+ loop = 0
+ time_start = monotonic()
+ get = self.cycle.get
+ polling_interval = self.polling_interval
+ while 1:
+ try:
+ item, channel = get(timeout=timeout)
+ except Empty:
+ if timeout and monotonic() - time_start >= timeout:
+ raise socket.timeout()
+ loop += 1
+ if polling_interval is not None:
+ sleep(polling_interval)
+ else:
+ break
+ self._deliver(*item)
+
+ def _deliver(self, message, queue):
+ if not queue:
+ raise KeyError(
+ 'Received message without destination queue: {0}'.format(
+ message))
+ try:
+ callback = self._callbacks[queue]
+ except KeyError:
+ logger.warn(W_NO_CONSUMERS, queue)
+ self._reject_inbound_message(message)
+ else:
+ callback(message)
+
+ def _reject_inbound_message(self, raw_message):
+ for channel in self.channels:
+ if channel:
+ message = channel.Message(channel, raw_message)
+ channel.qos.append(message, message.delivery_tag)
+ channel.basic_reject(message.delivery_tag, requeue=True)
+ break
+
+ def on_message_ready(self, channel, message, queue):
+ if not queue or queue not in self._callbacks:
+ raise KeyError(
+ 'Message for queue {0!r} without consumers: {1}'.format(
+ queue, message))
+ self._callbacks[queue](message)
+
+ def _drain_channel(self, channel, timeout=None):
+ return channel.drain_events(timeout=timeout)
+
+ @property
+ def default_connection_params(self):
+ return {'port': self.default_port, 'hostname': 'localhost'}
diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt
index 485e8a79..0684adfc 100644
--- a/requirements/test-ci.txt
+++ b/requirements/test-ci.txt
@@ -1,4 +1,4 @@
-coverage>=3.0
+pytest-cov
codecov
redis
PyYAML
diff --git a/requirements/test.txt b/requirements/test.txt
index cb284d73..b093eab7 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1,2 +1,3 @@
pytz>dev
-case>=1.2.2
+case>=1.3.1
+pytest
diff --git a/setup.cfg b/setup.cfg
index 322bd056..de74a8f5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,7 +1,6 @@
-[nosetests]
-verbosity = 1
-detailed-errors = 1
-where = kombu/tests
+[tool:pytest]
+testpaths = t/
+python_classes = test_*
[build_sphinx]
source-dir = docs/
diff --git a/setup.py b/setup.py
index d2d284a7..6458ef99 100644
--- a/setup.py
+++ b/setup.py
@@ -5,6 +5,9 @@ import re
import sys
import codecs
+import setuptools
+import setuptools.command.test
+
from distutils.command.install import INSTALL_SCHEMES
if sys.version_info < (2, 7):
@@ -44,7 +47,7 @@ finally:
meta_fh.close()
# --
-packages, data_files = [], []
+data_files = []
root_dir = os.path.dirname(__file__)
if root_dir != '':
os.chdir(root_dir)
@@ -71,9 +74,7 @@ for dirpath, dirnames, filenames in os.walk(src_dir):
if dirname.startswith('.'):
del dirnames[i]
for filename in filenames:
- if filename.endswith('.py'):
- packages.append('.'.join(fullsplit(dirpath)))
- else:
+ if not filename.endswith('.py'):
data_files.append(
[dirpath, [os.path.join(dirpath, f) for f in filenames]],
)
@@ -104,20 +105,46 @@ def reqs(*f):
def extras(*p):
return reqs('extras', *p)
+
+class pytest(setuptools.command.test.test):
+ user_options = [('pytest-args=', 'a', 'Arguments to pass to py.test')]
+
+ def initialize_options(self):
+ setuptools.command.test.test.initialize_options(self)
+ self.pytest_args = []
+
+ def run_tests(self):
+ import pytest
+ sys.exit(pytest.main(self.pytest_args))
+
setup(
name='kombu',
+ packages=setuptools.find_packages(exclude=['t', 't.*']),
version=meta['version'],
description=meta['doc'],
+ long_description=long_description,
author=meta['author'],
author_email=meta['contact'],
url=meta['homepage'],
platforms=['any'],
- packages=packages,
data_files=data_files,
zip_safe=False,
- test_suite='nose.collector',
+ cmdclass={'test': pytest},
install_requires=reqs('default.txt'),
tests_require=reqs('test.txt'),
+ extras_require={
+ 'msgpack': extras('msgpack.txt'),
+ 'yaml': extras('yaml.txt'),
+ 'redis': extras('redis.txt'),
+ 'mongodb': extras('mongodb.txt'),
+ 'sqs': extras('sqs.txt'),
+ 'zookeeper': extras('zookeeper.txt'),
+ 'librabbitmq': extras('librabbitmq.txt'),
+ 'pyro': extras('pyro.txt'),
+ 'slmq': extras('slmq.txt'),
+ 'qpid': extras('qpid.txt'),
+ 'consul': extras('consul.txt'),
+ },
classifiers=[
'Development Status :: 5 - Production/Stable',
'License :: OSI Approved :: BSD License',
@@ -137,18 +164,4 @@ setup(
'Topic :: System :: Networking',
'Topic :: Software Development :: Libraries :: Python Modules',
],
- long_description=long_description,
- extras_require={
- 'msgpack': extras('msgpack.txt'),
- 'yaml': extras('yaml.txt'),
- 'redis': extras('redis.txt'),
- 'mongodb': extras('mongodb.txt'),
- 'sqs': extras('sqs.txt'),
- 'zookeeper': extras('zookeeper.txt'),
- 'librabbitmq': extras('librabbitmq.txt'),
- 'pyro': extras('pyro.txt'),
- 'slmq': extras('slmq.txt'),
- 'qpid': extras('qpid.txt'),
- 'consul': extras('consul.txt'),
- },
)
diff --git a/kombu/tests/async/__init__.py b/t/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/async/__init__.py
+++ b/t/__init__.py
diff --git a/t/conftest.py b/t/conftest.py
new file mode 100644
index 00000000..9e8e959f
--- /dev/null
+++ b/t/conftest.py
@@ -0,0 +1,98 @@
+from __future__ import absolute_import, unicode_literals
+
+import atexit
+import os
+import pytest
+import sys
+
+from kombu.exceptions import VersionMismatch
+
+
+@pytest.fixture(scope='session')
+def multiprocessing_workaround(request):
+ def fin():
+ # Workaround for multiprocessing bug where logging
+ # is attempted after global already collected at shutdown.
+ canceled = set()
+ try:
+ import multiprocessing.util
+ canceled.add(multiprocessing.util._exit_function)
+ except (AttributeError, ImportError):
+ pass
+
+ try:
+ atexit._exithandlers[:] = [
+ e for e in atexit._exithandlers if e[0] not in canceled
+ ]
+ except AttributeError: # pragma: no cover
+ pass # Py3 missing _exithandlers
+ request.addfinalizer(fin)
+
+
+@pytest.fixture(autouse=True)
+def zzzz_test_cases_calls_setup_teardown(request):
+ if request.instance:
+ # we set the .patching attribute for every test class.
+ setup = getattr(request.instance, 'setup', None)
+ # we also call .setup() and .teardown() after every test method.
+ teardown = getattr(request.instance, 'teardown', None)
+ setup and setup()
+ teardown and request.addfinalizer(teardown)
+
+
+@pytest.fixture(autouse=True)
+def test_cases_has_patching(request, patching):
+ if request.instance:
+ request.instance.patching = patching
+
+
+@pytest.fixture
+def hub(request):
+ from kombu.async import Hub, get_event_loop, set_event_loop
+ _prev_hub = get_event_loop()
+ hub = Hub()
+ set_event_loop(hub)
+
+ def fin():
+ if _prev_hub is not None:
+ set_event_loop(_prev_hub)
+ request.addfinalizer(fin)
+ return hub
+
+
+def find_distribution_modules(name=__name__, file=__file__):
+ current_dist_depth = len(name.split('.')) - 1
+ current_dist = os.path.join(os.path.dirname(file),
+ *([os.pardir] * current_dist_depth))
+ abs = os.path.abspath(current_dist)
+ dist_name = os.path.basename(abs)
+
+ for dirpath, dirnames, filenames in os.walk(abs):
+ package = (dist_name + dirpath[len(abs):]).replace('/', '.')
+ if '__init__.py' in filenames:
+ yield package
+ for filename in filenames:
+ if filename.endswith('.py') and filename != '__init__.py':
+ yield '.'.join([package, filename])[:-3]
+
+
+def import_all_modules(name=__name__, file=__file__, skip=[]):
+ for module in find_distribution_modules(name, file):
+ if module not in skip:
+ print('preimporting %r for coverage...' % (module,))
+ try:
+ __import__(module)
+ except (ImportError, VersionMismatch, AttributeError):
+ pass
+
+
+def is_in_coverage():
+ return (os.environ.get('COVER_ALL_MODULES') or
+ any('--cov' in arg for arg in sys.argv))
+
+
+@pytest.fixture(scope='session')
+def cover_all_modules():
+ # so coverage sees all our modules.
+ if is_in_coverage():
+ import_all_modules()
diff --git a/kombu/tests/mocks.py b/t/mocks.py
index dac2b761..ea58eb22 100644
--- a/kombu/tests/mocks.py
+++ b/t/mocks.py
@@ -2,10 +2,34 @@ from __future__ import absolute_import, unicode_literals
from itertools import count
+from case import ContextMock, Mock
+
from kombu.transport import base
from kombu.utils import json
+def PromiseMock(*args, **kwargs):
+ m = Mock(*args, **kwargs)
+
+ def on_throw(exc=None, *args, **kwargs):
+ if exc:
+ raise exc
+ raise
+ m.throw.side_effect = on_throw
+ m.set_error_state.side_effect = on_throw
+ m.throw1.side_effect = on_throw
+ return m
+
+
+class MockPool(object):
+
+ def __init__(self, value=None):
+ self.value = value or ContextMock()
+
+ def acquire(self, **kwargs):
+ return self.value
+
+
class Message(base.Message):
def __init__(self, *args, **kwargs):
diff --git a/t/unit/__init__.py b/t/unit/__init__.py
new file mode 100644
index 00000000..01e6d4f4
--- /dev/null
+++ b/t/unit/__init__.py
@@ -0,0 +1 @@
+from __future__ import absolute_import, unicode_literals
diff --git a/kombu/tests/async/aws/__init__.py b/t/unit/async/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/async/aws/__init__.py
+++ b/t/unit/async/__init__.py
diff --git a/kombu/tests/async/aws/sqs/__init__.py b/t/unit/async/aws/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/async/aws/sqs/__init__.py
+++ b/t/unit/async/aws/__init__.py
diff --git a/kombu/tests/async/aws/case.py b/t/unit/async/aws/case.py
index e7e1a20c..d0af4faf 100644
--- a/kombu/tests/async/aws/case.py
+++ b/t/unit/async/aws/case.py
@@ -1,11 +1,14 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
-from kombu.tests.case import HubCase, skip
+import pytest
+
+from case import skip
@skip.if_pypy()
@skip.unless_module('boto')
@skip.unless_module('pycurl')
-class AWSCase(HubCase):
+@pytest.mark.usefixtures('hub')
+class AWSCase:
pass
diff --git a/kombu/tests/async/http/__init__.py b/t/unit/async/aws/sqs/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/async/http/__init__.py
+++ b/t/unit/async/aws/sqs/__init__.py
diff --git a/kombu/tests/async/aws/sqs/test_connection.py b/t/unit/async/aws/sqs/test_connection.py
index 9a9dcdf0..ecc22623 100644
--- a/kombu/tests/async/aws/sqs/test_connection.py
+++ b/t/unit/async/aws/sqs/test_connection.py
@@ -1,6 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock
+
from kombu.async.aws.sqs.connection import (
AsyncSQSConnection, Attributes, BatchResults,
)
@@ -8,8 +12,9 @@ from kombu.async.aws.sqs.message import AsyncMessage
from kombu.async.aws.sqs.queue import AsyncQueue
from kombu.utils.uuid import uuid
-from kombu.tests.async.aws.case import AWSCase
-from kombu.tests.case import PromiseMock, Mock
+from t.mocks import PromiseMock
+
+from ..case import AWSCase
class test_AsyncSQSConnection(AWSCase):
@@ -25,16 +30,14 @@ class test_AsyncSQSConnection(AWSCase):
from kombu.async.aws.sqs import connection
prev, connection.boto = connection.boto, None
try:
- with self.assertRaises(ImportError):
+ with pytest.raises(ImportError):
AsyncSQSConnection('ak', 'sk', http_client=Mock())
finally:
connection.boto = prev
def test_default_region(self):
- self.assertTrue(self.x.region)
- self.assertTrue(issubclass(
- self.x.region.connection_cls, AsyncSQSConnection,
- ))
+ assert self.x.region
+ assert issubclass(self.x.region.connection_cls, AsyncSQSConnection)
def test_create_queue(self):
self.x.create_queue('foo', callback=self.callback)
diff --git a/kombu/tests/async/aws/sqs/test_message.py b/t/unit/async/aws/sqs/test_message.py
index 0f1a033b..44a0ac32 100644
--- a/kombu/tests/async/aws/sqs/test_message.py
+++ b/t/unit/async/aws/sqs/test_message.py
@@ -1,12 +1,15 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
-from kombu.async.aws.sqs.message import AsyncMessage
+from case import Mock
-from kombu.tests.async.aws.case import AWSCase
-from kombu.tests.case import PromiseMock, Mock
+from kombu.async.aws.sqs.message import AsyncMessage
from kombu.utils.uuid import uuid
+from t.mocks import PromiseMock
+
+from ..case import AWSCase
+
class test_AsyncMessage(AWSCase):
@@ -17,20 +20,18 @@ class test_AsyncMessage(AWSCase):
self.x.receipt_handle = uuid()
def test_delete(self):
- self.assertTrue(self.x.delete(callback=self.callback))
+ assert self.x.delete(callback=self.callback)
self.x.queue.delete_message.assert_called_with(
self.x, self.callback,
)
self.x.queue = None
- self.assertIsNone(self.x.delete(callback=self.callback))
+ assert self.x.delete(callback=self.callback) is None
def test_change_visibility(self):
- self.assertTrue(self.x.change_visibility(303, callback=self.callback))
+ assert self.x.change_visibility(303, callback=self.callback)
self.x.queue.connection.change_message_visibility.assert_called_with(
self.x.queue, self.x.receipt_handle, 303, self.callback,
)
self.x.queue = None
- self.assertIsNone(self.x.change_visibility(
- 303, callback=self.callback,
- ))
+ assert self.x.change_visibility(303, callback=self.callback) is None
diff --git a/kombu/tests/async/aws/sqs/test_queue.py b/t/unit/async/aws/sqs/test_queue.py
index d34eb2e8..635016a6 100644
--- a/kombu/tests/async/aws/sqs/test_queue.py
+++ b/t/unit/async/aws/sqs/test_queue.py
@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock
+
from kombu.async.aws.sqs.message import AsyncMessage
from kombu.async.aws.sqs.queue import AsyncQueue
-from kombu.tests.async.aws.case import AWSCase
-from kombu.tests.case import PromiseMock, Mock
+from t.mocks import PromiseMock
+
+from ..case import AWSCase
class test_AsyncQueue(AWSCase):
@@ -16,7 +21,7 @@ class test_AsyncQueue(AWSCase):
self.callback = PromiseMock(name='callback')
def test_message_class(self):
- self.assertTrue(issubclass(self.x.message_class, AsyncMessage))
+ assert issubclass(self.x.message_class, AsyncMessage)
def test_get_attributes(self):
self.x.get_attributes(attributes='QueueSize', callback=self.callback)
@@ -50,10 +55,10 @@ class test_AsyncQueue(AWSCase):
)
on_ready(808)
self.callback.assert_called_with(808)
- self.assertEqual(self.x.visibility_timeout, 808)
+ assert self.x.visibility_timeout == 808
on_ready(None)
- self.assertEqual(self.x.visibility_timeout, 808)
+ assert self.x.visibility_timeout == 808
def test_add_permission(self):
self.x.add_permission(
@@ -101,8 +106,8 @@ class test_AsyncQueue(AWSCase):
new_message = self.MockMessage('id2', 'digest2')
on_ready(new_message)
- self.assertEqual(message.id, 'id2')
- self.assertEqual(message.md5, 'digest2')
+ assert message.id == 'id2'
+ assert message.md5 == 'digest2'
def test_write_batch(self):
messages = [('id1', 'A', 0), ('id2', 'B', 303)]
@@ -158,45 +163,45 @@ class test_AsyncQueue(AWSCase):
self.callback.assert_called_with(909)
def test_interface__count_slow(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.count_slow()
def test_interface__dump(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.dump()
def test_interface__save_to_file(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.save_to_file()
def test_interface__save_to_filename(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.save_to_filename()
def test_interface__save(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.save()
def test_interface__save_to_s3(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.save_to_s3()
def test_interface__load_from_s3(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.load_from_s3()
def test_interface__load_from_file(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.load_from_file()
def test_interface__load_from_filename(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.load_from_filename()
def test_interface__load(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.load()
def test_interface__clear(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.x.clear()
diff --git a/kombu/tests/async/aws/sqs/test_sqs.py b/t/unit/async/aws/sqs/test_sqs.py
index 433ffdad..ea58596d 100644
--- a/kombu/tests/async/aws/sqs/test_sqs.py
+++ b/t/unit/async/aws/sqs/test_sqs.py
@@ -1,23 +1,26 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock, patch
+
from kombu.async.aws.sqs import regions, connect_to_region
from kombu.async.aws.sqs.connection import AsyncSQSConnection
-from kombu.tests.async.aws.case import AWSCase
-from kombu.tests.case import Mock, patch, set_module_symbol
+from ..case import AWSCase
class test_connect_to_region(AWSCase):
- def test_when_no_boto_installed(self):
- with set_module_symbol('kombu.async.aws.sqs', 'boto', None):
- with self.assertRaises(ImportError):
- regions()
+ def test_when_no_boto_installed(self, patching):
+ patching('kombu.async.aws.sqs.boto', None)
+ with pytest.raises(ImportError):
+ regions()
def test_using_async_connection(self):
for region in regions():
- self.assertIs(region.connection_cls, AsyncSQSConnection)
+ assert region.connection_cls is AsyncSQSConnection
def test_connect_to_region(self):
with patch('kombu.async.aws.sqs.regions') as regions:
@@ -25,7 +28,7 @@ class test_connect_to_region(AWSCase):
region.name = 'us-west-1'
regions.return_value = [region]
conn = connect_to_region('us-west-1', kw=3.33)
- self.assertIs(conn, region.connect.return_value)
+ assert conn is region.connect.return_value
region.connect.assert_called_with(kw=3.33)
- self.assertIsNone(connect_to_region('foo'))
+ assert connect_to_region('foo') is None
diff --git a/kombu/tests/async/aws/test_aws.py b/t/unit/async/aws/test_aws.py
index b16b6bb8..f5ed9aef 100644
--- a/kombu/tests/async/aws/test_aws.py
+++ b/t/unit/async/aws/test_aws.py
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+from case import Mock
+
from kombu.async.aws import connect_sqs
-from kombu.tests.case import Mock
from .case import AWSCase
@@ -11,5 +12,5 @@ class test_connect_sqs(AWSCase):
def test_connection(self):
x = connect_sqs('AAKI', 'ASAK', http_client=Mock())
- self.assertTrue(x)
- self.assertTrue(x.connection)
+ assert x
+ assert x.connection
diff --git a/kombu/tests/async/aws/test_connection.py b/t/unit/async/aws/test_connection.py
index 8304af05..5de76aa5 100644
--- a/kombu/tests/async/aws/test_connection.py
+++ b/t/unit/async/aws/test_connection.py
@@ -1,8 +1,11 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import pytest
+
from contextlib import contextmanager
+from case import Mock, patch
from vine.abstract import Thenable
from kombu.exceptions import HttpError
@@ -18,10 +21,10 @@ from kombu.async.aws.connection import (
AsyncAWSQueryConnection,
)
-from kombu.tests.case import PromiseMock, Mock, patch, set_module_symbol
-
from .case import AWSCase
+from t.mocks import PromiseMock
+
# Not currently working
VALIDATES_CERT = False
@@ -39,56 +42,56 @@ class test_AsyncHTTPConnection(AWSCase):
def test_AsyncHTTPSConnection(self):
x = AsyncHTTPSConnection('aws.vandelay.com')
- self.assertEqual(x.scheme, 'https')
+ assert x.scheme == 'https'
def test_http_client(self):
x = AsyncHTTPConnection('aws.vandelay.com')
- self.assertIs(x.http_client, http.get_client())
+ assert x.http_client is http.get_client()
client = Mock(name='http_client')
y = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
- self.assertIs(y.http_client, client)
+ assert y.http_client is client
def test_args(self):
x = AsyncHTTPConnection(
'aws.vandelay.com', 8083, strict=True, timeout=33.3,
)
- self.assertEqual(x.host, 'aws.vandelay.com')
- self.assertEqual(x.port, 8083)
- self.assertTrue(x.strict)
- self.assertEqual(x.timeout, 33.3)
- self.assertEqual(x.scheme, 'http')
+ assert x.host == 'aws.vandelay.com'
+ assert x.port == 8083
+ assert x.strict
+ assert x.timeout == 33.3
+ assert x.scheme == 'http'
def test_request(self):
x = AsyncHTTPConnection('aws.vandelay.com')
x.request('PUT', '/importer-exporter')
- self.assertEqual(x.path, '/importer-exporter')
- self.assertEqual(x.method, 'PUT')
+ assert x.path == '/importer-exporter'
+ assert x.method == 'PUT'
def test_request_with_body_buffer(self):
x = AsyncHTTPConnection('aws.vandelay.com')
body = Mock(name='body')
body.read.return_value = 'Vandelay Industries'
x.request('PUT', '/importer-exporter', body)
- self.assertEqual(x.method, 'PUT')
- self.assertEqual(x.path, '/importer-exporter')
- self.assertEqual(x.body, 'Vandelay Industries')
+ assert x.method == 'PUT'
+ assert x.path == '/importer-exporter'
+ assert x.body == 'Vandelay Industries'
body.read.assert_called_with()
def test_request_with_body_text(self):
x = AsyncHTTPConnection('aws.vandelay.com')
x.request('PUT', '/importer-exporter', 'Vandelay Industries')
- self.assertEqual(x.method, 'PUT')
- self.assertEqual(x.path, '/importer-exporter')
- self.assertEqual(x.body, 'Vandelay Industries')
+ assert x.method == 'PUT'
+ assert x.path == '/importer-exporter'
+ assert x.body == 'Vandelay Industries'
def test_request_with_headers(self):
x = AsyncHTTPConnection('aws.vandelay.com')
headers = {'Proxy': 'proxy.vandelay.com'}
x.request('PUT', '/importer-exporter', None, headers)
- self.assertIn('Proxy', dict(x.headers))
- self.assertEqual(dict(x.headers)['Proxy'], 'proxy.vandelay.com')
+ assert 'Proxy' in dict(x.headers)
+ assert dict(x.headers)['Proxy'] == 'proxy.vandelay.com'
- def assertRequestCreatedWith(self, url, conn):
+ def assert_request_created_with(self, url, conn):
conn.Request.assert_called_with(
url, method=conn.method,
headers=http.Headers(conn.headers), body=conn.body,
@@ -100,18 +103,18 @@ class test_AsyncHTTPConnection(AWSCase):
x = AsyncHTTPSConnection('aws.vandelay.com')
x.Request = Mock(name='Request')
x.getrequest()
- self.assertRequestCreatedWith('https://aws.vandelay.com/', x)
+ self.assert_request_created_with('https://aws.vandelay.com/', x)
def test_getrequest_nondefault_port(self):
x = AsyncHTTPConnection('aws.vandelay.com', port=8080)
x.Request = Mock(name='Request')
x.getrequest()
- self.assertRequestCreatedWith('http://aws.vandelay.com:8080/', x)
+ self.assert_request_created_with('http://aws.vandelay.com:8080/', x)
y = AsyncHTTPSConnection('aws.vandelay.com', port=8443)
y.Request = Mock(name='Request')
y.getrequest()
- self.assertRequestCreatedWith('https://aws.vandelay.com:8443/', y)
+ self.assert_request_created_with('https://aws.vandelay.com:8443/', y)
def test_getresponse(self):
client = Mock(name='client')
@@ -120,8 +123,8 @@ class test_AsyncHTTPConnection(AWSCase):
x.Response = Mock(name='x.Response')
request = x.getresponse()
x.http_client.add_request.assert_called_with(request)
- self.assertIsInstance(request, Thenable)
- self.assertIsInstance(request.on_ready, Thenable)
+ assert isinstance(request, Thenable)
+ assert isinstance(request.on_ready, Thenable)
response = Mock(name='Response')
request.on_ready(response)
@@ -145,46 +148,46 @@ class test_AsyncHTTPConnection(AWSCase):
callback.assert_called()
wresponse = callback.call_args[0][0]
- self.assertEqual(wresponse.read(), 'The quick brown fox jumps')
- self.assertEqual(wresponse.status, 200)
- self.assertEqual(wresponse.getheader('X-Foo'), 'Hello')
- self.assertDictEqual(dict(wresponse.getheaders()), headers)
- self.assertTrue(wresponse.msg)
- self.assertTrue(wresponse.msg)
- self.assertTrue(repr(wresponse))
+ assert wresponse.read() == 'The quick brown fox jumps'
+ assert wresponse.status == 200
+ assert wresponse.getheader('X-Foo') == 'Hello'
+ assert dict(wresponse.getheaders()) == headers
+ assert wresponse.msg
+ assert wresponse.msg
+ assert repr(wresponse)
def test_repr(self):
- self.assertTrue(repr(AsyncHTTPConnection('aws.vandelay.com')))
+ assert repr(AsyncHTTPConnection('aws.vandelay.com'))
def test_putrequest(self):
x = AsyncHTTPConnection('aws.vandelay.com')
x.putrequest('UPLOAD', '/new')
- self.assertEqual(x.method, 'UPLOAD')
- self.assertEqual(x.path, '/new')
+ assert x.method == 'UPLOAD'
+ assert x.path == '/new'
def test_putheader(self):
x = AsyncHTTPConnection('aws.vandelay.com')
x.putheader('X-Foo', 'bar')
- self.assertListEqual(x.headers, [('X-Foo', 'bar')])
+ assert x.headers == [('X-Foo', 'bar')]
x.putheader('X-Bar', 'baz')
- self.assertListEqual(x.headers, [
+ assert x.headers == [
('X-Foo', 'bar'),
('X-Bar', 'baz'),
- ])
+ ]
def test_send(self):
x = AsyncHTTPConnection('aws.vandelay.com')
x.send('foo')
- self.assertEqual(x.body, 'foo')
+ assert x.body == 'foo'
x.send('bar')
- self.assertEqual(x.body, 'foobar')
+ assert x.body == 'foobar'
def test_interface(self):
x = AsyncHTTPConnection('aws.vandelay.com')
- self.assertIsNone(x.set_debuglevel(3))
- self.assertIsNone(x.connect())
- self.assertIsNone(x.close())
- self.assertIsNone(x.endheaders())
+ assert x.set_debuglevel(3) is None
+ assert x.connect() is None
+ assert x.close() is None
+ assert x.endheaders() is None
class test_AsyncHTTPResponse(AWSCase):
@@ -193,41 +196,41 @@ class test_AsyncHTTPResponse(AWSCase):
r = Mock(name='response')
r.error = HttpError(404, 'NotFound')
x = AsyncHTTPResponse(r)
- self.assertEqual(x.reason, 'NotFound')
+ assert x.reason == 'NotFound'
r.error = None
- self.assertFalse(x.reason)
+ assert not x.reason
class test_AsyncConnection(AWSCase):
- def test_when_boto_missing(self):
- with set_module_symbol('kombu.async.aws.connection', 'boto', None):
- with self.assertRaises(ImportError):
- AsyncConnection(Mock(name='client'))
+ def test_when_boto_missing(self, patching):
+ patching('kombu.async.aws.connection.boto', None)
+ with pytest.raises(ImportError):
+ AsyncConnection(Mock(name='client'))
def test_client(self):
x = AsyncConnection()
- self.assertIs(x._httpclient, http.get_client())
+ assert x._httpclient is http.get_client()
client = Mock(name='client')
y = AsyncConnection(http_client=client)
- self.assertIs(y._httpclient, client)
+ assert y._httpclient is client
def test_get_http_connection(self):
x = AsyncConnection(client=Mock(name='client'))
- self.assertIsInstance(
+ assert isinstance(
x.get_http_connection('aws.vandelay.com', 80, False),
AsyncHTTPConnection,
)
- self.assertIsInstance(
+ assert isinstance(
x.get_http_connection('aws.vandelay.com', 443, True),
AsyncHTTPSConnection,
)
conn = x.get_http_connection('aws.vandelay.com', 80, False)
- self.assertIs(conn.http_client, x._httpclient)
- self.assertEqual(conn.host, 'aws.vandelay.com')
- self.assertEqual(conn.port, 80)
+ assert conn.http_client is x._httpclient
+ assert conn.host == 'aws.vandelay.com'
+ assert conn.port == 80
class test_AsyncAWSAuthConnection(AWSCase):
@@ -239,7 +242,7 @@ class test_AsyncAWSAuthConnection(AWSCase):
Conn = x.get_http_connection = Mock(name='get_http_connection')
callback = PromiseMock(name='callback')
ret = x.make_request('GET', '/foo', callback=callback)
- self.assertIs(ret, callback)
+ assert ret is callback
Conn.return_value.request.assert_called()
Conn.return_value.getresponse.assert_called_with(
callback=callback,
@@ -261,9 +264,8 @@ class test_AsyncAWSAuthConnection(AWSCase):
)
no_callback_ret = x._mexe(request)
- self.assertIsInstance(
- no_callback_ret, Thenable, '_mexe always returns promise',
- )
+ # _mexe always returns promise
+ assert isinstance(no_callback_ret, Thenable)
@patch('boto.log', create=True)
def test_mexe__with_sender(self, _):
@@ -296,11 +298,11 @@ class test_AsyncAWSQueryConnection(AWSCase):
)
self.x._mexe.assert_called()
request = self.x._mexe.call_args[0][0]
- self.assertEqual(request.params['Action'], 'action')
- self.assertEqual(request.params['Version'], self.x.APIVersion)
+ assert request.params['Action'] == 'action'
+ assert request.params['Version'] == self.x.APIVersion
ret = _mexe(request, callback=callback)
- self.assertIs(ret, callback)
+ assert ret is callback
Conn.return_value.request.assert_called()
Conn.return_value.getresponse.assert_called_with(
callback=callback,
@@ -316,8 +318,8 @@ class test_AsyncAWSQueryConnection(AWSCase):
)
self.x._mexe.assert_called()
request = self.x._mexe.call_args[0][0]
- self.assertNotIn('Action', request.params)
- self.assertEqual(request.params['Version'], self.x.APIVersion)
+ assert 'Action' not in request.params
+ assert request.params['Version'] == self.x.APIVersion
@contextmanager
def mock_sax_parse(self, parser):
@@ -364,7 +366,7 @@ class test_AsyncAWSQueryConnection(AWSCase):
self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback)
on_ready = self.assert_make_request_called()
- with self.assertRaises(self.x.ResponseError):
+ with pytest.raises(self.x.ResponseError):
on_ready(self.Response(404, 'Not found'))
def test_get_object(self):
@@ -388,15 +390,15 @@ class test_AsyncAWSQueryConnection(AWSCase):
callback.assert_called()
result = callback.call_args[0][0]
- self.assertEqual(result.value, 42)
- self.assertTrue(result.parent)
+ assert result.value == 42
+ assert result.parent
def test_get_object_error(self):
with self.mock_make_request() as callback:
self.x.get_object('action', {'p': 3.3}, object, callback=callback)
on_ready = self.assert_make_request_called()
- with self.assertRaises(self.x.ResponseError):
+ with pytest.raises(self.x.ResponseError):
on_ready(self.Response(404, 'Not found'))
def test_get_status(self):
@@ -422,7 +424,7 @@ class test_AsyncAWSQueryConnection(AWSCase):
self.x.get_status('action', {'p': 3.3}, callback=callback)
on_ready = self.assert_make_request_called()
- with self.assertRaises(self.x.ResponseError):
+ with pytest.raises(self.x.ResponseError):
on_ready(self.Response(404, 'Not found'))
def test_get_status_error_empty_body(self):
@@ -430,5 +432,5 @@ class test_AsyncAWSQueryConnection(AWSCase):
self.x.get_status('action', {'p': 3.3}, callback=callback)
on_ready = self.assert_make_request_called()
- with self.assertRaises(self.x.ResponseError):
+ with pytest.raises(self.x.ResponseError):
on_ready(self.Response(200, ''))
diff --git a/kombu/tests/transport/__init__.py b/t/unit/async/http/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/transport/__init__.py
+++ b/t/unit/async/http/__init__.py
diff --git a/kombu/tests/async/http/test_curl.py b/t/unit/async/http/test_curl.py
index 7bfc312e..70f38cfa 100644
--- a/kombu/tests/async/http/test_curl.py
+++ b/t/unit/async/http/test_curl.py
@@ -1,41 +1,40 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
-from kombu.async.http.curl import READ, WRITE, CurlClient
+import pytest
+
+from case import Mock, call, patch, skip
-from kombu.tests.case import (
- HubCase, Mock, call, patch, set_module_symbol, skip,
-)
+from kombu.async.http.curl import READ, WRITE, CurlClient
@skip.if_pypy()
@skip.unless_module('pycurl')
-class test_CurlClient(HubCase):
+@pytest.mark.usefixtures('hub')
+class test_CurlClient:
class Client(CurlClient):
Curl = Mock(name='Curl')
- def test_when_pycurl_missing(self):
- with set_module_symbol('kombu.async.http.curl', 'pycurl', None):
- with self.assertRaises(ImportError):
- self.Client()
+ def test_when_pycurl_missing(self, patching):
+ patching('kombu.async.http.curl.pycurl', None)
+ with pytest.raises(ImportError):
+ self.Client()
def test_max_clients_set(self):
x = self.Client(max_clients=303)
- self.assertEqual(x.max_clients, 303)
+ assert x.max_clients == 303
def test_init(self):
with patch('kombu.async.http.curl.pycurl') as _pycurl:
x = self.Client()
- self.assertIsNotNone(x._multi)
- self.assertIsNotNone(x._pending)
- self.assertIsNotNone(x._free_list)
- self.assertIsNotNone(x._fds)
- self.assertEqual(
- x._socket_action, x._multi.socket_action,
- )
- self.assertEqual(len(x._curls), x.max_clients)
- self.assertTrue(x._timeout_check_tref)
+ assert x._multi is not None
+ assert x._pending is not None
+ assert x._free_list is not None
+ assert x._fds is not None
+ assert x._socket_action == x._multi.socket_action
+ assert len(x._curls) == x.max_clients
+ assert x._timeout_check_tref
x._multi.setopt.assert_has_calls([
call(_pycurl.M_TIMERFUNCTION, x._set_timeout),
@@ -59,7 +58,7 @@ class test_CurlClient(HubCase):
x._set_timeout = Mock(name='_set_timeout')
request = Mock(name='request')
x.add_request(request)
- self.assertIn(request, x._pending)
+ assert request in x._pending
x._process_queue.assert_called_with()
x._set_timeout.assert_called_with(0)
@@ -73,7 +72,7 @@ class test_CurlClient(HubCase):
x._fds[fd] = fd
x._handle_socket(_pycurl.POLL_REMOVE, fd, x._multi, None, _pycurl)
hub.remove.assert_called_with(fd)
- self.assertNotIn(fd, x._fds)
+ assert fd not in x._fds
x._handle_socket(_pycurl.POLL_REMOVE, fd, x._multi, None, _pycurl)
# POLL_IN
@@ -83,20 +82,20 @@ class test_CurlClient(HubCase):
x._handle_socket(_pycurl.POLL_IN, fd, x._multi, None, _pycurl)
hub.remove.assert_has_calls([call(fd)])
hub.add_reader.assert_called_with(fd, x.on_readable, fd)
- self.assertEqual(x._fds[fd], READ)
+ assert x._fds[fd] == READ
# POLL_OUT
hub = x.hub = Mock(name='hub')
x._handle_socket(_pycurl.POLL_OUT, fd, x._multi, None, _pycurl)
hub.add_writer.assert_called_with(fd, x.on_writable, fd)
- self.assertEqual(x._fds[fd], WRITE)
+ assert x._fds[fd] == WRITE
# POLL_INOUT
hub = x.hub = Mock(name='hub')
x._handle_socket(_pycurl.POLL_INOUT, fd, x._multi, None, _pycurl)
hub.add_reader.assert_called_with(fd, x.on_readable, fd)
hub.add_writer.assert_called_with(fd, x.on_writable, fd)
- self.assertEqual(x._fds[fd], READ | WRITE)
+ assert x._fds[fd] == READ | WRITE
# UNKNOWN EVENT
hub = x.hub = Mock(name='hub')
diff --git a/kombu/tests/async/http/test_http.py b/t/unit/async/http/test_http.py
index 6177a05c..9e279142 100644
--- a/kombu/tests/async/http/test_http.py
+++ b/t/unit/async/http/test_http.py
@@ -1,38 +1,42 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
from io import BytesIO
from vine import promise
+from case import Mock, skip
+
from kombu.async import http
from kombu.async.http.base import BaseClient, normalize_header
from kombu.exceptions import HttpError
-from kombu.tests.case import HubCase, Mock, PromiseMock, skip
+from t.mocks import PromiseMock
-class test_Headers(HubCase):
+class test_Headers:
def test_normalize(self):
- self.assertEqual(normalize_header('accept-encoding'),
- 'Accept-Encoding')
+ assert normalize_header('accept-encoding') == 'Accept-Encoding'
-class test_Request(HubCase):
+@pytest.mark.usefixtures('hub')
+class test_Request:
def test_init(self):
x = http.Request('http://foo', method='POST')
- self.assertEqual(x.url, 'http://foo')
- self.assertEqual(x.method, 'POST')
+ assert x.url == 'http://foo'
+ assert x.method == 'POST'
x = http.Request('x', max_redirects=100)
- self.assertEqual(x.max_redirects, 100)
+ assert x.max_redirects == 100
- self.assertIsInstance(x.headers, http.Headers)
+ assert isinstance(x.headers, http.Headers)
h = http.Headers()
x = http.Request('x', headers=h)
- self.assertIs(x.headers, h)
- self.assertIsInstance(x.on_ready, promise)
+ assert x.headers is h
+ assert isinstance(x.on_ready, promise)
def test_then(self):
callback = PromiseMock(name='callback')
@@ -43,23 +47,24 @@ class test_Request(HubCase):
callback.assert_called_with(1)
-class test_Response(HubCase):
+@pytest.mark.usefixtures('hub')
+class test_Response:
def test_init(self):
req = http.Request('http://foo')
r = http.Response(req, 200)
- self.assertEqual(r.status, 'OK')
- self.assertEqual(r.effective_url, 'http://foo')
+ assert r.status == 'OK'
+ assert r.effective_url == 'http://foo'
r.raise_for_error()
def test_raise_for_error(self):
req = http.Request('http://foo')
r = http.Response(req, 404)
- self.assertEqual(r.status, 'Not Found')
- self.assertTrue(r.error)
+ assert r.status == 'Not Found'
+ assert r.error
- with self.assertRaises(HttpError):
+ with pytest.raises(HttpError):
r.raise_for_error()
def test_get_body(self):
@@ -68,21 +73,25 @@ class test_Response(HubCase):
req.buffer.write(b'hello')
rn = http.Response(req, 200, buffer=None)
- self.assertIsNone(rn.body)
+ assert rn.body is None
r = http.Response(req, 200, buffer=req.buffer)
- self.assertIsNone(r._body)
- self.assertEqual(r.body, b'hello')
- self.assertEqual(r._body, b'hello')
- self.assertEqual(r.body, b'hello')
+ assert r._body is None
+ assert r.body == b'hello'
+ assert r._body == b'hello'
+ assert r.body == b'hello'
+
+class test_BaseClient:
-class test_BaseClient(HubCase):
+ @pytest.fixture(autouse=True)
+ def setup_hub(self, hub):
+ self.hub = hub
def test_init(self):
c = BaseClient(Mock(name='hub'))
- self.assertTrue(c.hub)
- self.assertTrue(c._header_parser)
+ assert c.hub
+ assert c._header_parser
def test_perform(self):
c = BaseClient(Mock(name='hub'))
@@ -90,7 +99,7 @@ class test_BaseClient(HubCase):
c.perform('http://foo')
c.add_request.assert_called()
- self.assertIsInstance(c.add_request.call_args[0][0], http.Request)
+ assert isinstance(c.add_request.call_args[0][0], http.Request)
req = http.Request('http://bar')
c.perform(req)
@@ -98,7 +107,7 @@ class test_BaseClient(HubCase):
def test_add_request(self):
c = BaseClient(Mock(name='hub'))
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
c.add_request(Mock(name='request'))
def test_header_parser(self):
@@ -109,23 +118,21 @@ class test_BaseClient(HubCase):
c.on_header(headers, 'HTTP/1.1')
c.on_header(headers, 'x-foo-bar: 123')
c.on_header(headers, 'People: George Costanza')
- self.assertEqual(headers._prev_key, 'People')
+ assert headers._prev_key == 'People'
c.on_header(headers, ' Jerry Seinfeld')
c.on_header(headers, ' Elaine Benes')
c.on_header(headers, ' Cosmo Kramer')
- self.assertFalse(headers.complete)
+ assert not headers.complete
c.on_header(headers, '')
- self.assertTrue(headers.complete)
+ assert headers.complete
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
parser.throw(KeyError('foo'))
c.on_header(headers, '')
- self.assertEqual(headers['X-Foo-Bar'], '123')
- self.assertEqual(
- headers['People'],
- 'George Costanza Jerry Seinfeld Elaine Benes Cosmo Kramer',
- )
+ assert headers['X-Foo-Bar'] == '123'
+ assert (headers['People'] ==
+ 'George Costanza Jerry Seinfeld Elaine Benes Cosmo Kramer')
def test_close(self):
BaseClient(Mock(name='hub')).close()
@@ -140,11 +147,11 @@ class test_BaseClient(HubCase):
@skip.if_pypy()
@skip.unless_module('pycurl')
-class test_Client(HubCase):
+class test_Client:
- def test_get_client(self):
+ def test_get_client(self, hub):
client = http.get_client()
- self.assertIs(client.hub, self.hub)
- client2 = http.get_client(self.hub)
- self.assertIs(client2, client)
- self.assertIs(client2.hub, self.hub)
+ assert client.hub is hub
+ client2 = http.get_client(hub)
+ assert client2 is client
+ assert client2.hub is hub
diff --git a/kombu/tests/async/test_hub.py b/t/unit/async/test_hub.py
index 94cf6b04..961de8fe 100644
--- a/kombu/tests/async/test_hub.py
+++ b/t/unit/async/test_hub.py
@@ -1,7 +1,9 @@
from __future__ import absolute_import, unicode_literals
import errno
+import pytest
+from case import Mock, call, patch
from vine import promise
from kombu.async import hub as _hub
@@ -13,8 +15,6 @@ from kombu.async.hub import (
)
from kombu.async.semaphore import DummyLock, LaxBoundedSemaphore
-from kombu.tests.case import Case, Mock, call, patch
-
class File(object):
@@ -33,63 +33,60 @@ class File(object):
return hash(self.fd)
-class test_DummyLock(Case):
-
- def test_context(self):
- mutex = DummyLock()
- with mutex:
- pass
+def test_DummyLock():
+ with DummyLock():
+ pass
-class test_LaxBoundedSemaphore(Case):
+class test_LaxBoundedSemaphore:
def test_acquire_release(self):
x = LaxBoundedSemaphore(2)
c1 = Mock()
x.acquire(c1, 1)
- self.assertEqual(x.value, 1)
+ assert x.value == 1
c1.assert_called_with(1)
c2 = Mock()
x.acquire(c2, 2)
- self.assertEqual(x.value, 0)
+ assert x.value == 0
c2.assert_called_with(2)
c3 = Mock()
x.acquire(c3, 3)
- self.assertEqual(x.value, 0)
+ assert x.value == 0
c3.assert_not_called()
x.release()
- self.assertEqual(x.value, 0)
+ assert x.value == 0
x.release()
- self.assertEqual(x.value, 1)
+ assert x.value == 1
x.release()
- self.assertEqual(x.value, 2)
+ assert x.value == 2
c3.assert_called_with(3)
def test_repr(self):
- self.assertTrue(repr(LaxBoundedSemaphore(2)))
+ assert repr(LaxBoundedSemaphore(2))
def test_bounded(self):
x = LaxBoundedSemaphore(2)
for i in range(100):
x.release()
- self.assertEqual(x.value, 2)
+ assert x.value == 2
def test_grow_shrink(self):
x = LaxBoundedSemaphore(1)
- self.assertEqual(x.initial_value, 1)
+ assert x.initial_value == 1
cb1 = Mock()
x.acquire(cb1, 1)
cb1.assert_called_with(1)
- self.assertEqual(x.value, 0)
+ assert x.value == 0
cb2 = Mock()
x.acquire(cb2, 2)
cb2.assert_not_called()
- self.assertEqual(x.value, 0)
+ assert x.value == 0
cb3 = Mock()
x.acquire(cb3, 3)
@@ -98,39 +95,39 @@ class test_LaxBoundedSemaphore(Case):
x.grow(2)
cb2.assert_called_with(2)
cb3.assert_called_with(3)
- self.assertEqual(x.value, 2)
- self.assertEqual(x.initial_value, 3)
+ assert x.value == 2
+ assert x.initial_value == 3
- self.assertFalse(x._waiting)
+ assert not x._waiting
x.grow(3)
for i in range(x.initial_value):
- self.assertTrue(x.acquire(Mock()))
- self.assertFalse(x.acquire(Mock()))
+ assert x.acquire(Mock())
+ assert not x.acquire(Mock())
x.clear()
x.shrink(3)
for i in range(x.initial_value):
- self.assertTrue(x.acquire(Mock()))
- self.assertFalse(x.acquire(Mock()))
- self.assertEqual(x.value, 0)
+ assert x.acquire(Mock())
+ assert not x.acquire(Mock())
+ assert x.value == 0
for i in range(100):
x.release()
- self.assertEqual(x.value, x.initial_value)
+ assert x.value == x.initial_value
def test_clear(self):
x = LaxBoundedSemaphore(10)
for i in range(11):
x.acquire(Mock())
- self.assertTrue(x._waiting)
- self.assertEqual(x.value, 0)
+ assert x._waiting
+ assert x.value == 0
x.clear()
- self.assertFalse(x._waiting)
- self.assertEqual(x.value, x.initial_value)
+ assert not x._waiting
+ assert x.value == x.initial_value
-class test_Utils(Case):
+class test_Utils:
def setup(self):
self._prev_loop = get_event_loop()
@@ -140,23 +137,23 @@ class test_Utils(Case):
def test_get_set_event_loop(self):
set_event_loop(None)
- self.assertIsNone(_hub._current_loop)
- self.assertIsNone(get_event_loop())
+ assert _hub._current_loop is None
+ assert get_event_loop() is None
hub = Hub()
set_event_loop(hub)
- self.assertIs(_hub._current_loop, hub)
- self.assertIs(get_event_loop(), hub)
+ assert _hub._current_loop is hub
+ assert get_event_loop() is hub
def test_dummy_context(self):
with _dummy_context():
pass
def test_raise_stop_error(self):
- with self.assertRaises(Stop):
+ with pytest.raises(Stop):
_raise_stop_error()
-class test_Hub(Case):
+class test_Hub:
def setup(self):
self.hub = Hub()
@@ -179,7 +176,7 @@ class test_Hub(Case):
poller = self.hub.poller = Mock(name='poller')
self.hub._close_poller()
poller.close.assert_called_with()
- self.assertIsNone(self.hub.poller)
+ assert self.hub.poller is None
def test_stop(self):
self.hub.call_soon = Mock(name='call_soon')
@@ -191,14 +188,14 @@ class test_Hub(Case):
callback = Mock(name='callback')
ret = self.hub.call_soon(callback, 1, 2, 3)
promise.assert_called_with(callback, (1, 2, 3))
- self.assertIn(promise(), self.hub._ready)
- self.assertIs(ret, promise())
+ assert promise() in self.hub._ready
+ assert ret is promise()
def test_call_soon__promise_argument(self):
callback = promise(Mock(name='callback'), (1, 2, 3))
ret = self.hub.call_soon(callback)
- self.assertIs(ret, callback)
- self.assertIn(ret, self.hub._ready)
+ assert ret is callback
+ assert ret in self.hub._ready
def test_call_later(self):
callback = Mock(name='callback')
@@ -213,24 +210,24 @@ class test_Hub(Case):
self.hub.timer.call_at.assert_called_with(21231122, callback, (1, 2))
def test_repr(self):
- self.assertTrue(repr(self.hub))
+ assert repr(self.hub)
def test_repr_flag(self):
- self.assertEqual(repr_flag(READ), 'R')
- self.assertEqual(repr_flag(WRITE), 'W')
- self.assertEqual(repr_flag(ERR), '!')
- self.assertEqual(repr_flag(READ | WRITE), 'RW')
- self.assertEqual(repr_flag(READ | ERR), 'R!')
- self.assertEqual(repr_flag(WRITE | ERR), 'W!')
- self.assertEqual(repr_flag(READ | WRITE | ERR), 'RW!')
+ assert repr_flag(READ) == 'R'
+ assert repr_flag(WRITE) == 'W'
+ assert repr_flag(ERR) == '!'
+ assert repr_flag(READ | WRITE) == 'RW'
+ assert repr_flag(READ | ERR) == 'R!'
+ assert repr_flag(WRITE | ERR) == 'W!'
+ assert repr_flag(READ | WRITE | ERR) == 'RW!'
def test_repr_callback_rcb(self):
def f():
pass
- self.assertEqual(_rcb(f), f.__name__)
- self.assertEqual(_rcb('foo'), 'foo')
+ assert _rcb(f) == f.__name__
+ assert _rcb('foo') == 'foo'
@patch('kombu.async.hub.poll')
def test_start_stop(self, poll):
@@ -245,14 +242,12 @@ class test_Hub(Case):
def test_fire_timers(self):
self.hub.timer = Mock()
self.hub.timer._queue = []
- self.assertEqual(
- self.hub.fire_timers(min_delay=42.324, max_delay=32.321),
- 32.321,
- )
+ assert self.hub.fire_timers(
+ min_delay=42.324, max_delay=32.321) == 32.321
self.hub.timer._queue = [1]
self.hub.scheduler = iter([(3.743, None)])
- self.assertEqual(self.hub.fire_timers(), 3.743)
+ assert self.hub.fire_timers() == 3.743
e1, e2, e3 = Mock(), Mock(), Mock()
entries = [e1, e2, e3]
@@ -267,21 +262,19 @@ class test_Hub(Case):
yield 3.982, None
self.hub.scheduler = se()
- self.assertEqual(self.hub.fire_timers(max_timers=10), 3.982)
+ assert self.hub.fire_timers(max_timers=10) == 3.982
for E in [e3, e2, e1]:
E.assert_called_with()
reset()
entries[:] = [Mock() for _ in range(11)]
keep = list(entries)
- self.assertEqual(
- self.hub.fire_timers(max_timers=10, min_delay=1.13),
- 1.13,
- )
+ assert self.hub.fire_timers(
+ max_timers=10, min_delay=1.13) == 1.13
for E in reversed(keep[1:]):
E.assert_called_with()
reset()
- self.assertEqual(self.hub.fire_timers(max_timers=10), 3.982)
+ assert self.hub.fire_timers(max_timers=10) == 3.982
keep[0].assert_called_with()
def test_fire_timers_raises(self):
@@ -289,32 +282,32 @@ class test_Hub(Case):
eback.side_effect = KeyError('foo')
self.hub.timer = Mock()
self.hub.scheduler = iter([(0, eback)])
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.hub.fire_timers(propagate=(KeyError,))
eback.side_effect = ValueError('foo')
self.hub.scheduler = iter([(0, eback)])
with patch('kombu.async.hub.logger') as logger:
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
self.hub.fire_timers()
logger.error.assert_called()
eback.side_effect = MemoryError('foo')
self.hub.scheduler = iter([(0, eback)])
- with self.assertRaises(MemoryError):
+ with pytest.raises(MemoryError):
self.hub.fire_timers()
eback.side_effect = OSError()
eback.side_effect.errno = errno.ENOMEM
self.hub.scheduler = iter([(0, eback)])
- with self.assertRaises(OSError):
+ with pytest.raises(OSError):
self.hub.fire_timers()
eback.side_effect = OSError()
eback.side_effect.errno = errno.ENOENT
self.hub.scheduler = iter([(0, eback)])
with patch('kombu.async.hub.logger') as logger:
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
self.hub.fire_timers()
logger.error.assert_called()
@@ -322,7 +315,7 @@ class test_Hub(Case):
self.hub.poller = Mock(name='hub.poller')
self.hub.poller.register.side_effect = ValueError()
self.hub._discard = Mock(name='hub.discard')
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.hub.add(2, Mock(), READ)
self.hub._discard.assert_called_with(2)
@@ -331,34 +324,34 @@ class test_Hub(Case):
self.hub.add(2, Mock(), READ)
self.hub.add(2, Mock(), WRITE)
self.hub.remove_reader(2)
- self.assertNotIn(2, self.hub.readers)
- self.assertIn(2, self.hub.writers)
+ assert 2 not in self.hub.readers
+ assert 2 in self.hub.writers
def test_remove_reader__not_writeable(self):
self.hub.poller = Mock(name='hub.poller')
self.hub.add(2, Mock(), READ)
self.hub.remove_reader(2)
- self.assertNotIn(2, self.hub.readers)
+ assert 2 not in self.hub.readers
def test_remove_writer(self):
self.hub.poller = Mock(name='hub.poller')
self.hub.add(2, Mock(), READ)
self.hub.add(2, Mock(), WRITE)
self.hub.remove_writer(2)
- self.assertIn(2, self.hub.readers)
- self.assertNotIn(2, self.hub.writers)
+ assert 2 in self.hub.readers
+ assert 2 not in self.hub.writers
def test_remove_writer__not_readable(self):
self.hub.poller = Mock(name='hub.poller')
self.hub.add(2, Mock(), WRITE)
self.hub.remove_writer(2)
- self.assertNotIn(2, self.hub.writers)
+ assert 2 not in self.hub.writers
def test_add__consolidate(self):
self.hub.poller = Mock(name='hub.poller')
self.hub.add(2, Mock(), WRITE, consolidate=True)
- self.assertIn(2, self.hub.consolidate)
- self.assertIsNone(self.hub.writers[2])
+ assert 2 in self.hub.consolidate
+ assert self.hub.writers[2] is None
@patch('kombu.async.hub.logger')
def test_on_callback_error(self, logger):
@@ -368,8 +361,8 @@ class test_Hub(Case):
def test_loop_property(self):
self.hub._loop = None
self.hub.create_loop = Mock(name='hub.create_loop')
- self.assertIs(self.hub.loop, self.hub.create_loop())
- self.assertIs(self.hub._loop, self.hub.create_loop())
+ assert self.hub.loop is self.hub.create_loop()
+ assert self.hub._loop is self.hub.create_loop()
def test_run_forever(self):
self.hub.run_once = Mock(name='hub.run_once')
@@ -380,7 +373,7 @@ class test_Hub(Case):
self.hub._loop = iter([1])
self.hub.run_once()
self.hub.run_once()
- self.assertIsNone(self.hub._loop)
+ assert self.hub._loop is None
def test_repr_active(self):
self.hub.readers = {1: Mock(), 2: Mock()}
@@ -388,7 +381,7 @@ class test_Hub(Case):
for value in list(
self.hub.readers.values()) + list(self.hub.writers.values()):
value.__name__ = 'mock'
- self.assertTrue(self.hub.repr_active())
+ assert self.hub.repr_active()
def test_repr_events(self):
self.hub.readers = {6: Mock(), 7: Mock(), 8: Mock()}
@@ -396,24 +389,24 @@ class test_Hub(Case):
for value in list(
self.hub.readers.values()) + list(self.hub.writers.values()):
value.__name__ = 'mock'
- self.assertTrue(self.hub.repr_events([
+ assert self.hub.repr_events([
(6, READ),
(7, ERR),
(8, READ | ERR),
(9, WRITE),
(10, 13213),
- ]))
+ ])
def test_callback_for(self):
reader, writer = Mock(), Mock()
self.hub.readers = {6: reader}
self.hub.writers = {7: writer}
- self.assertEqual(callback_for(self.hub, 6, READ), reader)
- self.assertEqual(callback_for(self.hub, 7, WRITE), writer)
- with self.assertRaises(KeyError):
+ assert callback_for(self.hub, 6, READ) == reader
+ assert callback_for(self.hub, 7, WRITE) == writer
+ with pytest.raises(KeyError):
callback_for(self.hub, 6, WRITE)
- self.assertEqual(callback_for(self.hub, 6, WRITE, 'foo'), 'foo')
+ assert callback_for(self.hub, 6, WRITE, 'foo') == 'foo'
def test_add_remove_readers(self):
P = self.hub.poller = Mock()
@@ -428,13 +421,13 @@ class test_Hub(Case):
call(11, self.hub.READ | self.hub.ERR),
], any_order=True)
- self.assertEqual(self.hub.readers[10], (read_A, (10,)))
- self.assertEqual(self.hub.readers[11], (read_B, (11,)))
+ assert self.hub.readers[10] == (read_A, (10,))
+ assert self.hub.readers[11] == (read_B, (11,))
self.hub.remove(10)
- self.assertNotIn(10, self.hub.readers)
+ assert 10 not in self.hub.readers
self.hub.remove(File(11))
- self.assertNotIn(11, self.hub.readers)
+ assert 11 not in self.hub.readers
P.unregister.assert_has_calls([
call(10), call(11),
])
@@ -463,13 +456,13 @@ class test_Hub(Case):
call(21, self.hub.WRITE),
], any_order=True)
- self.assertEqual(self.hub.writers[20], (write_A, ()))
- self.assertEqual(self.hub.writers[21], (write_B, ()))
+ assert self.hub.writers[20], (write_A == ())
+ assert self.hub.writers[21], (write_B == ())
self.hub.remove(20)
- self.assertNotIn(20, self.hub.writers)
+ assert 20 not in self.hub.writers
self.hub.remove(File(21))
- self.assertNotIn(21, self.hub.writers)
+ assert 21 not in self.hub.writers
P.unregister.assert_has_calls([
call(20), call(21),
])
@@ -488,13 +481,13 @@ class test_Hub(Case):
write_B = Mock()
self.hub.add_writer(20, write_A)
self.hub.add_writer(File(21), write_B)
- self.assertTrue(self.hub.readers)
- self.assertTrue(self.hub.writers)
+ assert self.hub.readers
+ assert self.hub.writers
finally:
assert self.hub.poller
self.hub.close()
- self.assertFalse(self.hub.readers)
- self.assertFalse(self.hub.writers)
+ assert not self.hub.readers
+ assert not self.hub.writers
P.unregister.assert_has_calls([
call(10), call(11), call(20), call(21),
@@ -504,7 +497,7 @@ class test_Hub(Case):
def test_scheduler_property(self):
hub = Hub(timer=[1, 2, 3])
- self.assertEqual(list(hub.scheduler), [1, 2, 3])
+ assert list(hub.scheduler), [1, 2 == 3]
def test_loop__tick_callbacks(self):
self.hub._ready = Mock(name='_ready')
@@ -512,7 +505,7 @@ class test_Hub(Case):
ticks = [Mock(name='cb1'), Mock(name='cb2')]
self.hub.on_tick = list(ticks)
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
next(self.hub.loop)
ticks[0].assert_called_once_with()
@@ -528,7 +521,7 @@ class test_Hub(Case):
self.hub.call_soon(cb)
self.hub._ready.add(None)
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
next(self.hub.loop)
callbacks[0].assert_called_once_with()
diff --git a/kombu/tests/async/test_semaphore.py b/t/unit/async/test_semaphore.py
index 703acf61..3db792fa 100644
--- a/kombu/tests/async/test_semaphore.py
+++ b/t/unit/async/test_semaphore.py
@@ -2,10 +2,8 @@ from __future__ import absolute_import, unicode_literals
from kombu.async.semaphore import LaxBoundedSemaphore
-from kombu.tests.case import Case
-
-class test_LaxBoundedSemaphore(Case):
+class test_LaxBoundedSemaphore:
def test_over_release(self):
x = LaxBoundedSemaphore(2)
@@ -17,29 +15,29 @@ class test_LaxBoundedSemaphore(Case):
x.release()
x.acquire(calls.append, 'y')
- self.assertEqual(calls, [1, 2, 3, 4])
+ assert calls, [1, 2, 3 == 4]
for i in range(30):
x.release()
- self.assertEqual(calls, list(range(1, 21)) + ['x', 'y'])
- self.assertEqual(x.value, x.initial_value)
+ assert calls, list(range(1, 21)) + ['x' == 'y']
+ assert x.value == x.initial_value
calls[:] = []
for i in range(1, 11):
x.acquire(calls.append, i)
for i in range(1, 11):
x.release()
- self.assertEqual(calls, list(range(1, 11)))
+ assert calls, list(range(1 == 11))
calls[:] = []
- self.assertEqual(x.value, x.initial_value)
+ assert x.value == x.initial_value
x.acquire(calls.append, 'x')
- self.assertEqual(x.value, 1)
+ assert x.value == 1
x.acquire(calls.append, 'y')
- self.assertEqual(x.value, 0)
+ assert x.value == 0
x.release()
- self.assertEqual(x.value, 1)
+ assert x.value == 1
x.release()
- self.assertEqual(x.value, 2)
+ assert x.value == 2
x.release()
- self.assertEqual(x.value, 2)
+ assert x.value == 2
diff --git a/kombu/tests/async/test_timer.py b/t/unit/async/test_timer.py
index f52fc93e..589ac54b 100644
--- a/kombu/tests/async/test_timer.py
+++ b/t/unit/async/test_timer.py
@@ -1,50 +1,46 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
from datetime import datetime
-from kombu.five import bytes_if_py2
+from case import Mock, patch
from kombu.async.timer import Entry, Timer, to_timestamp
-
-from kombu.tests.case import Case, Mock, mock, patch
+from kombu.five import bytes_if_py2
-class test_to_timestamp(Case):
+class test_to_timestamp:
def test_timestamp(self):
- self.assertIs(to_timestamp(3.13), 3.13)
+ assert to_timestamp(3.13) is 3.13
def test_datetime(self):
- self.assertTrue(to_timestamp(datetime.utcnow()))
+ assert to_timestamp(datetime.utcnow())
-class test_Entry(Case):
+class test_Entry:
def test_call(self):
- scratch = [None]
-
- def timed(x, y, moo='foo'):
- scratch[0] = (x, y, moo)
-
- tref = Entry(timed, (4, 4), {'moo': 'baz'})
+ fun = Mock(name='fun')
+ tref = Entry(fun, (4, 4), {'moo': 'baz'})
tref()
-
- self.assertTupleEqual(scratch[0], (4, 4, 'baz'))
+ fun.assert_called_with(4, 4, moo='baz')
def test_cancel(self):
tref = Entry(lambda x: x, (1,), {})
- self.assertFalse(tref.canceled)
- self.assertFalse(tref.cancelled)
+ assert not tref.canceled
+ assert not tref.cancelled
tref.cancel()
- self.assertTrue(tref.canceled)
- self.assertTrue(tref.cancelled)
+ assert tref.canceled
+ assert tref.cancelled
def test_repr(self):
tref = Entry(lambda x: x(1,), {})
- self.assertTrue(repr(tref))
+ assert repr(tref)
def test_hash(self):
- self.assertTrue(hash(Entry(lambda: None)))
+ assert hash(Entry(lambda: None))
def test_ordering(self):
# we don't care about results, just that it's possible
@@ -56,11 +52,11 @@ class test_Entry(Case):
def test_eq(self):
x = Entry(lambda x: 1)
y = Entry(lambda x: 1)
- self.assertEqual(x, x)
- self.assertNotEqual(x, y)
+ assert x == x
+ assert x != y
-class test_Timer(Case):
+class test_Timer:
def test_enter_exit(self):
x = Timer()
@@ -77,14 +73,11 @@ class test_Timer(Case):
x.cancel(tref)
tref.cancel.assert_called_with()
- self.assertIs(x.schedule, x)
+ assert x.schedule is x
def test_handle_error(self):
from datetime import datetime
- scratch = [None]
-
- def on_error(exc_info):
- scratch[0] = exc_info
+ on_error = Mock(name='on_error')
s = Timer(on_error=on_error)
@@ -94,11 +87,12 @@ class test_Timer(Case):
eta=datetime.now())
s.enter_at(Entry(lambda: None, (), {}), eta=None)
s.on_error = None
- with self.assertRaises(OverflowError):
+ with pytest.raises(OverflowError):
s.enter_at(Entry(lambda: None, (), {}),
eta=datetime.now())
- exc = scratch[0]
- self.assertIsInstance(exc, OverflowError)
+ on_error.assert_called_once()
+ exc = on_error.call_args[0][0]
+ assert isinstance(exc, OverflowError)
def test_call_repeatedly(self):
t = Timer()
@@ -109,20 +103,20 @@ class test_Timer(Case):
myfun.__name__ = bytes_if_py2('myfun')
t.call_repeatedly(0.03, myfun)
- self.assertEqual(t.schedule.enter_after.call_count, 1)
+ assert t.schedule.enter_after.call_count == 1
args1, _ = t.schedule.enter_after.call_args_list[0]
sec1, tref1, _ = args1
- self.assertEqual(sec1, 0.03)
+ assert sec1 == 0.03
tref1()
- self.assertEqual(t.schedule.enter_after.call_count, 2)
+ assert t.schedule.enter_after.call_count == 2
args2, _ = t.schedule.enter_after.call_args_list[1]
sec2, tref2, _ = args2
- self.assertEqual(sec2, 0.03)
+ assert sec2 == 0.03
tref2.canceled = True
tref2()
- self.assertEqual(t.schedule.enter_after.call_count, 2)
+ assert t.schedule.enter_after.call_count == 2
finally:
t.stop()
@@ -137,8 +131,7 @@ class test_Timer(Case):
t.schedule.apply_entry(fun)
logger.error.assert_called()
- @mock.stdouts
- def test_apply_entry_error_not_handled(self, stdout, stderr):
+ def test_apply_entry_error_not_handled(self, stdouts):
t = Timer()
t.schedule.on_error = Mock()
@@ -146,7 +139,7 @@ class test_Timer(Case):
fun.side_effect = ValueError()
t.schedule.apply_entry(fun)
fun.assert_called_with()
- self.assertFalse(stderr.getvalue())
+ assert not stdouts.stderr.getvalue()
def test_enter_after(self):
t = Timer()
diff --git a/kombu/tests/case.py b/t/unit/case.py
index 057229e1..057229e1 100644
--- a/kombu/tests/case.py
+++ b/t/unit/case.py
diff --git a/t/unit/test_clocks.py b/t/unit/test_clocks.py
new file mode 100644
index 00000000..5ed30bf1
--- /dev/null
+++ b/t/unit/test_clocks.py
@@ -0,0 +1,89 @@
+from __future__ import absolute_import, unicode_literals
+
+import pickle
+
+from heapq import heappush
+from time import time
+
+from case import Mock
+
+from kombu.clocks import LamportClock, timetuple
+
+
+class test_LamportClock:
+
+ def test_clocks(self):
+ c1 = LamportClock()
+ c2 = LamportClock()
+
+ c1.forward()
+ c2.forward()
+ c1.forward()
+ c1.forward()
+ c2.adjust(c1.value)
+ assert c2.value == c1.value + 1
+ assert repr(c1)
+
+ c2_val = c2.value
+ c2.forward()
+ c2.forward()
+ c2.adjust(c1.value)
+ assert c2.value == c2_val + 2 + 1
+
+ c1.adjust(c2.value)
+ assert c1.value == c2.value + 1
+
+ def test_sort(self):
+ c = LamportClock()
+ pid1 = 'a.example.com:312'
+ pid2 = 'b.example.com:311'
+
+ events = []
+
+ m1 = (c.forward(), pid1)
+ heappush(events, m1)
+ m2 = (c.forward(), pid2)
+ heappush(events, m2)
+ m3 = (c.forward(), pid1)
+ heappush(events, m3)
+ m4 = (30, pid1)
+ heappush(events, m4)
+ m5 = (30, pid2)
+ heappush(events, m5)
+
+ assert str(c) == str(c.value)
+
+ assert c.sort_heap(events) == m1
+ assert c.sort_heap([m4, m5]) == m4
+ assert c.sort_heap([m4, m5, m1]) == m4
+
+
+class test_timetuple:
+
+ def test_repr(self):
+ x = timetuple(133, time(), 'id', Mock())
+ assert repr(x)
+
+ def test_pickleable(self):
+ x = timetuple(133, time(), 'id', 'obj')
+ assert pickle.loads(pickle.dumps(x)) == tuple(x)
+
+ def test_order(self):
+ t1 = time()
+ t2 = time() + 300 # windows clock not reliable
+ a = timetuple(133, t1, 'A', 'obj')
+ b = timetuple(140, t1, 'A', 'obj')
+ assert a.__getnewargs__()
+ assert a.clock == 133
+ assert a.timestamp == t1
+ assert a.id == 'A'
+ assert a.obj == 'obj'
+ assert a <= b
+ assert b >= a
+
+ assert (timetuple(134, time(), 'A', 'obj').__lt__(tuple()) is
+ NotImplemented)
+ assert timetuple(134, t2, 'A', 'obj') > timetuple(133, t1, 'A', 'obj')
+ assert timetuple(134, t1, 'B', 'obj') > timetuple(134, t1, 'A', 'obj')
+ assert (timetuple(None, t2, 'B', 'obj') >
+ timetuple(None, t1, 'A', 'obj'))
diff --git a/kombu/tests/test_common.py b/t/unit/test_common.py
index acbba0eb..8aee3d9e 100644
--- a/kombu/tests/test_common.py
+++ b/t/unit/test_common.py
@@ -1,8 +1,10 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import socket
from amqp import RecoverableConnectionError
+from case import ContextMock, Mock, patch
from kombu import common
from kombu.common import (
@@ -12,64 +14,61 @@ from kombu.common import (
QoS, PREFETCH_COUNT_MAX,
)
-from .case import Case, ContextMock, Mock, MockPool, patch
+from t.mocks import MockPool
-class test_ignore_errors(Case):
+def test_ignore_errors():
+ connection = Mock()
+ connection.channel_errors = (KeyError,)
+ connection.connection_errors = (KeyError,)
- def test_ignored(self):
- connection = Mock()
- connection.channel_errors = (KeyError,)
- connection.connection_errors = (KeyError,)
+ with ignore_errors(connection):
+ raise KeyError()
- with ignore_errors(connection):
- raise KeyError()
-
- def raising():
- raise KeyError()
+ def raising():
+ raise KeyError()
- ignore_errors(connection, raising)
+ ignore_errors(connection, raising)
- connection.channel_errors = connection.connection_errors = \
- ()
+ connection.channel_errors = connection.connection_errors = ()
- with self.assertRaises(KeyError):
- with ignore_errors(connection):
- raise KeyError()
+ with pytest.raises(KeyError):
+ with ignore_errors(connection):
+ raise KeyError()
-class test_declaration_cached(Case):
+class test_declaration_cached:
def test_when_cached(self):
chan = Mock()
chan.connection.client.declared_entities = ['foo']
- self.assertTrue(declaration_cached('foo', chan))
+ assert declaration_cached('foo', chan)
def test_when_not_cached(self):
chan = Mock()
chan.connection.client.declared_entities = ['bar']
- self.assertFalse(declaration_cached('foo', chan))
+ assert not declaration_cached('foo', chan)
-class test_Broadcast(Case):
+class test_Broadcast:
def test_arguments(self):
q = Broadcast(name='test_Broadcast')
- self.assertTrue(q.name.startswith('bcast.'))
- self.assertEqual(q.alias, 'test_Broadcast')
- self.assertTrue(q.auto_delete)
- self.assertEqual(q.exchange.name, 'test_Broadcast')
- self.assertEqual(q.exchange.type, 'fanout')
+ assert q.name.startswith('bcast.')
+ assert q.alias == 'test_Broadcast'
+ assert q.auto_delete
+ assert q.exchange.name == 'test_Broadcast'
+ assert q.exchange.type == 'fanout'
q = Broadcast('test_Broadcast', 'explicit_queue_name')
- self.assertEqual(q.name, 'explicit_queue_name')
- self.assertEqual(q.exchange.name, 'test_Broadcast')
+ assert q.name == 'explicit_queue_name'
+ assert q.exchange.name == 'test_Broadcast'
q2 = q(Mock())
- self.assertEqual(q2.name, q.name)
+ assert q2.name == q.name
-class test_maybe_declare(Case):
+class test_maybe_declare:
def test_cacheable(self):
channel = Mock()
@@ -82,16 +81,14 @@ class test_maybe_declare(Case):
entity.channel = channel
maybe_declare(entity, channel)
- self.assertEqual(entity.declare.call_count, 1)
- self.assertIn(
- hash(entity), channel.connection.client.declared_entities,
- )
+ assert entity.declare.call_count == 1
+ assert hash(entity) in channel.connection.client.declared_entities
maybe_declare(entity, channel)
- self.assertEqual(entity.declare.call_count, 1)
+ assert entity.declare.call_count == 1
entity.channel.connection = None
- with self.assertRaises(RecoverableConnectionError):
+ with pytest.raises(RecoverableConnectionError):
maybe_declare(entity)
def test_binds_entities(self):
@@ -116,10 +113,10 @@ class test_maybe_declare(Case):
entity.channel = channel
maybe_declare(entity, channel, retry=True)
- self.assertTrue(channel.connection.client.ensure.call_count)
+ assert channel.connection.client.ensure.call_count
-class test_replies(Case):
+class test_replies:
def test_send_reply(self):
req = Mock()
@@ -136,16 +133,18 @@ class test_replies(Case):
producer.channel.connection.client.declared_entities = set()
send_reply(exchange, req, {'hello': 'world'}, producer)
- self.assertTrue(producer.publish.call_count)
+ assert producer.publish.call_count
args = producer.publish.call_args
- self.assertDictEqual(args[0][0], {'hello': 'world'})
- self.assertDictEqual(args[1], {'exchange': exchange,
- 'routing_key': 'hello',
- 'correlation_id': 'world',
- 'serializer': 'json',
- 'retry': False,
- 'retry_policy': None,
- 'content_encoding': 'binary'})
+ assert args[0][0] == {'hello': 'world'}
+ assert args[1] == {
+ 'exchange': exchange,
+ 'routing_key': 'hello',
+ 'correlation_id': 'world',
+ 'serializer': 'json',
+ 'retry': False,
+ 'retry_policy': None,
+ 'content_encoding': 'binary',
+ }
@patch('kombu.common.itermessages')
def test_collect_replies_with_ack(self, itermessages):
@@ -154,11 +153,11 @@ class test_replies(Case):
itermessages.return_value = [(body, message)]
it = collect_replies(conn, channel, queue, no_ack=False)
m = next(it)
- self.assertIs(m, body)
+ assert m is body
itermessages.assert_called_with(conn, channel, queue, no_ack=False)
message.ack.assert_called_with()
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
channel.after_reply_message_received.assert_called_with(queue.name)
@@ -170,7 +169,7 @@ class test_replies(Case):
itermessages.return_value = [(body, message)]
it = collect_replies(conn, channel, queue)
m = next(it)
- self.assertIs(m, body)
+ assert m is body
itermessages.assert_called_with(conn, channel, queue, no_ack=True)
message.ack.assert_not_called()
@@ -179,13 +178,12 @@ class test_replies(Case):
conn, channel, queue = Mock(), Mock(), Mock()
itermessages.return_value = []
it = collect_replies(conn, channel, queue)
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
-
channel.after_reply_message_received.assert_not_called()
-class test_insured(Case):
+class test_insured:
@patch('kombu.common.logger')
def test_ensure_errback(self, logger):
@@ -212,22 +210,21 @@ class test_insured(Case):
conn, pool, fun, insured = self.get_insured_mocks()
ret = common.insured(pool, fun, (2, 2), {'foo': 'bar'})
- self.assertEqual(ret, 'works')
+ assert ret == 'works'
conn.ensure_connection.assert_called_with(
errback=common._ensure_errback,
)
insured.assert_called()
i_args, i_kwargs = insured.call_args
- self.assertTupleEqual(i_args, (2, 2))
- self.assertDictEqual(i_kwargs, {'foo': 'bar',
- 'connection': conn})
+ assert i_args == (2, 2)
+ assert i_kwargs == {'foo': 'bar', 'connection': conn}
conn.autoretry.assert_called()
ar_args, ar_kwargs = conn.autoretry.call_args
- self.assertTupleEqual(ar_args, (fun, conn.default_channel))
- self.assertTrue(ar_kwargs.get('on_revive'))
- self.assertTrue(ar_kwargs.get('errback'))
+ assert ar_args == (fun, conn.default_channel)
+ assert ar_kwargs.get('on_revive')
+ assert ar_kwargs.get('errback')
def test_insured_custom_errback(self):
conn, pool, fun, insured = self.get_insured_mocks()
@@ -254,7 +251,7 @@ class MockConsumer(object):
self.consumers.discard(self)
-class test_itermessages(Case):
+class test_itermessages:
class MockConnection(object):
should_raise_timeout = False
@@ -274,9 +271,9 @@ class test_itermessages(Case):
it = common.itermessages(conn, channel, 'q', limit=1)
ret = next(it)
- self.assertTupleEqual(ret, ('body', 'message'))
+ assert ret == ('body', 'message')
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
def test_when_raises_socket_timeout(self):
@@ -287,7 +284,7 @@ class test_itermessages(Case):
conn.Consumer = MockConsumer
it = common.itermessages(conn, channel, 'q', limit=1)
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
@patch('kombu.common.deque')
@@ -299,11 +296,11 @@ class test_itermessages(Case):
conn.Consumer = MockConsumer
it = common.itermessages(conn, channel, 'q', limit=1)
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
-class test_QoS(Case):
+class test_QoS:
class _QoS(QoS):
def __init__(self, value):
@@ -326,20 +323,20 @@ class test_QoS(Case):
def test_qos_increment_decrement(self):
qos = self._QoS(10)
- self.assertEqual(qos.increment_eventually(), 11)
- self.assertEqual(qos.increment_eventually(3), 14)
- self.assertEqual(qos.increment_eventually(-30), 14)
- self.assertEqual(qos.decrement_eventually(7), 7)
- self.assertEqual(qos.decrement_eventually(), 6)
+ assert qos.increment_eventually() == 11
+ assert qos.increment_eventually(3) == 14
+ assert qos.increment_eventually(-30) == 14
+ assert qos.decrement_eventually(7) == 7
+ assert qos.decrement_eventually() == 6
def test_qos_disabled_increment_decrement(self):
qos = self._QoS(0)
- self.assertEqual(qos.increment_eventually(), 0)
- self.assertEqual(qos.increment_eventually(3), 0)
- self.assertEqual(qos.increment_eventually(-30), 0)
- self.assertEqual(qos.decrement_eventually(7), 0)
- self.assertEqual(qos.decrement_eventually(), 0)
- self.assertEqual(qos.decrement_eventually(10), 0)
+ assert qos.increment_eventually() == 0
+ assert qos.increment_eventually(3) == 0
+ assert qos.increment_eventually(-30) == 0
+ assert qos.decrement_eventually(7) == 0
+ assert qos.decrement_eventually() == 0
+ assert qos.decrement_eventually(10) == 0
def test_qos_thread_safe(self):
qos = self._QoS(10)
@@ -361,59 +358,59 @@ class test_QoS(Case):
thread.join()
threaded([add, add])
- self.assertEqual(qos.value, 2010)
+ assert qos.value == 2010
qos.value = 1000
threaded([add, sub]) # n = 2
- self.assertEqual(qos.value, 1000)
+ assert qos.value == 1000
def test_exceeds_short(self):
qos = QoS(Mock(), PREFETCH_COUNT_MAX - 1)
qos.update()
- self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+ assert qos.value == PREFETCH_COUNT_MAX - 1
qos.increment_eventually()
- self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+ assert qos.value == PREFETCH_COUNT_MAX
qos.increment_eventually()
- self.assertEqual(qos.value, PREFETCH_COUNT_MAX + 1)
+ assert qos.value == PREFETCH_COUNT_MAX + 1
qos.decrement_eventually()
- self.assertEqual(qos.value, PREFETCH_COUNT_MAX)
+ assert qos.value == PREFETCH_COUNT_MAX
qos.decrement_eventually()
- self.assertEqual(qos.value, PREFETCH_COUNT_MAX - 1)
+ assert qos.value == PREFETCH_COUNT_MAX - 1
def test_consumer_increment_decrement(self):
mconsumer = Mock()
qos = QoS(mconsumer.qos, 10)
qos.update()
- self.assertEqual(qos.value, 10)
+ assert qos.value == 10
mconsumer.qos.assert_called_with(prefetch_count=10)
qos.decrement_eventually()
qos.update()
- self.assertEqual(qos.value, 9)
+ assert qos.value == 9
mconsumer.qos.assert_called_with(prefetch_count=9)
qos.decrement_eventually()
- self.assertEqual(qos.value, 8)
+ assert qos.value == 8
mconsumer.qos.assert_called_with(prefetch_count=9)
- self.assertIn({'prefetch_count': 9}, mconsumer.qos.call_args)
+ assert {'prefetch_count': 9} in mconsumer.qos.call_args
# Does not decrement 0 value
qos.value = 0
qos.decrement_eventually()
- self.assertEqual(qos.value, 0)
+ assert qos.value == 0
qos.increment_eventually()
- self.assertEqual(qos.value, 0)
+ assert qos.value == 0
def test_consumer_decrement_eventually(self):
mconsumer = Mock()
qos = QoS(mconsumer.qos, 10)
qos.decrement_eventually()
- self.assertEqual(qos.value, 9)
+ assert qos.value == 9
qos.value = 0
qos.decrement_eventually()
- self.assertEqual(qos.value, 0)
+ assert qos.value == 0
def test_set(self):
mconsumer = Mock()
qos = QoS(mconsumer.qos, 10)
qos.set(12)
- self.assertEqual(qos.prev, 12)
+ assert qos.prev == 12
qos.set(qos.prev)
diff --git a/kombu/tests/test_compat.py b/t/unit/test_compat.py
index c8abc85a..485625d0 100644
--- a/kombu/tests/test_compat.py
+++ b/t/unit/test_compat.py
@@ -1,13 +1,16 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock, patch
+
from kombu import Connection, Exchange, Queue
from kombu import compat
-from .case import Case, Mock, patch
-from .mocks import Transport, Channel
+from t.mocks import Transport, Channel
-class test_misc(Case):
+class test_misc:
def test_iterconsume(self):
@@ -27,11 +30,11 @@ class test_misc(Case):
conn = MyConnection()
consumer = Consumer()
it = compat._iterconsume(conn, consumer)
- self.assertEqual(next(it), 1)
- self.assertTrue(consumer.active)
+ assert next(it) == 1
+ assert consumer.active
it2 = compat._iterconsume(conn, consumer, limit=10)
- self.assertEqual(list(it2), [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
+ assert list(it2), [2, 3, 4, 5, 6, 7, 8, 9, 10 == 11]
def test_Queue_from_dict(self):
defs = {'binding_key': 'foo.#',
@@ -41,40 +44,40 @@ class test_misc(Case):
'auto_delete': False}
q1 = Queue.from_dict('foo', **dict(defs))
- self.assertEqual(q1.name, 'foo')
- self.assertEqual(q1.routing_key, 'foo.#')
- self.assertEqual(q1.exchange.name, 'fooex')
- self.assertEqual(q1.exchange.type, 'topic')
- self.assertTrue(q1.durable)
- self.assertTrue(q1.exchange.durable)
- self.assertFalse(q1.auto_delete)
- self.assertFalse(q1.exchange.auto_delete)
+ assert q1.name == 'foo'
+ assert q1.routing_key == 'foo.#'
+ assert q1.exchange.name == 'fooex'
+ assert q1.exchange.type == 'topic'
+ assert q1.durable
+ assert q1.exchange.durable
+ assert not q1.auto_delete
+ assert not q1.exchange.auto_delete
q2 = Queue.from_dict('foo', **dict(defs,
exchange_durable=False))
- self.assertTrue(q2.durable)
- self.assertFalse(q2.exchange.durable)
+ assert q2.durable
+ assert not q2.exchange.durable
q3 = Queue.from_dict('foo', **dict(defs,
exchange_auto_delete=True))
- self.assertFalse(q3.auto_delete)
- self.assertTrue(q3.exchange.auto_delete)
+ assert not q3.auto_delete
+ assert q3.exchange.auto_delete
q4 = Queue.from_dict('foo', **dict(defs,
queue_durable=False))
- self.assertFalse(q4.durable)
- self.assertTrue(q4.exchange.durable)
+ assert not q4.durable
+ assert q4.exchange.durable
q5 = Queue.from_dict('foo', **dict(defs,
queue_auto_delete=True))
- self.assertTrue(q5.auto_delete)
- self.assertFalse(q5.exchange.auto_delete)
+ assert q5.auto_delete
+ assert not q5.exchange.auto_delete
- self.assertEqual(Queue.from_dict('foo', **dict(defs)),
- Queue.from_dict('foo', **dict(defs)))
+ assert (Queue.from_dict('foo', **dict(defs)) ==
+ Queue.from_dict('foo', **dict(defs)))
-class test_Publisher(Case):
+class test_Publisher:
def setup(self):
self.connection = Connection(transport=Transport)
@@ -83,25 +86,25 @@ class test_Publisher(Case):
pub = compat.Publisher(self.connection,
exchange='test_Publisher_constructor',
routing_key='rkey')
- self.assertIsInstance(pub.backend, Channel)
- self.assertEqual(pub.exchange.name, 'test_Publisher_constructor')
- self.assertTrue(pub.exchange.durable)
- self.assertFalse(pub.exchange.auto_delete)
- self.assertEqual(pub.exchange.type, 'direct')
+ assert isinstance(pub.backend, Channel)
+ assert pub.exchange.name == 'test_Publisher_constructor'
+ assert pub.exchange.durable
+ assert not pub.exchange.auto_delete
+ assert pub.exchange.type == 'direct'
pub2 = compat.Publisher(self.connection,
exchange='test_Publisher_constructor2',
routing_key='rkey',
auto_delete=True,
durable=False)
- self.assertTrue(pub2.exchange.auto_delete)
- self.assertFalse(pub2.exchange.durable)
+ assert pub2.exchange.auto_delete
+ assert not pub2.exchange.durable
explicit = Exchange('test_Publisher_constructor_explicit',
type='topic')
pub3 = compat.Publisher(self.connection,
exchange=explicit)
- self.assertEqual(pub3.exchange, explicit)
+ assert pub3.exchange == explicit
compat.Publisher(self.connection,
exchange='test_Publisher_constructor3',
@@ -112,7 +115,7 @@ class test_Publisher(Case):
exchange='test_Publisher_send',
routing_key='rkey')
pub.send({'foo': 'bar'})
- self.assertIn('basic_publish', pub.backend)
+ assert 'basic_publish' in pub.backend
pub.close()
def test__enter__exit__(self):
@@ -120,12 +123,12 @@ class test_Publisher(Case):
exchange='test_Publisher_send',
routing_key='rkey')
x = pub.__enter__()
- self.assertIs(x, pub)
+ assert x is pub
x.__exit__()
- self.assertTrue(pub._closed)
+ assert pub._closed
-class test_Consumer(Case):
+class test_Consumer:
def setup(self):
self.connection = Connection(transport=Transport)
@@ -139,39 +142,39 @@ class test_Consumer(Case):
def test_constructor(self, n='test_Consumer_constructor'):
c = compat.Consumer(self.connection, queue=n, exchange=n,
routing_key='rkey')
- self.assertIsInstance(c.backend, Channel)
+ assert isinstance(c.backend, Channel)
q = c.queues[0]
- self.assertTrue(q.durable)
- self.assertTrue(q.exchange.durable)
- self.assertFalse(q.auto_delete)
- self.assertFalse(q.exchange.auto_delete)
- self.assertEqual(q.name, n)
- self.assertEqual(q.exchange.name, n)
+ assert q.durable
+ assert q.exchange.durable
+ assert not q.auto_delete
+ assert not q.exchange.auto_delete
+ assert q.name == n
+ assert q.exchange.name == n
c2 = compat.Consumer(self.connection, queue=n + '2',
exchange=n + '2',
routing_key='rkey', durable=False,
auto_delete=True, exclusive=True)
q2 = c2.queues[0]
- self.assertFalse(q2.durable)
- self.assertFalse(q2.exchange.durable)
- self.assertTrue(q2.auto_delete)
- self.assertTrue(q2.exchange.auto_delete)
+ assert not q2.durable
+ assert not q2.exchange.durable
+ assert q2.auto_delete
+ assert q2.exchange.auto_delete
def test__enter__exit__(self, n='test__enter__exit__'):
c = compat.Consumer(self.connection, queue=n, exchange=n,
routing_key='rkey')
x = c.__enter__()
- self.assertIs(x, c)
+ assert x is c
x.__exit__()
- self.assertTrue(c._closed)
+ assert c._closed
def test_revive(self, n='test_revive'):
c = compat.Consumer(self.connection, queue=n, exchange=n)
with self.connection.channel() as c2:
c.revive(c2)
- self.assertIs(c.backend, c2)
+ assert c.backend is c2
def test__iter__(self, n='test__iter__'):
c = compat.Consumer(self.connection, queue=n, exchange=n)
@@ -188,7 +191,7 @@ class test_Consumer(Case):
def test_process_next(self, n='test_process_next'):
c = compat.Consumer(self.connection, queue=n, exchange=n,
routing_key='rkey')
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
c.process_next()
c.close()
@@ -201,14 +204,14 @@ class test_Consumer(Case):
c = compat.Consumer(self.connection, queue=n, exchange=n,
routing_key='rkey')
c.discard_all()
- self.assertIn('queue_purge', c.backend)
+ assert 'queue_purge' in c.backend
def test_fetch(self, n='test_fetch'):
c = compat.Consumer(self.connection, queue=n, exchange=n,
routing_key='rkey')
- self.assertIsNone(c.fetch())
- self.assertIsNone(c.fetch(no_ack=True))
- self.assertIn('basic_get', c.backend)
+ assert c.fetch() is None
+ assert c.fetch(no_ack=True) is None
+ assert 'basic_get' in c.backend
callback_called = [False]
@@ -217,16 +220,16 @@ class test_Consumer(Case):
c.backend.to_deliver.append('42')
payload = c.fetch().payload
- self.assertEqual(payload, '42')
+ assert payload == '42'
c.backend.to_deliver.append('46')
c.register_callback(receive)
- self.assertEqual(c.fetch(enable_callbacks=True).payload, '46')
- self.assertTrue(callback_called[0])
+ assert c.fetch(enable_callbacks=True).payload == '46'
+ assert callback_called[0]
def test_discard_all_filterfunc_not_supported(self, n='xjf21j21'):
c = compat.Consumer(self.connection, queue=n, exchange=n,
routing_key='rkey')
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
c.discard_all(filterfunc=lambda x: x)
c.close()
@@ -240,7 +243,7 @@ class test_Consumer(Case):
c = C(self.connection,
queue=n, exchange=n, routing_key='rkey')
- self.assertEqual(c.wait(10), list(range(10)))
+ assert c.wait(10) == list(range(10))
c.close()
def test_iterqueue(self, n='test_iterqueue'):
@@ -255,11 +258,11 @@ class test_Consumer(Case):
c = C(self.connection,
queue=n, exchange=n, routing_key='rkey')
- self.assertEqual(list(c.iterqueue(limit=10)), list(range(10)))
+ assert list(c.iterqueue(limit=10)) == list(range(10))
c.close()
-class test_ConsumerSet(Case):
+class test_ConsumerSet:
def setup(self):
self.connection = Connection(transport=Transport)
@@ -267,8 +270,8 @@ class test_ConsumerSet(Case):
def test_providing_channel(self):
chan = Mock(name='channel')
cs = compat.ConsumerSet(self.connection, channel=chan)
- self.assertTrue(cs._provided_channel)
- self.assertIs(cs.backend, chan)
+ assert cs._provided_channel
+ assert cs.backend is chan
cs.cancel = Mock(name='cancel')
cs.close()
@@ -287,7 +290,7 @@ class test_ConsumerSet(Case):
with self.connection.channel() as c2:
cs.revive(c2)
- self.assertIs(cs.backend, c2)
+ assert cs.backend is c2
def test_constructor(self, prefix='0daf8h21'):
dcon = {'%s.xyx' % prefix: {'exchange': '%s.xyx' % prefix,
@@ -300,31 +303,31 @@ class test_ConsumerSet(Case):
c = compat.ConsumerSet(self.connection, consumers=consumers)
c2 = compat.ConsumerSet(self.connection, from_dict=dcon)
- self.assertEqual(len(c.queues), 3)
- self.assertEqual(len(c2.queues), 2)
+ assert len(c.queues) == 3
+ assert len(c2.queues) == 2
c.add_consumer(compat.Consumer(self.connection,
queue=prefix + 'xaxxxa',
exchange=prefix + 'xaxxxa'))
- self.assertEqual(len(c.queues), 4)
+ assert len(c.queues) == 4
for cq in c.queues:
- self.assertIs(cq.channel, c.channel)
+ assert cq.channel is c.channel
c2.add_consumer_from_dict(
'%s.xxx' % prefix,
exchange='%s.xxx' % prefix,
routing_key='xxx',
)
- self.assertEqual(len(c2.queues), 3)
+ assert len(c2.queues) == 3
for c2q in c2.queues:
- self.assertIs(c2q.channel, c2.channel)
+ assert c2q.channel is c2.channel
c.discard_all()
- self.assertEqual(c.channel.called.count('queue_purge'), 4)
+ assert c.channel.called.count('queue_purge') == 4
c.consume()
c.close()
c2.close()
- self.assertIn('basic_cancel', c.channel)
- self.assertIn('close', c.channel)
- self.assertIn('close', c2.channel)
+ assert 'basic_cancel' in c.channel
+ assert 'close' in c.channel
+ assert 'close' in c2.channel
diff --git a/kombu/tests/test_compression.py b/t/unit/test_compression.py
index d98da690..f3fc625e 100644
--- a/kombu/tests/test_compression.py
+++ b/t/unit/test_compression.py
@@ -2,41 +2,41 @@ from __future__ import absolute_import, unicode_literals
import sys
-from kombu import compression
+from case import mock, skip
-from .case import Case, mock, skip
+from kombu import compression
-class test_compression(Case):
+class test_compression:
@mock.mask_modules('bz2')
def test_no_bz2(self):
c = sys.modules.pop('kombu.compression')
try:
import kombu.compression
- self.assertFalse(hasattr(kombu.compression, 'bz2'))
+ assert not hasattr(kombu.compression, 'bz2')
finally:
if c is not None:
sys.modules['kombu.compression'] = c
def test_encoders__gzip(self):
- self.assertIn('application/x-gzip', compression.encoders())
+ assert 'application/x-gzip' in compression.encoders()
@skip.unless_module('bz2')
def test_encoders__bz2(self):
- self.assertIn('application/x-bz2', compression.encoders())
+ assert 'application/x-bz2' in compression.encoders()
def test_compress__decompress__zlib(self):
text = b'The Quick Brown Fox Jumps Over The Lazy Dog'
c, ctype = compression.compress(text, 'zlib')
- self.assertNotEqual(text, c)
+ assert text != c
d = compression.decompress(c, ctype)
- self.assertEqual(d, text)
+ assert d == text
@skip.unless_module('bz2')
def test_compress__decompress__bzip2(self):
text = b'The Brown Quick Fox Over The Lazy Dog Jumps'
c, ctype = compression.compress(text, 'bzip2')
- self.assertNotEqual(text, c)
+ assert text != c
d = compression.decompress(c, ctype)
- self.assertEqual(d, text)
+ assert d == text
diff --git a/kombu/tests/test_connection.py b/t/unit/test_connection.py
index 548263d8..576181cf 100644
--- a/kombu/tests/test_connection.py
+++ b/t/unit/test_connection.py
@@ -1,21 +1,23 @@
from __future__ import absolute_import, unicode_literals
import pickle
+import pytest
import socket
from copy import copy, deepcopy
+from case import Mock, patch, skip
+
from kombu import Connection, Consumer, Producer, parse_url
from kombu.connection import Resource
from kombu.exceptions import OperationalError
from kombu.five import items, range
from kombu.utils.functional import lazy
-from .case import Case, Mock, patch, skip
-from .mocks import Transport
+from t.mocks import Transport
-class test_connection_utils(Case):
+class test_connection_utils:
def setup(self):
self.url = 'amqp://user:pass@localhost:5672/my/vhost'
@@ -31,106 +33,74 @@ class test_connection_utils(Case):
def test_parse_url(self):
result = parse_url(self.url)
- self.assertDictEqual(result, self.expected)
+ assert result == self.expected
def test_parse_generated_as_uri(self):
conn = Connection(self.url)
info = conn.info()
for k, v in self.expected.items():
- self.assertEqual(info[k], v)
+ assert info[k] == v
# by default almost the same- no password
- self.assertEqual(conn.as_uri(), self.nopass)
- self.assertEqual(conn.as_uri(include_password=True), self.url)
+ assert conn.as_uri() == self.nopass
+ assert conn.as_uri(include_password=True) == self.url
@skip.unless_module('redis')
def test_as_uri_when_prefix(self):
conn = Connection('redis+socket:///var/spool/x/y/z/redis.sock')
- self.assertEqual(
- conn.as_uri(), 'redis+socket:///var/spool/x/y/z/redis.sock',
- )
+ assert conn.as_uri() == 'redis+socket:///var/spool/x/y/z/redis.sock'
@skip.unless_module('pymongo')
def test_as_uri_when_mongodb(self):
x = Connection('mongodb://localhost')
- self.assertTrue(x.as_uri())
+ assert x.as_uri()
def test_bogus_scheme(self):
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
Connection('bogus://localhost:7421').transport
def assert_info(self, conn, **fields):
info = conn.info()
for field, expected in items(fields):
- self.assertEqual(info[field], expected)
-
- def test_rabbitmq_example_urls(self):
+ assert info[field] == expected
+
+ @pytest.mark.parametrize('url,expected', [
+ ('amqp://user:pass@host:10000/vhost',
+ dict(userid='user', password='pass', hostname='host',
+ port=10000, virtual_host='vhost')),
+ ('amqp://user%61:%61pass@ho%61st:10000/v%2fhost',
+ dict(userid='usera', password='apass', hostname='hoast',
+ port=10000, virtual_host='v/host')),
+ ('amqp://',
+ dict(userid='guest', password='guest', hostname='localhost',
+ port=5672, virtual_host='/')),
+ ('amqp://:@/',
+ dict(userid='guest', password='guest', hostname='localhost',
+ port=5672, virtual_host='/')),
+ ('amqp://user@/',
+ dict(userid='user', password='guest', hostname='localhost',
+ port=5672, virtual_host='/')),
+ ('amqp://user:pass@/',
+ dict(userid='user', password='pass', hostname='localhost',
+ port=5672, virtual_host='/')),
+ ('amqp://host',
+ dict(userid='guest', password='guest', hostname='host',
+ port=5672, virtual_host='/')),
+ ('amqp://:10000',
+ dict(userid='guest', password='guest', hostname='localhost',
+ port=10000, virtual_host='/')),
+ ('amqp:///vhost',
+ dict(userid='guest', password='guest', hostname='localhost',
+ port=5672, virtual_host='vhost')),
+ ('amqp://host/',
+ dict(userid='guest', password='guest', hostname='host',
+ port=5672, virtual_host='/')),
+ ('amqp://host/%2f',
+ dict(userid='guest', password='guest', hostname='host',
+ port=5672, virtual_host='/')),
+ ])
+ def test_rabbitmq_example_urls(self, url, expected):
# see Appendix A of http://www.rabbitmq.com/uri-spec.html
-
- self.assert_info(
- Connection('amqp://user:pass@host:10000/vhost'),
- userid='user', password='pass', hostname='host',
- port=10000, virtual_host='vhost',
- )
-
- self.assert_info(
- Connection('amqp://user%61:%61pass@ho%61st:10000/v%2fhost'),
- userid='usera', password='apass', hostname='hoast',
- port=10000, virtual_host='v/host',
- )
-
- self.assert_info(
- Connection('amqp://'),
- userid='guest', password='guest', hostname='localhost',
- port=5672, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp://:@/'),
- userid='guest', password='guest', hostname='localhost',
- port=5672, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp://user@/'),
- userid='user', password='guest', hostname='localhost',
- port=5672, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp://user:pass@/'),
- userid='user', password='pass', hostname='localhost',
- port=5672, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp://host'),
- userid='guest', password='guest', hostname='host',
- port=5672, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp://:10000'),
- userid='guest', password='guest', hostname='localhost',
- port=10000, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp:///vhost'),
- userid='guest', password='guest', hostname='localhost',
- port=5672, virtual_host='vhost',
- )
-
- self.assert_info(
- Connection('amqp://host/'),
- userid='guest', password='guest', hostname='host',
- port=5672, virtual_host='/',
- )
-
- self.assert_info(
- Connection('amqp://host/%2f'),
- userid='guest', password='guest', hostname='host',
- port=5672, virtual_host='/',
- )
+ self.assert_info(Connection(url), **expected)
@skip.todo('urllib cannot parse ipv6 urls')
def test_url_IPV6(self):
@@ -143,10 +113,10 @@ class test_connection_utils(Case):
def test_connection_copy(self):
conn = Connection(self.url, alternates=['amqp://host'])
clone = deepcopy(conn)
- self.assertEqual(clone.alt, ['amqp://host'])
+ assert clone.alt == ['amqp://host']
-class test_Connection(Case):
+class test_Connection:
def setup(self):
self.conn = Connection(port=5672, transport=Transport)
@@ -154,24 +124,24 @@ class test_Connection(Case):
def test_establish_connection(self):
conn = self.conn
conn.connect()
- self.assertTrue(conn.connection.connected)
- self.assertEqual(conn.host, 'localhost:5672')
+ assert conn.connection.connected
+ assert conn.host == 'localhost:5672'
channel = conn.channel()
- self.assertTrue(channel.open)
- self.assertEqual(conn.drain_events(), 'event')
+ assert channel.open
+ assert conn.drain_events() == 'event'
_connection = conn.connection
conn.close()
- self.assertFalse(_connection.connected)
- self.assertIsInstance(conn.transport, Transport)
+ assert not _connection.connected
+ assert isinstance(conn.transport, Transport)
def test_multiple_urls(self):
conn1 = Connection('amqp://foo;amqp://bar')
- self.assertEqual(conn1.hostname, 'foo')
- self.assertListEqual(conn1.alt, ['amqp://foo', 'amqp://bar'])
+ assert conn1.hostname == 'foo'
+ assert conn1.alt == ['amqp://foo', 'amqp://bar']
conn2 = Connection(['amqp://foo', 'amqp://bar'])
- self.assertEqual(conn2.hostname, 'foo')
- self.assertListEqual(conn2.alt, ['amqp://foo', 'amqp://bar'])
+ assert conn2.hostname == 'foo'
+ assert conn2.alt == ['amqp://foo', 'amqp://bar']
def test_collect(self):
connection = Connection('memory://')
@@ -185,9 +155,9 @@ class test_Connection(Case):
_close.assert_not_called()
_collect.assert_called_with(uconn)
connection.declared_entities.clear.assert_called_with()
- self.assertIsNone(trans.client)
- self.assertIsNone(connection._transport)
- self.assertIsNone(connection._connection)
+ assert trans.client is None
+ assert connection._transport is None
+ assert connection._connection is None
def test_collect_no_transport(self):
connection = Connection('memory://')
@@ -213,7 +183,7 @@ class test_Connection(Case):
connection.collect()
collect.assert_called_with(uconn)
- self.assertIsNone(connection._transport)
+ assert connection._transport is None
def test_uri_passthrough(self):
transport = Mock(name='transport')
@@ -222,17 +192,17 @@ class test_Connection(Case):
transport.can_parse_url = True
with patch('kombu.connection.parse_url') as parse_url:
c = Connection('foo+mysql://some_host')
- self.assertEqual(c.transport_cls, 'foo')
+ assert c.transport_cls == 'foo'
parse_url.assert_not_called()
- self.assertEqual(c.hostname, 'mysql://some_host')
- self.assertTrue(c.as_uri().startswith('foo+'))
+ assert c.hostname == 'mysql://some_host'
+ assert c.as_uri().startswith('foo+')
with patch('kombu.connection.parse_url') as parse_url:
c = Connection('mysql://some_host', transport='foo')
- self.assertEqual(c.transport_cls, 'foo')
+ assert c.transport_cls == 'foo'
parse_url.assert_not_called()
- self.assertEqual(c.hostname, 'mysql://some_host')
+ assert c.hostname == 'mysql://some_host'
c = Connection('pyamqp+sqlite://some_host')
- self.assertTrue(c.as_uri().startswith('pyamqp+'))
+ assert c.as_uri().startswith('pyamqp+')
def test_default_ensure_callback(self):
with patch('kombu.connection.logger') as logger:
@@ -249,31 +219,31 @@ class test_Connection(Case):
args = rot.call_args[0]
cb = args[4]
intervals = iter([1, 2, 3, 4, 5])
- self.assertEqual(cb(KeyError(), intervals, 0), 0)
- self.assertEqual(cb(KeyError(), intervals, 1), 1)
- self.assertEqual(cb(KeyError(), intervals, 2), 0)
- self.assertEqual(cb(KeyError(), intervals, 3), 2)
- self.assertEqual(cb(KeyError(), intervals, 4), 0)
- self.assertEqual(cb(KeyError(), intervals, 5), 3)
- self.assertEqual(cb(KeyError(), intervals, 6), 0)
- self.assertEqual(cb(KeyError(), intervals, 7), 4)
+ assert cb(KeyError(), intervals, 0) == 0
+ assert cb(KeyError(), intervals, 1) == 1
+ assert cb(KeyError(), intervals, 2) == 0
+ assert cb(KeyError(), intervals, 3) == 2
+ assert cb(KeyError(), intervals, 4) == 0
+ assert cb(KeyError(), intervals, 5) == 3
+ assert cb(KeyError(), intervals, 6) == 0
+ assert cb(KeyError(), intervals, 7) == 4
errback = Mock()
c.ensure_connection(errback=errback)
args = rot.call_args[0]
cb = args[4]
- self.assertEqual(cb(KeyError(), intervals, 0), 0)
+ assert cb(KeyError(), intervals, 0) == 0
errback.assert_called()
def test_supports_heartbeats(self):
c = Connection(transport=Mock)
c.transport.implements.heartbeats = False
- self.assertFalse(c.supports_heartbeats)
+ assert not c.supports_heartbeats
def test_is_evented(self):
c = Connection(transport=Mock)
c.transport.implements.async = False
- self.assertFalse(c.is_evented)
+ assert not c.is_evented
def test_register_with_event_loop(self):
c = Connection(transport=Mock)
@@ -285,41 +255,41 @@ class test_Connection(Case):
def test_manager(self):
c = Connection(transport=Mock)
- self.assertIs(c.manager, c.transport.manager)
+ assert c.manager is c.transport.manager
def test_copy(self):
c = Connection('amqp://example.com')
- self.assertEqual(copy(c).info(), c.info())
+ assert copy(c).info() == c.info()
def test_copy_multiples(self):
c = Connection('amqp://A.example.com;amqp://B.example.com')
- self.assertTrue(c.alt)
+ assert c.alt
d = copy(c)
- self.assertEqual(d.alt, c.alt)
+ assert d.alt == c.alt
def test_switch(self):
c = Connection('amqp://foo')
c._closed = True
c.switch('redis://example.com//3')
- self.assertFalse(c._closed)
- self.assertEqual(c.hostname, 'example.com')
- self.assertEqual(c.transport_cls, 'redis')
- self.assertEqual(c.virtual_host, '/3')
+ assert not c._closed
+ assert c.hostname == 'example.com'
+ assert c.transport_cls == 'redis'
+ assert c.virtual_host == '/3'
def test_maybe_switch_next(self):
c = Connection('amqp://foo;redis://example.com//3')
c.maybe_switch_next()
- self.assertFalse(c._closed)
- self.assertEqual(c.hostname, 'example.com')
- self.assertEqual(c.transport_cls, 'redis')
- self.assertEqual(c.virtual_host, '/3')
+ assert not c._closed
+ assert c.hostname == 'example.com'
+ assert c.transport_cls == 'redis'
+ assert c.virtual_host == '/3'
def test_maybe_switch_next_no_cycle(self):
c = Connection('amqp://foo')
c.maybe_switch_next()
- self.assertFalse(c._closed)
- self.assertEqual(c.hostname, 'foo')
- self.assertIn(c.transport_cls, ('librabbitmq', 'pyamqp', 'amqp'))
+ assert not c._closed
+ assert c.hostname == 'foo'
+ assert c.transport_cls, ('librabbitmq', 'pyamqp' in 'amqp')
def test_heartbeat_check(self):
c = Connection(transport=Transport)
@@ -329,45 +299,40 @@ class test_Connection(Case):
def test_completes_cycle_no_cycle(self):
c = Connection('amqp://')
- self.assertTrue(c.completes_cycle(0))
- self.assertTrue(c.completes_cycle(1))
+ assert c.completes_cycle(0)
+ assert c.completes_cycle(1)
def test_completes_cycle(self):
c = Connection('amqp://a;amqp://b;amqp://c')
- self.assertFalse(c.completes_cycle(0))
- self.assertFalse(c.completes_cycle(1))
- self.assertTrue(c.completes_cycle(2))
+ assert not c.completes_cycle(0)
+ assert not c.completes_cycle(1)
+ assert c.completes_cycle(2)
def test_get_heartbeat_interval(self):
self.conn.transport.get_heartbeat_interval = Mock(name='ghi')
- self.assertIs(
- self.conn.get_heartbeat_interval(),
- self.conn.transport.get_heartbeat_interval.return_value,
- )
+ assert (self.conn.get_heartbeat_interval() is
+ self.conn.transport.get_heartbeat_interval.return_value)
self.conn.transport.get_heartbeat_interval.assert_called_with(
self.conn.connection)
def test_supports_exchange_type(self):
self.conn.transport.implements.exchange_type = {'topic'}
- self.assertTrue(self.conn.supports_exchange_type('topic'))
- self.assertFalse(self.conn.supports_exchange_type('fanout'))
+ assert self.conn.supports_exchange_type('topic')
+ assert not self.conn.supports_exchange_type('fanout')
def test_qos_semantics_matches_spec(self):
qsms = self.conn.transport.qos_semantics_matches_spec = Mock()
- self.assertIs(
- self.conn.qos_semantics_matches_spec,
- qsms.return_value,
- )
+ assert self.conn.qos_semantics_matches_spec is qsms.return_value
qsms.assert_called_with(self.conn.connection)
def test__enter____exit__(self):
conn = self.conn
context = conn.__enter__()
- self.assertIs(context, conn)
+ assert context is conn
conn.connect()
- self.assertTrue(conn.connection.connected)
+ assert conn.connection.connected
conn.__exit__()
- self.assertIsNone(conn.connection)
+ assert conn.connection is None
conn.close() # again
def test_close_survives_connerror(self):
@@ -384,7 +349,7 @@ class test_Connection(Case):
conn = Connection(transport=MyTransport)
conn.connect()
conn.close()
- self.assertTrue(conn._closed)
+ assert conn._closed
def test_close_when_default_channel(self):
conn = self.conn
@@ -413,17 +378,17 @@ class test_Connection(Case):
conn.revive(Mock())
defchan.close.assert_called_with()
- self.assertIsNone(conn._default_channel)
+ assert conn._default_channel is None
def test_ensure_connection(self):
- self.assertTrue(self.conn.ensure_connection())
+ assert self.conn.ensure_connection()
def test_ensure_success(self):
def publish():
return 'foobar'
ensured = self.conn.ensure(None, publish)
- self.assertEqual(ensured(), 'foobar')
+ assert ensured() == 'foobar'
def test_ensure_failure(self):
class _CustomError(Exception):
@@ -433,7 +398,7 @@ class test_Connection(Case):
raise _CustomError('bar')
ensured = self.conn.ensure(None, publish)
- with self.assertRaises(_CustomError):
+ with pytest.raises(_CustomError):
ensured()
def test_ensure_connection_failure(self):
@@ -445,7 +410,7 @@ class test_Connection(Case):
self.conn.transport.connection_errors = (_ConnectionError,)
ensured = self.conn.ensure(self.conn, publish)
- with self.assertRaises(OperationalError):
+ with pytest.raises(OperationalError):
ensured()
def test_autoretry(self):
@@ -466,36 +431,37 @@ class test_Connection(Case):
def test_SimpleQueue(self):
conn = self.conn
q = conn.SimpleQueue('foo')
- self.assertIs(q.channel, conn.default_channel)
+ assert q.channel is conn.default_channel
chan = conn.channel()
q2 = conn.SimpleQueue('foo', channel=chan)
- self.assertIs(q2.channel, chan)
+ assert q2.channel is chan
def test_SimpleBuffer(self):
conn = self.conn
q = conn.SimpleBuffer('foo')
- self.assertIs(q.channel, conn.default_channel)
+ assert q.channel is conn.default_channel
chan = conn.channel()
q2 = conn.SimpleBuffer('foo', channel=chan)
- self.assertIs(q2.channel, chan)
+ assert q2.channel is chan
def test_Producer(self):
conn = self.conn
- self.assertIsInstance(conn.Producer(), Producer)
- self.assertIsInstance(conn.Producer(conn.default_channel), Producer)
+ assert isinstance(conn.Producer(), Producer)
+ assert isinstance(conn.Producer(conn.default_channel), Producer)
def test_Consumer(self):
conn = self.conn
- self.assertIsInstance(conn.Consumer(queues=[]), Consumer)
- self.assertIsInstance(conn.Consumer(queues=[],
- channel=conn.default_channel), Consumer)
+ assert isinstance(conn.Consumer(queues=[]), Consumer)
+ assert isinstance(
+ conn.Consumer(queues=[], channel=conn.default_channel),
+ Consumer)
def test__repr__(self):
- self.assertTrue(repr(self.conn))
+ assert repr(self.conn)
def test__reduce__(self):
x = pickle.loads(pickle.dumps(self.conn))
- self.assertDictEqual(x.info(), self.conn.info())
+ assert x.info() == self.conn.info()
def test_channel_errors(self):
@@ -503,7 +469,7 @@ class test_Connection(Case):
channel_errors = (KeyError, ValueError)
conn = Connection(transport=MyTransport)
- self.assertTupleEqual(conn.channel_errors, (KeyError, ValueError))
+ assert conn.channel_errors == (KeyError, ValueError)
def test_connection_errors(self):
@@ -511,10 +477,10 @@ class test_Connection(Case):
connection_errors = (KeyError, ValueError)
conn = Connection(transport=MyTransport)
- self.assertTupleEqual(conn.connection_errors, (KeyError, ValueError))
+ assert conn.connection_errors == (KeyError, ValueError)
-class test_Connection_with_transport_options(Case):
+class test_Connection_with_transport_options:
transport_options = {'pool_recycler': 3600, 'echo': True}
@@ -524,7 +490,7 @@ class test_Connection_with_transport_options(Case):
def test_establish_connection(self):
conn = self.conn
- self.assertEqual(conn.transport_options, self.transport_options)
+ assert conn.transport_options == self.transport_options
class xResource(Resource):
@@ -533,56 +499,46 @@ class xResource(Resource):
pass
-class ResourceCase(Case):
- abstract = True
+class ResourceCase:
def create_resource(self, limit):
raise NotImplementedError('subclass responsibility')
- def assertState(self, P, avail, dirty):
- self.assertEqual(P._resource.qsize(), avail)
- self.assertEqual(len(P._dirty), dirty)
+ def assert_state(self, P, avail, dirty):
+ assert P._resource.qsize() == avail
+ assert len(P._dirty) == dirty
def test_setup(self):
- if self.abstract:
- with self.assertRaises(NotImplementedError):
- Resource()
+ with pytest.raises(NotImplementedError):
+ Resource()
def test_acquire__release(self):
- if self.abstract:
- return
P = self.create_resource(10)
- self.assertState(P, 10, 0)
+ self.assert_state(P, 10, 0)
chans = [P.acquire() for _ in range(10)]
- self.assertState(P, 0, 10)
- with self.assertRaises(P.LimitExceeded):
+ self.assert_state(P, 0, 10)
+ with pytest.raises(P.LimitExceeded):
P.acquire()
chans.pop().release()
- self.assertState(P, 1, 9)
+ self.assert_state(P, 1, 9)
[chan.release() for chan in chans]
- self.assertState(P, 10, 0)
+ self.assert_state(P, 10, 0)
def test_acquire_prepare_raises(self):
- if self.abstract:
- return
P = self.create_resource(10)
- self.assertEqual(len(P._resource.queue), 10)
+ assert len(P._resource.queue) == 10
P.prepare = Mock()
P.prepare.side_effect = IOError()
- with self.assertRaises(IOError):
+ with pytest.raises(IOError):
P.acquire(block=True)
- self.assertEqual(len(P._resource.queue), 10)
+ assert len(P._resource.queue) == 10
def test_acquire_no_limit(self):
- if self.abstract:
- return
P = self.create_resource(None)
P.acquire().release()
def test_replace_when_limit(self):
- if self.abstract:
- return
P = self.create_resource(10)
r = P.acquire()
P._dirty = Mock()
@@ -593,8 +549,6 @@ class ResourceCase(Case):
P.close_resource.assert_called_with(r)
def test_replace_no_limit(self):
- if self.abstract:
- return
P = self.create_resource(None)
r = P.acquire()
P._dirty = Mock()
@@ -605,26 +559,20 @@ class ResourceCase(Case):
P.close_resource.assert_called_with(r)
def test_interface_prepare(self):
- if not self.abstract:
- return
x = xResource()
- self.assertEqual(x.prepare(10), 10)
+ assert x.prepare(10) == 10
def test_force_close_all_handles_AttributeError(self):
- if self.abstract:
- return
P = self.create_resource(10)
cr = P.collect_resource = Mock()
cr.side_effect = AttributeError('x')
P.acquire()
- self.assertTrue(P._dirty)
+ assert P._dirty
P.force_close_all()
def test_force_close_all_no_mutex(self):
- if self.abstract:
- return
P = self.create_resource(10)
P.close_resource = Mock()
@@ -635,17 +583,14 @@ class ResourceCase(Case):
P.force_close_all()
def test_add_when_empty(self):
- if self.abstract:
- return
P = self.create_resource(None)
P._resource.queue.clear()
- self.assertFalse(P._resource.queue)
+ assert not P._resource.queue
P._add_when_empty()
- self.assertTrue(P._resource.queue)
+ assert P._resource.queue
class test_ConnectionPool(ResourceCase):
- abstract = False
def create_resource(self, limit):
return Connection(port=5672, transport=Transport).Pool(limit)
@@ -666,9 +611,9 @@ class test_ConnectionPool(ResourceCase):
def test_setup(self):
P = self.create_resource(10)
q = P._resource.queue
- self.assertIsNone(q[0]()._connection)
- self.assertIsNone(q[1]()._connection)
- self.assertIsNone(q[2]()._connection)
+ assert q[0]()._connection is None
+ assert q[1]()._connection is None
+ assert q[2]()._connection is None
def test_acquire_raises_evaluated(self):
P = self.create_resource(1)
@@ -678,7 +623,7 @@ class test_ConnectionPool(ResourceCase):
P.prepare = Mock()
P.prepare.side_effect = MemoryError()
P.release = Mock()
- with self.assertRaises(MemoryError):
+ with pytest.raises(MemoryError):
with P.acquire():
assert False
P.release.assert_called_with(r)
@@ -691,22 +636,21 @@ class test_ConnectionPool(ResourceCase):
def test_setup_no_limit(self):
P = self.create_resource(None)
- self.assertFalse(P._resource.queue)
- self.assertIsNone(P.limit)
+ assert not P._resource.queue
+ assert P.limit is None
def test_prepare_not_callable(self):
P = self.create_resource(None)
conn = Connection('memory://')
- self.assertIs(P.prepare(conn), conn)
+ assert P.prepare(conn) is conn
def test_acquire_channel(self):
P = self.create_resource(10)
with P.acquire_channel() as (conn, channel):
- self.assertIs(channel, conn.default_channel)
+ assert channel is conn.default_channel
class test_ChannelPool(ResourceCase):
- abstract = False
def create_resource(self, limit):
return Connection(port=5672, transport=Transport).ChannelPool(limit)
@@ -714,16 +658,16 @@ class test_ChannelPool(ResourceCase):
def test_setup(self):
P = self.create_resource(10)
q = P._resource.queue
- with self.assertRaises(AttributeError):
+ with pytest.raises(AttributeError):
q[0].basic_consume
def test_setup_no_limit(self):
P = self.create_resource(None)
- self.assertFalse(P._resource.queue)
- self.assertIsNone(P.limit)
+ assert not P._resource.queue
+ assert P.limit is None
def test_prepare_not_callable(self):
P = self.create_resource(10)
conn = Connection('memory://')
chan = conn.default_channel
- self.assertIs(P.prepare(chan), chan)
+ assert P.prepare(chan) is chan
diff --git a/kombu/tests/test_entity.py b/t/unit/test_entity.py
index 8a15f5b8..e2312c8c 100644
--- a/kombu/tests/test_entity.py
+++ b/t/unit/test_entity.py
@@ -1,21 +1,23 @@
from __future__ import absolute_import, unicode_literals
import pickle
+import pytest
+
+from case import Mock, call
from kombu import Connection, Exchange, Producer, Queue, binding
from kombu.abstract import MaybeChannelBound
from kombu.exceptions import NotBoundError
from kombu.serialization import registry
-from .case import Case, Mock, call
-from .mocks import Transport
+from t.mocks import Transport
def get_conn():
return Connection(transport=Transport)
-class test_binding(Case):
+class test_binding:
def test_constructor(self):
x = binding(
@@ -23,83 +25,80 @@ class test_binding(Case):
arguments={'barg': 'bval'},
unbind_arguments={'uarg': 'uval'},
)
- self.assertEqual(x.exchange, Exchange('foo'))
- self.assertEqual(x.routing_key, 'rkey')
- self.assertDictEqual(x.arguments, {'barg': 'bval'})
- self.assertDictEqual(x.unbind_arguments, {'uarg': 'uval'})
+ assert x.exchange == Exchange('foo')
+ assert x.routing_key == 'rkey'
+ assert x.arguments == {'barg': 'bval'}
+ assert x.unbind_arguments == {'uarg': 'uval'}
def test_declare(self):
chan = get_conn().channel()
x = binding(Exchange('foo'), 'rkey')
x.declare(chan)
- self.assertIn('exchange_declare', chan)
+ assert 'exchange_declare' in chan
def test_declare_no_exchange(self):
chan = get_conn().channel()
x = binding()
x.declare(chan)
- self.assertNotIn('exchange_declare', chan)
+ assert 'exchange_declare' not in chan
def test_bind(self):
chan = get_conn().channel()
x = binding(Exchange('foo'))
x.bind(Exchange('bar')(chan))
- self.assertIn('exchange_bind', chan)
+ assert 'exchange_bind' in chan
def test_unbind(self):
chan = get_conn().channel()
x = binding(Exchange('foo'))
x.unbind(Exchange('bar')(chan))
- self.assertIn('exchange_unbind', chan)
+ assert 'exchange_unbind' in chan
def test_repr(self):
b = binding(Exchange('foo'), 'rkey')
- self.assertIn('foo', repr(b))
- self.assertIn('rkey', repr(b))
+ assert 'foo' in repr(b)
+ assert 'rkey' in repr(b)
-class test_Exchange(Case):
+class test_Exchange:
def test_bound(self):
exchange = Exchange('foo', 'direct')
- self.assertFalse(exchange.is_bound)
- self.assertIn('<unbound', repr(exchange))
+ assert not exchange.is_bound
+ assert '<unbound' in repr(exchange)
chan = get_conn().channel()
bound = exchange.bind(chan)
- self.assertTrue(bound.is_bound)
- self.assertIs(bound.channel, chan)
- self.assertIn('bound to chan:%r' % (chan.channel_id,),
- repr(bound))
+ assert bound.is_bound
+ assert bound.channel is chan
+ assert 'bound to chan:%r' % (chan.channel_id,) in repr(bound)
def test_hash(self):
- self.assertEqual(hash(Exchange('a')), hash(Exchange('a')))
- self.assertNotEqual(hash(Exchange('a')), hash(Exchange('b')))
+ assert hash(Exchange('a')) == hash(Exchange('a'))
+ assert hash(Exchange('a')) != hash(Exchange('b'))
def test_can_cache_declaration(self):
- self.assertTrue(Exchange('a', durable=True).can_cache_declaration)
- self.assertTrue(Exchange('a', durable=False).can_cache_declaration)
- self.assertFalse(Exchange('a', auto_delete=True).can_cache_declaration)
- self.assertFalse(
- Exchange(
- 'a', durable=True, auto_delete=True
- ).can_cache_declaration,
- )
+ assert Exchange('a', durable=True).can_cache_declaration
+ assert Exchange('a', durable=False).can_cache_declaration
+ assert not Exchange('a', auto_delete=True).can_cache_declaration
+ assert not Exchange(
+ 'a', durable=True, auto_delete=True,
+ ).can_cache_declaration
def test_pickle(self):
e1 = Exchange('foo', 'direct')
e2 = pickle.loads(pickle.dumps(e1))
- self.assertEqual(e1, e2)
+ assert e1 == e2
def test_eq(self):
e1 = Exchange('foo', 'direct')
e2 = Exchange('foo', 'direct')
- self.assertEqual(e1, e2)
+ assert e1 == e2
e3 = Exchange('foo', 'topic')
- self.assertNotEqual(e1, e3)
+ assert e1 != e3
- self.assertEqual(e1.__eq__(True), NotImplemented)
+ assert e1.__eq__(True) == NotImplemented
def test_revive(self):
exchange = Exchange('foo', 'direct')
@@ -108,121 +107,121 @@ class test_Exchange(Case):
# reviving unbound channel is a noop.
exchange.revive(chan)
- self.assertFalse(exchange.is_bound)
- self.assertIsNone(exchange._channel)
+ assert not exchange.is_bound
+ assert exchange._channel is None
bound = exchange.bind(chan)
- self.assertTrue(bound.is_bound)
- self.assertIs(bound.channel, chan)
+ assert bound.is_bound
+ assert bound.channel is chan
chan2 = conn.channel()
bound.revive(chan2)
- self.assertTrue(bound.is_bound)
- self.assertIs(bound._channel, chan2)
+ assert bound.is_bound
+ assert bound._channel is chan2
def test_assert_is_bound(self):
exchange = Exchange('foo', 'direct')
- with self.assertRaises(NotBoundError):
+ with pytest.raises(NotBoundError):
exchange.declare()
conn = get_conn()
chan = conn.channel()
exchange.bind(chan).declare()
- self.assertIn('exchange_declare', chan)
+ assert 'exchange_declare' in chan
def test_set_transient_delivery_mode(self):
exc = Exchange('foo', 'direct', delivery_mode='transient')
- self.assertEqual(exc.delivery_mode, Exchange.TRANSIENT_DELIVERY_MODE)
+ assert exc.delivery_mode == Exchange.TRANSIENT_DELIVERY_MODE
def test_set_passive_mode(self):
exc = Exchange('foo', 'direct', passive=True)
- self.assertTrue(exc.passive)
+ assert exc.passive
def test_set_persistent_delivery_mode(self):
exc = Exchange('foo', 'direct', delivery_mode='persistent')
- self.assertEqual(exc.delivery_mode, Exchange.PERSISTENT_DELIVERY_MODE)
+ assert exc.delivery_mode == Exchange.PERSISTENT_DELIVERY_MODE
def test_bind_at_instantiation(self):
- self.assertTrue(Exchange('foo', channel=get_conn().channel()).is_bound)
+ assert Exchange('foo', channel=get_conn().channel()).is_bound
def test_create_message(self):
chan = get_conn().channel()
Exchange('foo', channel=chan).Message({'foo': 'bar'})
- self.assertIn('prepare_message', chan)
+ assert 'prepare_message' in chan
def test_publish(self):
chan = get_conn().channel()
Exchange('foo', channel=chan).publish('the quick brown fox')
- self.assertIn('basic_publish', chan)
+ assert 'basic_publish' in chan
def test_delete(self):
chan = get_conn().channel()
Exchange('foo', channel=chan).delete()
- self.assertIn('exchange_delete', chan)
+ assert 'exchange_delete' in chan
def test__repr__(self):
b = Exchange('foo', 'topic')
- self.assertIn('foo(topic)', repr(b))
- self.assertIn('Exchange', repr(b))
+ assert 'foo(topic)' in repr(b)
+ assert 'Exchange' in repr(b)
def test_bind_to(self):
chan = get_conn().channel()
foo = Exchange('foo', 'topic')
bar = Exchange('bar', 'topic')
foo(chan).bind_to(bar)
- self.assertIn('exchange_bind', chan)
+ assert 'exchange_bind' in chan
def test_bind_to_by_name(self):
chan = get_conn().channel()
foo = Exchange('foo', 'topic')
foo(chan).bind_to('bar')
- self.assertIn('exchange_bind', chan)
+ assert 'exchange_bind' in chan
def test_unbind_from(self):
chan = get_conn().channel()
foo = Exchange('foo', 'topic')
bar = Exchange('bar', 'topic')
foo(chan).unbind_from(bar)
- self.assertIn('exchange_unbind', chan)
+ assert 'exchange_unbind' in chan
def test_unbind_from_by_name(self):
chan = get_conn().channel()
foo = Exchange('foo', 'topic')
foo(chan).unbind_from('bar')
- self.assertIn('exchange_unbind', chan)
+ assert 'exchange_unbind' in chan
def test_declare__no_declare(self):
chan = get_conn().channel()
foo = Exchange('foo', 'topic', no_declare=True)
foo(chan).declare()
- self.assertNotIn('exchange_declare', chan)
+ assert 'exchange_declare' not in chan
def test_declare__internal_exchange(self):
chan = get_conn().channel()
foo = Exchange('amq.rabbitmq.trace', 'topic')
foo(chan).declare()
- self.assertNotIn('exchange_declare', chan)
+ assert 'exchange_declare' not in chan
def test_declare(self):
chan = get_conn().channel()
foo = Exchange('foo', 'topic', no_declare=False)
foo(chan).declare()
- self.assertIn('exchange_declare', chan)
+ assert 'exchange_declare' in chan
-class test_Queue(Case):
+class test_Queue:
def setup(self):
self.exchange = Exchange('foo', 'direct')
def test_hash(self):
- self.assertEqual(hash(Queue('a')), hash(Queue('a')))
- self.assertNotEqual(hash(Queue('a')), hash(Queue('b')))
+ assert hash(Queue('a')) == hash(Queue('a'))
+ assert hash(Queue('a')) != hash(Queue('b'))
def test_repr_with_bindings(self):
ex = Exchange('foo')
x = Queue('foo', bindings=[ex.binding('A'), ex.binding('B')])
- self.assertTrue(repr(x))
+ assert repr(x)
def test_anonymous(self):
chan = Mock()
@@ -230,7 +229,7 @@ class test_Queue(Case):
chan.queue_declare.return_value = 'generated', 0, 0
xx = x(chan)
xx.declare()
- self.assertEqual(xx.name, 'generated')
+ assert xx.name == 'generated'
def test_basic_get__accept_disallowed(self):
conn = Connection('memory://')
@@ -242,9 +241,9 @@ class test_Queue(Case):
)
message = q(conn).get(no_ack=True)
- self.assertIsNotNone(message)
+ assert message is not None
- with self.assertRaises(q.ContentDisallowed):
+ with pytest.raises(q.ContentDisallowed):
message.decode()
def test_basic_get__accept_allowed(self):
@@ -257,15 +256,15 @@ class test_Queue(Case):
)
message = q(conn).get(accept=['pickle'], no_ack=True)
- self.assertIsNotNone(message)
+ assert message is not None
payload = message.decode()
- self.assertTrue(payload['complex'])
+ assert payload['complex']
def test_when_bound_but_no_exchange(self):
q = Queue('a')
q.exchange = None
- self.assertIsNone(q.when_bound())
+ assert q.when_bound() is None
def test_declare_but_no_exchange(self):
q = Queue('a')
@@ -296,7 +295,7 @@ class test_Queue(Case):
chan = Mock()
q = Queue('a')(chan)
chan.message_to_python = None
- self.assertTrue(q.get())
+ assert q.get()
def test_multiple_bindings(self):
chan = Mock()
@@ -306,110 +305,105 @@ class test_Queue(Case):
binding(Exchange('mul3'), 'rkey3'),
])
q(chan).declare()
- self.assertIn(
- call(
- nowait=False,
- exchange='mul1',
- auto_delete=False,
- passive=False,
- arguments=None,
- type='direct',
- durable=True,
- ),
- chan.exchange_declare.call_args_list,
- )
+ assert call(
+ nowait=False,
+ exchange='mul1',
+ auto_delete=False,
+ passive=False,
+ arguments=None,
+ type='direct',
+ durable=True,
+ ) in chan.exchange_declare.call_args_list
def test_can_cache_declaration(self):
- self.assertTrue(Queue('a', durable=True).can_cache_declaration)
- self.assertTrue(Queue('a', durable=False).can_cache_declaration)
+ assert Queue('a', durable=True).can_cache_declaration
+ assert Queue('a', durable=False).can_cache_declaration
def test_eq(self):
q1 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx')
q2 = Queue('xxx', Exchange('xxx', 'direct'), 'xxx')
- self.assertEqual(q1, q2)
- self.assertEqual(q1.__eq__(True), NotImplemented)
+ assert q1 == q2
+ assert q1.__eq__(True) == NotImplemented
q3 = Queue('yyy', Exchange('xxx', 'direct'), 'xxx')
- self.assertNotEqual(q1, q3)
+ assert q1 != q3
def test_exclusive_implies_auto_delete(self):
- self.assertTrue(
- Queue('foo', self.exchange, exclusive=True).auto_delete,
- )
+ assert Queue('foo', self.exchange, exclusive=True).auto_delete
def test_binds_at_instantiation(self):
- self.assertTrue(Queue('foo', self.exchange,
- channel=get_conn().channel()).is_bound)
+ assert Queue('foo', self.exchange,
+ channel=get_conn().channel()).is_bound
def test_also_binds_exchange(self):
chan = get_conn().channel()
b = Queue('foo', self.exchange)
- self.assertFalse(b.is_bound)
- self.assertFalse(b.exchange.is_bound)
+ assert not b.is_bound
+ assert not b.exchange.is_bound
b = b.bind(chan)
- self.assertTrue(b.is_bound)
- self.assertTrue(b.exchange.is_bound)
- self.assertIs(b.channel, b.exchange.channel)
- self.assertIsNot(b.exchange, self.exchange)
+ assert b.is_bound
+ assert b.exchange.is_bound
+ assert b.channel is b.exchange.channel
+ assert b.exchange is not self.exchange
def test_declare(self):
chan = get_conn().channel()
b = Queue('foo', self.exchange, 'foo', channel=chan)
- self.assertTrue(b.is_bound)
+ assert b.is_bound
b.declare()
- self.assertIn('exchange_declare', chan)
- self.assertIn('queue_declare', chan)
- self.assertIn('queue_bind', chan)
+ assert 'exchange_declare' in chan
+ assert 'queue_declare' in chan
+ assert 'queue_bind' in chan
def test_get(self):
b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel())
b.get()
- self.assertIn('basic_get', b.channel)
+ assert 'basic_get' in b.channel
def test_purge(self):
b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel())
b.purge()
- self.assertIn('queue_purge', b.channel)
+ assert 'queue_purge' in b.channel
def test_consume(self):
b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel())
b.consume('fifafo', None)
- self.assertIn('basic_consume', b.channel)
+ assert 'basic_consume' in b.channel
def test_cancel(self):
b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel())
b.cancel('fifafo')
- self.assertIn('basic_cancel', b.channel)
+ assert 'basic_cancel' in b.channel
def test_delete(self):
b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel())
b.delete()
- self.assertIn('queue_delete', b.channel)
+ assert 'queue_delete' in b.channel
def test_queue_unbind(self):
b = Queue('foo', self.exchange, 'foo', channel=get_conn().channel())
b.queue_unbind()
- self.assertIn('queue_unbind', b.channel)
+ assert 'queue_unbind' in b.channel
def test_as_dict(self):
q = Queue('foo', self.exchange, 'rk')
d = q.as_dict(recurse=True)
- self.assertEqual(d['exchange']['name'], self.exchange.name)
+ assert d['exchange']['name'] == self.exchange.name
def test_queue_dump(self):
b = binding(self.exchange, 'rk')
q = Queue('foo', self.exchange, 'rk', bindings=[b])
d = q.as_dict(recurse=True)
- self.assertEqual(d['bindings'][0]['routing_key'], 'rk')
+ assert d['bindings'][0]['routing_key'] == 'rk'
registry.dumps(d)
def test__repr__(self):
b = Queue('foo', self.exchange, 'foo')
- self.assertIn('foo', repr(b))
- self.assertIn('Queue', repr(b))
+ assert 'foo' in repr(b)
+ assert 'Queue' in repr(b)
-class test_MaybeChannelBound(Case):
+class test_MaybeChannelBound:
def test_repr(self):
- self.assertTrue(repr(MaybeChannelBound()))
+ assert repr(MaybeChannelBound())
diff --git a/t/unit/test_exceptions.py b/t/unit/test_exceptions.py
new file mode 100644
index 00000000..f72f3d6d
--- /dev/null
+++ b/t/unit/test_exceptions.py
@@ -0,0 +1,11 @@
+from __future__ import absolute_import, unicode_literals
+
+from case import Mock
+
+from kombu.exceptions import HttpError
+
+
+class test_HttpError:
+
+ def test_str(self):
+ assert str(HttpError(200, 'msg', Mock(name='response')))
diff --git a/kombu/tests/test_log.py b/t/unit/test_log.py
index 08dc299e..5946755b 100644
--- a/kombu/tests/test_log.py
+++ b/t/unit/test_log.py
@@ -3,6 +3,8 @@ from __future__ import absolute_import, unicode_literals
import logging
import sys
+from case import ANY, Mock, patch
+
from kombu.log import (
get_logger,
get_loglevel,
@@ -12,22 +14,20 @@ from kombu.log import (
setup_logging,
)
-from .case import ANY, Case, Mock, patch
-
-class test_get_logger(Case):
+class test_get_logger:
def test_when_string(self):
l = get_logger('foo')
- self.assertIs(l, logging.getLogger('foo'))
+ assert l is logging.getLogger('foo')
h1 = l.handlers[0]
- self.assertIsInstance(h1, logging.NullHandler)
+ assert isinstance(h1, logging.NullHandler)
def test_when_logger(self):
l = get_logger(logging.getLogger('foo'))
h1 = l.handlers[0]
- self.assertIsInstance(h1, logging.NullHandler)
+ assert isinstance(h1, logging.NullHandler)
def test_with_custom_handler(self):
l = logging.getLogger('bar')
@@ -35,28 +35,23 @@ class test_get_logger(Case):
l.addHandler(handler)
l = get_logger('bar')
- self.assertIs(l.handlers[0], handler)
+ assert l.handlers[0] is handler
def test_get_loglevel(self):
- self.assertEqual(get_loglevel('DEBUG'), logging.DEBUG)
- self.assertEqual(get_loglevel('ERROR'), logging.ERROR)
- self.assertEqual(get_loglevel(logging.INFO), logging.INFO)
+ assert get_loglevel('DEBUG') == logging.DEBUG
+ assert get_loglevel('ERROR') == logging.ERROR
+ assert get_loglevel(logging.INFO) == logging.INFO
-class test_safe_format(Case):
+def test_safe_format():
+ fmt = 'The %r jumped %x over the %s'
+ args = ['frog', 'foo', 'elephant']
- def test_formatting(self):
- fmt = 'The %r jumped %x over the %s'
- args = ['frog', 'foo', 'elephant']
+ res = list(safeify_format(fmt, args))
+ assert [x.strip('u') for x in res] == ["'frog'", 'foo', 'elephant']
- res = list(safeify_format(fmt, args))
- self.assertListEqual(
- [x.strip('u') for x in res],
- ["'frog'", 'foo', 'elephant'],
- )
-
-class test_LogMixin(Case):
+class test_LogMixin:
def setup(self):
self.log = Log('Log', Mock())
@@ -96,22 +91,20 @@ class test_LogMixin(Case):
log.DISABLE_TRACEBACKS = False
def test_get_loglevel(self):
- self.assertEqual(self.log.get_loglevel('DEBUG'), logging.DEBUG)
- self.assertEqual(self.log.get_loglevel('ERROR'), logging.ERROR)
- self.assertEqual(self.log.get_loglevel(logging.INFO), logging.INFO)
+ assert self.log.get_loglevel('DEBUG') == logging.DEBUG
+ assert self.log.get_loglevel('ERROR') == logging.ERROR
+ assert self.log.get_loglevel(logging.INFO) == logging.INFO
def test_is_enabled_for(self):
self.logger.isEnabledFor.return_value = True
- self.assertTrue(self.log.is_enabled_for('DEBUG'))
+ assert self.log.is_enabled_for('DEBUG')
self.logger.isEnabledFor.assert_called_with(logging.DEBUG)
def test_LogMixin_get_logger(self):
- self.assertIs(LogMixin().get_logger(),
- logging.getLogger('LogMixin'))
+ assert LogMixin().get_logger() is logging.getLogger('LogMixin')
def test_Log_get_logger(self):
- self.assertIs(Log('test_Log').get_logger(),
- logging.getLogger('test_Log'))
+ assert Log('test_Log').get_logger() is logging.getLogger('test_Log')
def test_log_when_not_enabled(self):
self.logger.isEnabledFor.return_value = False
@@ -123,13 +116,10 @@ class test_LogMixin(Case):
self.logger.log.assert_called_with(
logging.DEBUG, 'Log - Host %s removed', ANY,
)
- self.assertEqual(
- self.logger.log.call_args[0][2].strip('u'),
- "'example.com'",
- )
+ assert self.logger.log.call_args[0][2].strip('u') == "'example.com'"
-class test_setup_logging(Case):
+class test_setup_logging:
@patch('logging.getLogger')
def test_set_up_default_values(self, getLogger):
@@ -141,8 +131,8 @@ class test_setup_logging(Case):
logger.addHandler.assert_called()
ah_args, _ = logger.addHandler.call_args
handler = ah_args[0]
- self.assertIsInstance(handler, logging.StreamHandler)
- self.assertIs(handler.stream, sys.__stderr__)
+ assert isinstance(handler, logging.StreamHandler)
+ assert handler.stream is sys.__stderr__
@patch('logging.getLogger')
@patch('kombu.log.WatchedFileHandler')
diff --git a/kombu/tests/test_message.py b/t/unit/test_message.py
index 320e2421..70d35b00 100644
--- a/kombu/tests/test_message.py
+++ b/t/unit/test_message.py
@@ -1,26 +1,27 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import sys
-from kombu.message import Message
+from case import Mock, patch
-from .case import Case, Mock, patch
+from kombu.message import Message
-class test_Message(Case):
+class test_Message:
def test_repr(self):
- self.assertTrue(repr(Message(Mock(), 'b')))
+ assert repr(Message(Mock(), 'b'))
def test_decode(self):
m = Message(Mock(), 'body')
decode = m._decode = Mock()
- self.assertIsNone(m._decoded_cache)
- self.assertIs(m.decode(), m._decode.return_value)
- self.assertIs(m._decoded_cache, m._decode.return_value)
+ assert m._decoded_cache is None
+ assert m.decode() is m._decode.return_value
+ assert m._decoded_cache is m._decode.return_value
m._decode.assert_called_with()
m._decode = Mock()
- self.assertIs(m.decode(), decode.return_value)
+ assert m.decode() is decode.return_value
def test_reraise_error(self):
m = Message(Mock(), 'body')
@@ -32,12 +33,12 @@ class test_Message(Case):
m._reraise_error(callback)
callback.assert_called()
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
m._reraise_error(None)
@patch('kombu.message.decompress')
def test_decompression_stores_error(self, decompress):
decompress.side_effect = RuntimeError()
m = Message(Mock(), 'body', headers={'compression': 'zlib'})
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
m._reraise_error(None)
diff --git a/kombu/tests/test_messaging.py b/t/unit/test_messaging.py
index 8f48f7cd..fbc8d72b 100644
--- a/kombu/tests/test_messaging.py
+++ b/t/unit/test_messaging.py
@@ -1,41 +1,43 @@
from __future__ import absolute_import, unicode_literals
import pickle
+import pytest
import sys
from collections import defaultdict
+from case import Mock, patch
+
from kombu import Connection, Consumer, Producer, Exchange, Queue
from kombu.exceptions import MessageStateError
from kombu.utils import json
from kombu.utils.functional import ChannelPromise
-from .case import Case, Mock, patch
-from .mocks import Transport
+from t.mocks import Transport
-class test_Producer(Case):
+class test_Producer:
def setup(self):
self.exchange = Exchange('foo', 'direct')
self.connection = Connection(transport=Transport)
self.connection.connect()
- self.assertTrue(self.connection.connection.connected)
- self.assertFalse(self.exchange.is_bound)
+ assert self.connection.connection.connected
+ assert not self.exchange.is_bound
def test_repr(self):
p = Producer(self.connection)
- self.assertTrue(repr(p))
+ assert repr(p)
def test_pickle(self):
chan = Mock()
producer = Producer(chan, serializer='pickle')
p2 = pickle.loads(pickle.dumps(producer))
- self.assertEqual(p2.serializer, producer.serializer)
+ assert p2.serializer == producer.serializer
def test_no_channel(self):
p = Producer(None)
- self.assertFalse(p._channel)
+ assert not p._channel
@patch('kombu.messaging.maybe_declare')
def test_maybe_declare(self, maybe_declare):
@@ -53,30 +55,30 @@ class test_Producer(Case):
def test_auto_declare(self):
channel = self.connection.channel()
p = Producer(channel, self.exchange, auto_declare=True)
- self.assertIsNot(p.exchange, self.exchange,
- 'creates Exchange clone at bind')
- self.assertTrue(p.exchange.is_bound)
- self.assertIn('exchange_declare', channel,
- 'auto_declare declares exchange')
+ # creates Exchange clone at bind
+ assert p.exchange is not self.exchange
+ assert p.exchange.is_bound
+ # auto_declare declares exchange'
+ assert 'exchange_declare' in channel
def test_manual_declare(self):
channel = self.connection.channel()
p = Producer(channel, self.exchange, auto_declare=False)
- self.assertTrue(p.exchange.is_bound)
- self.assertNotIn('exchange_declare', channel,
- 'auto_declare=False does not declare exchange')
+ assert p.exchange.is_bound
+ # auto_declare=False does not declare exchange
+ assert 'exchange_declare' not in channel
+ # p.declare() declares exchange')
p.declare()
- self.assertIn('exchange_declare', channel,
- 'p.declare() declares exchange')
+ assert 'exchange_declare' in channel
def test_prepare(self):
message = {'the quick brown fox': 'jumps over the lazy dog'}
channel = self.connection.channel()
p = Producer(channel, self.exchange, serializer='json')
m, ctype, cencoding = p._prepare(message, headers={})
- self.assertDictEqual(message, json.loads(m))
- self.assertEqual(ctype, 'application/json')
- self.assertEqual(cencoding, 'utf-8')
+ assert json.loads(m) == message
+ assert ctype == 'application/json'
+ assert cencoding == 'utf-8'
def test_prepare_compression(self):
message = {'the quick brown fox': 'jumps over the lazy dog'}
@@ -85,64 +87,59 @@ class test_Producer(Case):
headers = {}
m, ctype, cencoding = p._prepare(message, compression='zlib',
headers=headers)
- self.assertEqual(ctype, 'application/json')
- self.assertEqual(cencoding, 'utf-8')
- self.assertEqual(headers['compression'], 'application/x-gzip')
+ assert ctype == 'application/json'
+ assert cencoding == 'utf-8'
+ assert headers['compression'] == 'application/x-gzip'
import zlib
- self.assertEqual(
- json.loads(zlib.decompress(m).decode('utf-8')),
- message,
- )
+ assert json.loads(zlib.decompress(m).decode('utf-8')) == message
def test_prepare_custom_content_type(self):
message = 'the quick brown fox'.encode('utf-8')
channel = self.connection.channel()
p = Producer(channel, self.exchange, serializer='json')
m, ctype, cencoding = p._prepare(message, content_type='custom')
- self.assertEqual(m, message)
- self.assertEqual(ctype, 'custom')
- self.assertEqual(cencoding, 'binary')
+ assert m == message
+ assert ctype == 'custom'
+ assert cencoding == 'binary'
m, ctype, cencoding = p._prepare(message, content_type='custom',
content_encoding='alien')
- self.assertEqual(m, message)
- self.assertEqual(ctype, 'custom')
- self.assertEqual(cencoding, 'alien')
+ assert m == message
+ assert ctype == 'custom'
+ assert cencoding == 'alien'
def test_prepare_is_already_unicode(self):
message = 'the quick brown fox'
channel = self.connection.channel()
p = Producer(channel, self.exchange, serializer='json')
m, ctype, cencoding = p._prepare(message, content_type='text/plain')
- self.assertEqual(m, message.encode('utf-8'))
- self.assertEqual(ctype, 'text/plain')
- self.assertEqual(cencoding, 'utf-8')
+ assert m == message.encode('utf-8')
+ assert ctype == 'text/plain'
+ assert cencoding == 'utf-8'
m, ctype, cencoding = p._prepare(message, content_type='text/plain',
content_encoding='utf-8')
- self.assertEqual(m, message.encode('utf-8'))
- self.assertEqual(ctype, 'text/plain')
- self.assertEqual(cencoding, 'utf-8')
+ assert m == message.encode('utf-8')
+ assert ctype == 'text/plain'
+ assert cencoding == 'utf-8'
def test_publish_with_Exchange_instance(self):
p = self.connection.Producer()
p.channel = Mock()
p.publish('hello', exchange=Exchange('foo'), delivery_mode='transient')
- self.assertEqual(
- p._channel.basic_publish.call_args[1]['exchange'], 'foo',
- )
+ assert p._channel.basic_publish.call_args[1]['exchange'] == 'foo'
def test_publish_with_expiration(self):
p = self.connection.Producer()
p.channel = Mock()
p.publish('hello', exchange=Exchange('foo'), expiration=10)
properties = p._channel.prepare_message.call_args[0][5]
- self.assertEqual(properties['expiration'], '10000')
+ assert properties['expiration'] == '10000'
def test_publish_with_reply_to(self):
p = self.connection.Producer()
p.channel = Mock()
p.publish('hello', exchange=Exchange('foo'), reply_to=Queue('foo'))
properties = p._channel.prepare_message.call_args[0][5]
- self.assertEqual(properties['reply_to'], 'foo')
+ assert properties['reply_to'] == 'foo'
def test_set_on_return(self):
chan = Mock()
@@ -173,14 +170,14 @@ class test_Producer(Case):
defchan = new_conn.default_channel
p.revive(new_conn)
- self.assertIs(p.channel, defchan)
+ assert p.channel is defchan
p.exchange.revive.assert_called_with(defchan)
def test_enter_exit(self):
p = self.connection.Producer()
p.release = Mock()
- self.assertIs(p.__enter__(), p)
+ assert p.__enter__() is p
p.__exit__()
p.release.assert_called_with()
@@ -188,37 +185,37 @@ class test_Producer(Case):
p = self.connection.Producer()
p.channel = object()
p.__connection__ = None
- self.assertIsNone(p.connection)
+ assert p.connection is None
def test_publish(self):
channel = self.connection.channel()
p = Producer(channel, self.exchange, serializer='json')
message = {'the quick brown fox': 'jumps over the lazy dog'}
ret = p.publish(message, routing_key='process')
- self.assertIn('prepare_message', channel)
- self.assertIn('basic_publish', channel)
+ assert 'prepare_message' in channel
+ assert 'basic_publish' in channel
m, exc, rkey = ret
- self.assertDictEqual(message, json.loads(m['body']))
- self.assertDictContainsSubset({'content_type': 'application/json',
- 'content_encoding': 'utf-8',
- 'priority': 0}, m)
- self.assertDictContainsSubset({'delivery_mode': 2}, m['properties'])
- self.assertEqual(exc, p.exchange.name)
- self.assertEqual(rkey, 'process')
+ assert json.loads(m['body']) == message
+ assert m['content_type'] == 'application/json'
+ assert m['content_encoding'] == 'utf-8'
+ assert m['priority'] == 0
+ assert m['properties']['delivery_mode'] == 2
+ assert exc == p.exchange.name
+ assert rkey == 'process'
def test_no_exchange(self):
chan = self.connection.channel()
p = Producer(chan)
- self.assertFalse(p.exchange.name)
+ assert not p.exchange.name
def test_revive(self):
chan = self.connection.channel()
p = Producer(chan)
chan2 = self.connection.channel()
p.revive(chan2)
- self.assertIs(p.channel, chan2)
- self.assertIs(p.exchange.channel, chan2)
+ assert p.channel is chan2
+ assert p.exchange.channel is chan2
def test_on_return(self):
chan = self.connection.channel()
@@ -227,28 +224,27 @@ class test_Producer(Case):
pass
p = Producer(chan, on_return=on_return)
- self.assertIn(on_return, chan.events['basic_return'])
- self.assertTrue(p.on_return)
+ assert on_return in chan.events['basic_return']
+ assert p.on_return
-class test_Consumer(Case):
+class test_Consumer:
def setup(self):
self.connection = Connection(transport=Transport)
self.connection.connect()
- self.assertTrue(self.connection.connection.connected)
+ assert self.connection.connection.connected
self.exchange = Exchange('foo', 'direct')
def test_accept(self):
a = Consumer(self.connection)
- self.assertIsNone(a.accept)
+ assert a.accept is None
b = Consumer(self.connection, accept=['json', 'pickle'])
- self.assertSetEqual(
- b.accept,
- {'application/json', 'application/x-python-serialize'},
- )
+ assert b.accept == {
+ 'application/json', 'application/x-python-serialize',
+ }
c = Consumer(self.connection, accept=b.accept)
- self.assertSetEqual(b.accept, c.accept)
+ assert b.accept == c.accept
def test_enter_exit_cancel_raises(self):
c = Consumer(self.connection)
@@ -269,7 +265,7 @@ class test_Consumer(Case):
c._receive_callback(message)
callback.assert_called_with(message)
- self.assertSetEqual(message.accept, c.accept)
+ assert message.accept == c.accept
def test_accept__content_disallowed(self):
conn = Connection('memory://')
@@ -282,7 +278,7 @@ class test_Consumer(Case):
callback = Mock(name='callback')
with conn.Consumer(queues=[q], callbacks=[callback]) as consumer:
- with self.assertRaises(consumer.ContentDisallowed):
+ with pytest.raises(consumer.ContentDisallowed):
conn.drain_events(timeout=1)
callback.assert_not_called()
@@ -301,26 +297,26 @@ class test_Consumer(Case):
conn.drain_events(timeout=1)
callback.assert_called()
body, message = callback.call_args[0]
- self.assertTrue(body['complex'])
+ assert body['complex']
def test_set_no_channel(self):
c = Consumer(None)
- self.assertIsNone(c.channel)
+ assert c.channel is None
c.revive(Mock())
- self.assertTrue(c.channel)
+ assert c.channel
def test_set_no_ack(self):
channel = self.connection.channel()
queue = Queue('qname', self.exchange, 'rkey')
consumer = Consumer(channel, queue, auto_declare=True, no_ack=True)
- self.assertTrue(consumer.no_ack)
+ assert consumer.no_ack
def test_add_queue_when_auto_declare(self):
consumer = self.connection.Consumer(auto_declare=True)
q = Mock()
q.return_value = q
consumer.add_queue(q)
- self.assertIn(q, consumer.queues)
+ assert q in consumer.queues
q.declare.assert_called_with()
def test_add_queue_when_not_auto_declare(self):
@@ -328,26 +324,26 @@ class test_Consumer(Case):
q = Mock()
q.return_value = q
consumer.add_queue(q)
- self.assertIn(q, consumer.queues)
- self.assertFalse(q.declare.call_count)
+ assert q in consumer.queues
+ assert not q.declare.call_count
def test_consume_without_queues_returns(self):
consumer = self.connection.Consumer()
consumer.queues[:] = []
- self.assertIsNone(consumer.consume())
+ assert consumer.consume() is None
def test_consuming_from(self):
consumer = self.connection.Consumer()
consumer.queues[:] = [Queue('a'), Queue('b'), Queue('d')]
consumer._active_tags = {'a': 1, 'b': 2}
- self.assertFalse(consumer.consuming_from(Queue('c')))
- self.assertFalse(consumer.consuming_from('c'))
- self.assertFalse(consumer.consuming_from(Queue('d')))
- self.assertFalse(consumer.consuming_from('d'))
- self.assertTrue(consumer.consuming_from(Queue('a')))
- self.assertTrue(consumer.consuming_from(Queue('b')))
- self.assertTrue(consumer.consuming_from('b'))
+ assert not consumer.consuming_from(Queue('c'))
+ assert not consumer.consuming_from('c')
+ assert not consumer.consuming_from(Queue('d'))
+ assert not consumer.consuming_from('d')
+ assert consumer.consuming_from(Queue('a'))
+ assert consumer.consuming_from(Queue('b'))
+ assert consumer.consuming_from('b')
def test_receive_callback_without_m2p(self):
channel = self.connection.channel()
@@ -374,7 +370,7 @@ class test_Consumer(Case):
except KeyError:
message.errors = [sys.exc_info()]
message._reraise_error.side_effect = KeyError()
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
c._receive_callback(message)
def test_set_callbacks(self):
@@ -384,7 +380,7 @@ class test_Consumer(Case):
lambda x, y: x]
consumer = Consumer(channel, queue, auto_declare=True,
callbacks=callbacks)
- self.assertEqual(consumer.callbacks, callbacks)
+ assert consumer.callbacks == callbacks
def test_auto_declare(self):
channel = self.connection.channel()
@@ -392,22 +388,22 @@ class test_Consumer(Case):
consumer = Consumer(channel, queue, auto_declare=True)
consumer.consume()
consumer.consume() # twice is a noop
- self.assertIsNot(consumer.queues[0], queue)
- self.assertTrue(consumer.queues[0].is_bound)
- self.assertTrue(consumer.queues[0].exchange.is_bound)
- self.assertIsNot(consumer.queues[0].exchange, self.exchange)
+ assert consumer.queues[0] is not queue
+ assert consumer.queues[0].is_bound
+ assert consumer.queues[0].exchange.is_bound
+ assert consumer.queues[0].exchange is not self.exchange
for meth in ('exchange_declare',
'queue_declare',
'queue_bind',
'basic_consume'):
- self.assertIn(meth, channel)
- self.assertEqual(channel.called.count('basic_consume'), 1)
- self.assertTrue(consumer._active_tags)
+ assert meth in channel
+ assert channel.called.count('basic_consume') == 1
+ assert consumer._active_tags
consumer.cancel_by_queue(queue.name)
consumer.cancel_by_queue(queue.name)
- self.assertFalse(consumer._active_tags)
+ assert not consumer._active_tags
def test_consumer_tag_prefix(self):
channel = self.connection.channel()
@@ -415,33 +411,31 @@ class test_Consumer(Case):
consumer = Consumer(channel, queue, tag_prefix='consumer_')
consumer.consume()
- self.assertTrue(
- consumer._active_tags[queue.name].startswith('consumer_'),
- )
+ assert consumer._active_tags[queue.name].startswith('consumer_')
def test_manual_declare(self):
channel = self.connection.channel()
queue = Queue('qname', self.exchange, 'rkey')
consumer = Consumer(channel, queue, auto_declare=False)
- self.assertIsNot(consumer.queues[0], queue)
- self.assertTrue(consumer.queues[0].is_bound)
- self.assertTrue(consumer.queues[0].exchange.is_bound)
- self.assertIsNot(consumer.queues[0].exchange, self.exchange)
+ assert consumer.queues[0] is not queue
+ assert consumer.queues[0].is_bound
+ assert consumer.queues[0].exchange.is_bound
+ assert consumer.queues[0].exchange is not self.exchange
for meth in ('exchange_declare',
'queue_declare',
'basic_consume'):
- self.assertNotIn(meth, channel)
+ assert meth not in channel
consumer.declare()
for meth in ('exchange_declare',
'queue_declare',
'queue_bind'):
- self.assertIn(meth, channel)
- self.assertNotIn('basic_consume', channel)
+ assert meth in channel
+ assert 'basic_consume' not in channel
consumer.consume()
- self.assertIn('basic_consume', channel)
+ assert 'basic_consume' in channel
def test_consume__cancel(self):
channel = self.connection.channel()
@@ -449,34 +443,34 @@ class test_Consumer(Case):
consumer = Consumer(channel, queue, auto_declare=True)
consumer.consume()
consumer.cancel()
- self.assertIn('basic_cancel', channel)
- self.assertFalse(consumer._active_tags)
+ assert 'basic_cancel' in channel
+ assert not consumer._active_tags
def test___enter____exit__(self):
channel = self.connection.channel()
queue = Queue('qname', self.exchange, 'rkey')
consumer = Consumer(channel, queue, auto_declare=True)
context = consumer.__enter__()
- self.assertIs(context, consumer)
- self.assertTrue(consumer._active_tags)
+ assert context is consumer
+ assert consumer._active_tags
res = consumer.__exit__(None, None, None)
- self.assertFalse(res)
- self.assertIn('basic_cancel', channel)
- self.assertFalse(consumer._active_tags)
+ assert not res
+ assert 'basic_cancel' in channel
+ assert not consumer._active_tags
def test_flow(self):
channel = self.connection.channel()
queue = Queue('qname', self.exchange, 'rkey')
consumer = Consumer(channel, queue, auto_declare=True)
consumer.flow(False)
- self.assertIn('flow', channel)
+ assert 'flow' in channel
def test_qos(self):
channel = self.connection.channel()
queue = Queue('qname', self.exchange, 'rkey')
consumer = Consumer(channel, queue, auto_declare=True)
consumer.qos(30, 10, False)
- self.assertIn('basic_qos', channel)
+ assert 'basic_qos' in channel
def test_purge(self):
channel = self.connection.channel()
@@ -486,7 +480,7 @@ class test_Consumer(Case):
b4 = Queue('qname4', self.exchange, 'rkey')
consumer = Consumer(channel, [b1, b2, b3, b4], auto_declare=True)
consumer.purge()
- self.assertEqual(channel.called.count('queue_purge'), 4)
+ assert channel.called.count('queue_purge') == 4
def test_multiple_queues(self):
channel = self.connection.channel()
@@ -496,14 +490,14 @@ class test_Consumer(Case):
b4 = Queue('qname4', self.exchange, 'rkey')
consumer = Consumer(channel, [b1, b2, b3, b4])
consumer.consume()
- self.assertEqual(channel.called.count('exchange_declare'), 4)
- self.assertEqual(channel.called.count('queue_declare'), 4)
- self.assertEqual(channel.called.count('queue_bind'), 4)
- self.assertEqual(channel.called.count('basic_consume'), 4)
- self.assertEqual(len(consumer._active_tags), 4)
+ assert channel.called.count('exchange_declare') == 4
+ assert channel.called.count('queue_declare') == 4
+ assert channel.called.count('queue_bind') == 4
+ assert channel.called.count('basic_consume') == 4
+ assert len(consumer._active_tags) == 4
consumer.cancel()
- self.assertEqual(channel.called.count('basic_cancel'), 4)
- self.assertFalse(len(consumer._active_tags))
+ assert channel.called.count('basic_cancel') == 4
+ assert not len(consumer._active_tags)
def test_receive_callback(self):
channel = self.connection.channel()
@@ -519,9 +513,9 @@ class test_Consumer(Case):
consumer.register_callback(callback)
consumer._receive_callback({'foo': 'bar'})
- self.assertIn('basic_ack', channel)
- self.assertIn('message_to_python', channel)
- self.assertEqual(received[0], {'foo': 'bar'})
+ assert 'basic_ack' in channel
+ assert 'message_to_python' in channel
+ assert received[0] == {'foo': 'bar'}
def test_basic_ack_twice(self):
channel = self.connection.channel()
@@ -533,7 +527,7 @@ class test_Consumer(Case):
message.ack()
consumer.register_callback(callback)
- with self.assertRaises(MessageStateError):
+ with pytest.raises(MessageStateError):
consumer._receive_callback({'foo': 'bar'})
def test_basic_reject(self):
@@ -546,7 +540,7 @@ class test_Consumer(Case):
consumer.register_callback(callback)
consumer._receive_callback({'foo': 'bar'})
- self.assertIn('basic_reject', channel)
+ assert 'basic_reject' in channel
def test_basic_reject_twice(self):
channel = self.connection.channel()
@@ -558,9 +552,9 @@ class test_Consumer(Case):
message.reject()
consumer.register_callback(callback)
- with self.assertRaises(MessageStateError):
+ with pytest.raises(MessageStateError):
consumer._receive_callback({'foo': 'bar'})
- self.assertIn('basic_reject', channel)
+ assert 'basic_reject' in channel
def test_basic_reject__requeue(self):
channel = self.connection.channel()
@@ -572,7 +566,7 @@ class test_Consumer(Case):
consumer.register_callback(callback)
consumer._receive_callback({'foo': 'bar'})
- self.assertIn('basic_reject:requeue', channel)
+ assert 'basic_reject:requeue' in channel
def test_basic_reject__requeue_twice(self):
channel = self.connection.channel()
@@ -584,15 +578,15 @@ class test_Consumer(Case):
message.requeue()
consumer.register_callback(callback)
- with self.assertRaises(MessageStateError):
+ with pytest.raises(MessageStateError):
consumer._receive_callback({'foo': 'bar'})
- self.assertIn('basic_reject:requeue', channel)
+ assert 'basic_reject:requeue' in channel
def test_receive_without_callbacks_raises(self):
channel = self.connection.channel()
b1 = Queue('qname1', self.exchange, 'rkey')
consumer = Consumer(channel, [b1])
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
consumer.receive(1, 2)
def test_decode_error(self):
@@ -601,7 +595,7 @@ class test_Consumer(Case):
consumer = Consumer(channel, [b1])
consumer.channel.throw_decode_error = True
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
consumer._receive_callback({'foo': 'bar'})
def test_on_decode_error_callback(self):
@@ -616,17 +610,17 @@ class test_Consumer(Case):
consumer.channel.throw_decode_error = True
consumer._receive_callback({'foo': 'bar'})
- self.assertTrue(thrown)
+ assert thrown
m, exc = thrown[0]
- self.assertEqual(json.loads(m), {'foo': 'bar'})
- self.assertIsInstance(exc, ValueError)
+ assert json.loads(m) == {'foo': 'bar'}
+ assert isinstance(exc, ValueError)
def test_recover(self):
channel = self.connection.channel()
b1 = Queue('qname1', self.exchange, 'rkey')
consumer = Consumer(channel, [b1])
consumer.recover()
- self.assertIn('basic_recover', channel)
+ assert 'basic_recover' in channel
def test_revive(self):
channel = self.connection.channel()
@@ -634,9 +628,9 @@ class test_Consumer(Case):
consumer = Consumer(channel, [b1])
channel2 = self.connection.channel()
consumer.revive(channel2)
- self.assertIs(consumer.channel, channel2)
- self.assertIs(consumer.queues[0].channel, channel2)
- self.assertIs(consumer.queues[0].exchange.channel, channel2)
+ assert consumer.channel is channel2
+ assert consumer.queues[0].channel is channel2
+ assert consumer.queues[0].exchange.channel is channel2
def test_revive__with_prefetch_count(self):
channel = Mock(name='channel')
@@ -647,9 +641,9 @@ class test_Consumer(Case):
def test__repr__(self):
channel = self.connection.channel()
b1 = Queue('qname1', self.exchange, 'rkey')
- self.assertTrue(repr(Consumer(channel, [b1])))
+ assert repr(Consumer(channel, [b1]))
def test_connection_property_handles_AttributeError(self):
p = self.connection.Consumer()
p.channel = object()
- self.assertIsNone(p.connection)
+ assert p.connection is None
diff --git a/kombu/tests/test_mixins.py b/t/unit/test_mixins.py
index e3690806..372399e4 100644
--- a/kombu/tests/test_mixins.py
+++ b/t/unit/test_mixins.py
@@ -1,10 +1,11 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import socket
-from kombu.mixins import ConsumerMixin
+from case import ContextMock, Mock, patch
-from .case import Case, Mock, ContextMock, patch
+from kombu.mixins import ConsumerMixin
def Message(body, content_type='text/plain', content_encoding='utf-8'):
@@ -31,7 +32,7 @@ class Cons(ConsumerMixin):
self.extra_context.return_value = self.extra_context
-class test_ConsumerMixin(Case):
+class test_ConsumerMixin:
def _context(self):
Acons = ContextMock(name='consumerA')
@@ -57,7 +58,7 @@ class test_ConsumerMixin(Case):
next(it)
next(it)
c.should_stop = True
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
def test_consume_drain_raises_socket_error(self):
@@ -65,7 +66,7 @@ class test_ConsumerMixin(Case):
c.should_stop = False
it = c.consume(no_ack=True)
c.connection.drain_events.side_effect = socket.error
- with self.assertRaises(socket.error):
+ with pytest.raises(socket.error):
next(it)
def se2(*args, **kwargs):
@@ -73,7 +74,7 @@ class test_ConsumerMixin(Case):
raise socket.error()
c.connection.drain_events.side_effect = se2
it = c.consume(no_ack=True)
- with self.assertRaises(StopIteration):
+ with pytest.raises(StopIteration):
next(it)
def test_consume_drain_raises_socket_timeout(self):
@@ -85,49 +86,47 @@ class test_ConsumerMixin(Case):
c.should_stop = True
raise socket.timeout()
c.connection.drain_events.side_effect = se
- with self.assertRaises(socket.error):
+ with pytest.raises(socket.error):
next(it)
def test_Consumer_context(self):
c, Acons, Bcons = self._context()
with c.Consumer() as (conn, channel, consumer):
- self.assertIs(conn, c.connection)
- self.assertIs(channel, conn.default_channel)
+ assert conn is c.connection
+ assert channel is conn.default_channel
c.on_connection_revived.assert_called_with()
c.get_consumers.assert_called()
cls = c.get_consumers.call_args[0][0]
subcons = cls()
- self.assertIs(subcons.on_decode_error, c.on_decode_error)
- self.assertIs(subcons.channel, conn.default_channel)
+ assert subcons.on_decode_error is c.on_decode_error
+ assert subcons.channel is conn.default_channel
Acons.__enter__.assert_called_with()
Bcons.__enter__.assert_called_with()
c.on_consume_end.assert_called_with(conn, channel)
-class test_ConsumerMixin_interface(Case):
+class test_ConsumerMixin_interface:
def setup(self):
self.c = ConsumerMixin()
def test_get_consumers(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.c.get_consumers(Mock(), Mock())
def test_on_connection_revived(self):
- self.assertIsNone(self.c.on_connection_revived())
+ assert self.c.on_connection_revived() is None
def test_on_consume_ready(self):
- self.assertIsNone(self.c.on_consume_ready(
- Mock(), Mock(), [],
- ))
+ assert self.c.on_consume_ready(Mock(), Mock(), []) is None
def test_on_consume_end(self):
- self.assertIsNone(self.c.on_consume_end(Mock(), Mock()))
+ assert self.c.on_consume_end(Mock(), Mock()) is None
def test_on_iteration(self):
- self.assertIsNone(self.c.on_iteration())
+ assert self.c.on_iteration() is None
def test_on_decode_error(self):
message = Message('foo')
@@ -146,15 +145,15 @@ class test_ConsumerMixin_interface(Case):
pass
def test_restart_limit(self):
- self.assertTrue(self.c.restart_limit)
+ assert self.c.restart_limit
def test_connection_errors(self):
conn = Mock(name='connection')
self.c.connection = conn
conn.connection_errors = (KeyError,)
- self.assertTupleEqual(self.c.connection_errors, conn.connection_errors)
+ assert self.c.connection_errors == conn.connection_errors
conn.channel_errors = (ValueError,)
- self.assertTupleEqual(self.c.channel_errors, conn.channel_errors)
+ assert self.c.channel_errors == conn.channel_errors
def test__consume_from(self):
a = ContextMock(name='A')
@@ -173,7 +172,7 @@ class test_ConsumerMixin_interface(Case):
self.c.connect_max_retries = 3
with self.c.establish_connection() as conn:
- self.assertTrue(conn)
+ assert conn
conn.ensure_connection.assert_called_with(
self.c.on_connection_error, 3,
)
diff --git a/kombu/tests/test_pidbox.py b/t/unit/test_pidbox.py
index 852dbbf5..c2dfccb0 100644
--- a/kombu/tests/test_pidbox.py
+++ b/t/unit/test_pidbox.py
@@ -1,29 +1,34 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import socket
import warnings
+from case import Mock, patch
+
from kombu import Connection
from kombu import pidbox
from kombu.exceptions import ContentDisallowed, InconsistencyError
from kombu.utils.uuid import uuid
-from .case import Case, Mock, patch
+def is_cast(message):
+ return message['method']
-class test_Mailbox(Case):
- def _handler(self, state):
- return self.stats['var']
+def is_call(message):
+ return message['method'] and message['reply_to']
- def setup(self):
- class Mailbox(pidbox.Mailbox):
+class test_Mailbox:
+
+ class Mailbox(pidbox.Mailbox):
- def _collect(self, *args, **kwargs):
- return 'COLLECTED'
+ def _collect(self, *args, **kwargs):
+ return 'COLLECTED'
- self.mailbox = Mailbox('test_pidbox')
+ def setup(self):
+ self.mailbox = self.Mailbox('test_pidbox')
self.connection = Connection(transport='memory')
self.state = {'var': 1}
self.handlers = {'mymethod': self._handler}
@@ -35,6 +40,9 @@ class test_Mailbox(Case):
channel=self.default_chan,
)
+ def _handler(self, state):
+ return self.stats['var']
+
def test_publish_reply_ignores_InconsistencyError(self):
mailbox = pidbox.Mailbox('test_reply__collect')(self.connection)
with patch('kombu.pidbox.Producer') as Producer:
@@ -60,17 +68,17 @@ class test_Mailbox(Case):
reply = mailbox._collect(ticket, limit=1,
callback=callback, channel=channel)
- self.assertEqual(reply, [{'foo': 'bar'}])
- self.assertTrue(_callback_called[0])
+ assert reply == [{'foo': 'bar'}]
+ assert _callback_called[0]
ticket = uuid()
mailbox._publish_reply({'biz': 'boz'}, exchange, mailbox.oid, ticket)
reply = mailbox._collect(ticket, limit=1, channel=channel)
- self.assertEqual(reply, [{'biz': 'boz'}])
+ assert reply == [{'biz': 'boz'}]
mailbox._publish_reply({'foo': 'BAM'}, exchange, mailbox.oid, 'doom',
serializer='pickle')
- with self.assertRaises(ContentDisallowed):
+ with pytest.raises(ContentDisallowed):
reply = mailbox._collect('doom', limit=1, channel=channel)
mailbox._publish_reply(
{'foo': 'BAMBAM'}, exchange, mailbox.oid, 'doom',
@@ -78,40 +86,40 @@ class test_Mailbox(Case):
)
reply = mailbox._collect('doom', limit=1, channel=channel,
accept=['pickle'])
- self.assertEqual(reply[0]['foo'], 'BAMBAM')
+ assert reply[0]['foo'] == 'BAMBAM'
de = mailbox.connection.drain_events = Mock()
de.side_effect = socket.timeout
mailbox._collect(ticket, limit=1, channel=channel)
def test_constructor(self):
- self.assertIsNone(self.mailbox.connection)
- self.assertTrue(self.mailbox.exchange.name)
- self.assertTrue(self.mailbox.reply_exchange.name)
+ assert self.mailbox.connection is None
+ assert self.mailbox.exchange.name
+ assert self.mailbox.reply_exchange.name
def test_bound(self):
bound = self.mailbox(self.connection)
- self.assertIs(bound.connection, self.connection)
+ assert bound.connection is self.connection
def test_Node(self):
- self.assertTrue(self.node.hostname)
- self.assertTrue(self.node.state)
- self.assertIs(self.node.mailbox, self.bound)
- self.assertTrue(self.handlers)
+ assert self.node.hostname
+ assert self.node.state
+ assert self.node.mailbox is self.bound
+ assert self.handlers
# No initial handlers
node2 = self.bound.Node('test_pidbox2', state=self.state)
- self.assertDictEqual(node2.handlers, {})
+ assert node2.handlers == {}
def test_Node_consumer(self):
consumer1 = self.node.Consumer()
- self.assertIs(consumer1.channel, self.default_chan)
- self.assertTrue(consumer1.no_ack)
+ assert consumer1.channel is self.default_chan
+ assert consumer1.no_ack
chan2 = self.connection.channel()
consumer2 = self.node.Consumer(channel=chan2, no_ack=False)
- self.assertIs(consumer2.channel, chan2)
- self.assertFalse(consumer2.no_ack)
+ assert consumer2.channel is chan2
+ assert not consumer2.no_ack
def test_Node_consumer_multiple_listeners(self):
warnings.resetwarnings()
@@ -119,12 +127,12 @@ class test_Mailbox(Case):
q = consumer.queues[0]
with warnings.catch_warnings(record=True) as log:
q.on_declared('foo', 1, 1)
- self.assertTrue(log)
- self.assertIn('already using this', log[0].message.args[0])
+ assert log
+ assert 'already using this' in log[0].message.args[0]
with warnings.catch_warnings(record=True) as log:
q.on_declared('foo', 1, 0)
- self.assertFalse(log)
+ assert not log
def test_handler(self):
node = self.bound.Node('test_handler', state=self.state)
@@ -133,7 +141,7 @@ class test_Mailbox(Case):
def my_handler_name(state):
return 42
- self.assertIn('my_handler_name', node.handlers)
+ assert 'my_handler_name' in node.handlers
def test_dispatch(self):
node = self.bound.Node('test_dispatch', state=self.state)
@@ -142,8 +150,8 @@ class test_Mailbox(Case):
def my_handler_name(state, x=None, y=None):
return x + y
- self.assertEqual(node.dispatch('my_handler_name',
- arguments={'x': 10, 'y': 10}), 20)
+ assert node.dispatch('my_handler_name',
+ arguments={'x': 10, 'y': 10}) == 20
def test_dispatch_raising_SystemExit(self):
node = self.bound.Node('test_dispatch_raising_SystemExit',
@@ -153,7 +161,7 @@ class test_Mailbox(Case):
def my_handler_name(state):
raise SystemExit
- with self.assertRaises(SystemExit):
+ with pytest.raises(SystemExit):
node.dispatch('my_handler_name')
def test_dispatch_raising(self):
@@ -164,8 +172,8 @@ class test_Mailbox(Case):
raise KeyError('foo')
res = node.dispatch('my_handler_name')
- self.assertIn('error', res)
- self.assertIn('KeyError', res['error'])
+ assert 'error' in res
+ assert 'KeyError' in res['error']
def test_dispatch_replies(self):
_replied = [False]
@@ -183,7 +191,7 @@ class test_Mailbox(Case):
node.dispatch('my_handler_name',
arguments={'x': 10, 'y': 10},
reply_to={'exchange': 'foo', 'routing_key': 'bar'})
- self.assertTrue(_replied[0])
+ assert _replied[0]
def test_reply(self):
_replied = [(None, None, None)]
@@ -204,10 +212,10 @@ class test_Mailbox(Case):
'routing_key': 'rkey'},
ticket='TICKET')
data, exchange, routing_key, ticket = _replied[0]
- self.assertEqual(data, {'test_reply': 42})
- self.assertEqual(exchange, 'exchange')
- self.assertEqual(routing_key, 'rkey')
- self.assertEqual(ticket, 'TICKET')
+ assert data == {'test_reply': 42}
+ assert exchange == 'exchange'
+ assert routing_key == 'rkey'
+ assert ticket == 'TICKET'
def test_handle_message(self):
node = self.bound.Node('test_dispatch_from_message')
@@ -219,11 +227,11 @@ class test_Mailbox(Case):
body = {'method': 'my_handler_name',
'arguments': {'x': 64, 'y': 64}}
- self.assertEqual(node.handle_message(body, None), 64 * 64)
+ assert node.handle_message(body, None) == 64 * 64
# message not for me should not be processed.
body['destination'] = ['some_other_node']
- self.assertIsNone(node.handle_message(body, None))
+ assert node.handle_message(body, None) is None
def test_handle_message_adjusts_clock(self):
node = self.bound.Node('test_adjusts_clock')
@@ -239,49 +247,38 @@ class test_Mailbox(Case):
node.adjust_clock = Mock(name='adjust_clock')
res = node.handle_message(body, message)
node.adjust_clock.assert_called_with(313)
- self.assertEqual(res, 10)
+ assert res == 10
def test_listen(self):
consumer = self.node.listen()
- self.assertEqual(consumer.callbacks[0],
- self.node.handle_message)
- self.assertEqual(consumer.channel, self.default_chan)
+ assert consumer.callbacks[0] == self.node.handle_message
+ assert consumer.channel == self.default_chan
def test_cast(self):
self.bound.cast(['somenode'], 'mymethod')
consumer = self.node.Consumer()
- self.assertIsCast(self.get_next(consumer))
+ assert is_cast(self.get_next(consumer))
def test_abcast(self):
self.bound.abcast('mymethod')
consumer = self.node.Consumer()
- self.assertIsCast(self.get_next(consumer))
+ assert is_cast(self.get_next(consumer))
def test_call_destination_must_be_sequence(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.bound.call('some_node', 'mymethod')
def test_call(self):
- self.assertEqual(
- self.bound.call(['some_node'], 'mymethod'),
- 'COLLECTED',
- )
+ assert self.bound.call(['some_node'], 'mymethod') == 'COLLECTED'
consumer = self.node.Consumer()
- self.assertIsCall(self.get_next(consumer))
+ assert is_call(self.get_next(consumer))
def test_multi_call(self):
- self.assertEqual(self.bound.multi_call('mymethod'), 'COLLECTED')
+ assert self.bound.multi_call('mymethod') == 'COLLECTED'
consumer = self.node.Consumer()
- self.assertIsCall(self.get_next(consumer))
+ assert is_call(self.get_next(consumer))
def get_next(self, consumer):
m = consumer.queues[0].get()
if m:
return m.payload
-
- def assertIsCast(self, message):
- self.assertTrue(message['method'])
-
- def assertIsCall(self, message):
- self.assertTrue(message['method'])
- self.assertTrue(message['reply_to'])
diff --git a/kombu/tests/test_pools.py b/t/unit/test_pools.py
index 7a632eee..c70cc89d 100644
--- a/kombu/tests/test_pools.py
+++ b/t/unit/test_pools.py
@@ -1,14 +1,16 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock
+
from kombu import Connection, Producer
from kombu import pools
from kombu.connection import ConnectionPool
from kombu.utils.collections import eqhash
-from .case import Case, Mock
-
-class test_ProducerPool(Case):
+class test_ProducerPool:
Pool = pools.ProducerPool
class MyPool(pools.ProducerPool):
@@ -32,7 +34,7 @@ class test_ProducerPool(Case):
self.pool.Producer.side_effect = IOError()
acq = self.pool._acquire_connection = Mock()
conn = acq.return_value = Mock()
- with self.assertRaises(IOError):
+ with pytest.raises(IOError):
self.pool.create_producer()
conn.release.assert_called_with()
@@ -43,7 +45,7 @@ class test_ProducerPool(Case):
acq = self.pool._acquire_connection = Mock()
conn = acq.return_value = Mock()
p._channel = None
- with self.assertRaises(IOError):
+ with pytest.raises(IOError):
self.pool.prepare(pp)
conn.release.assert_called_with()
@@ -56,10 +58,10 @@ class test_ProducerPool(Case):
self.pool.release(p)
def test_init(self):
- self.assertIs(self.pool.connections, self.connections)
+ assert self.pool.connections is self.connections
def test_Producer(self):
- self.assertIsInstance(self.pool.Producer(Mock()), Producer)
+ assert isinstance(self.pool.Producer(Mock()), Producer)
def test_acquire_connection(self):
self.pool._acquire_connection()
@@ -68,20 +70,20 @@ class test_ProducerPool(Case):
def test_new(self):
promise = self.pool.new()
producer = promise()
- self.assertIsInstance(producer, Producer)
+ assert isinstance(producer, Producer)
self.connections.acquire.assert_called_with(block=True)
def test_setup_unlimited(self):
pool = self.Pool(self.connections, limit=None)
pool.setup()
- self.assertFalse(pool._resource.queue)
+ assert not pool._resource.queue
def test_setup(self):
- self.assertEqual(len(self.pool._resource.queue), self.pool.limit)
+ assert len(self.pool._resource.queue) == self.pool.limit
first = self.pool._resource.get_nowait()
producer = first()
- self.assertIsInstance(producer, Producer)
+ assert isinstance(producer, Producer)
def test_prepare(self):
connection = self.connections.acquire.return_value = Mock()
@@ -111,10 +113,10 @@ class test_ProducerPool(Case):
p.__connection__ = Mock()
self.pool.release(p)
p.__connection__.release.assert_called_with()
- self.assertIsNone(p.channel)
+ assert p.channel is None
-class test_PoolGroup(Case):
+class test_PoolGroup:
Group = pools.PoolGroup
class MyGroup(pools.PoolGroup):
@@ -124,47 +126,47 @@ class test_PoolGroup(Case):
def test_interface_create(self):
g = self.Group()
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
g.create(Mock(), 10)
def test_getitem_using_global_limit(self):
g = self.MyGroup(limit=pools.use_global_limit)
res = g['foo']
- self.assertTupleEqual(res, ('foo', pools.get_limit()))
+ assert res == ('foo', pools.get_limit())
def test_getitem_using_custom_limit(self):
g = self.MyGroup(limit=102456)
res = g['foo']
- self.assertTupleEqual(res, ('foo', 102456))
+ assert res == ('foo', 102456)
def test_delitem(self):
g = self.MyGroup()
g['foo']
del(g['foo'])
- self.assertNotIn('foo', g)
+ assert 'foo' not in g
def test_Connections(self):
conn = Connection('memory://')
p = pools.connections[conn]
- self.assertTrue(p)
- self.assertIsInstance(p, ConnectionPool)
- self.assertIs(p.connection, conn)
- self.assertEqual(p.limit, pools.get_limit())
+ assert p
+ assert isinstance(p, ConnectionPool)
+ assert p.connection is conn
+ assert p.limit == pools.get_limit()
def test_Producers(self):
conn = Connection('memory://')
p = pools.producers[conn]
- self.assertTrue(p)
- self.assertIsInstance(p, pools.ProducerPool)
- self.assertIs(p.connections, pools.connections[conn])
- self.assertEqual(p.limit, p.connections.limit)
- self.assertEqual(p.limit, pools.get_limit())
+ assert p
+ assert isinstance(p, pools.ProducerPool)
+ assert p.connections is pools.connections[conn]
+ assert p.limit == p.connections.limit
+ assert p.limit == pools.get_limit()
def test_all_groups(self):
conn = Connection('memory://')
pools.connections[conn]
- self.assertTrue(list(pools._all_pools()))
+ assert list(pools._all_pools())
def test_reset(self):
pools.reset()
@@ -181,7 +183,7 @@ class test_PoolGroup(Case):
pools.reset()
p1.force_close_all.assert_called_with()
- self.assertTrue(g1.clear_called)
+ assert g1.clear_called
p1 = pools.connections['foo'] = Mock()
p1.force_close_all.side_effect = KeyError()
@@ -191,23 +193,23 @@ class test_PoolGroup(Case):
pools.reset()
pools.set_limit(34576)
limit = pools.get_limit()
- self.assertEqual(limit, 34576)
+ assert limit == 34576
conn = Connection('memory://')
pool = pools.connections[conn]
with pool.acquire():
pools.set_limit(limit + 1)
- self.assertEqual(pools.get_limit(), limit + 1)
+ assert pools.get_limit() == limit + 1
limit = pools.get_limit()
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
pools.set_limit(limit - 1)
pools.set_limit(limit - 1, force=True)
- self.assertEqual(pools.get_limit(), limit - 1)
+ assert pools.get_limit() == limit - 1
pools.set_limit(pools.get_limit())
-class test_fun_PoolGroup(Case):
+class test_fun_PoolGroup:
def test_connections_behavior(self):
c1u = 'memory://localhost:123'
@@ -220,19 +222,19 @@ class test_fun_PoolGroup(Case):
assert eqhash(c1) == eqhash(c3)
c4 = Connection(c1u, transport_options={'confirm_publish': True})
- self.assertNotEqual(eqhash(c3), eqhash(c4))
+ assert eqhash(c3) != eqhash(c4)
p1 = pools.connections[c1]
p2 = pools.connections[c2]
p3 = pools.connections[c3]
- self.assertIsNot(p1, p2)
- self.assertIs(p1, p3)
+ assert p1 is not p2
+ assert p1 is p3
r1 = p1.acquire()
- self.assertTrue(p1._dirty)
- self.assertTrue(p3._dirty)
- self.assertFalse(p2._dirty)
+ assert p1._dirty
+ assert p3._dirty
+ assert not p2._dirty
r1.release()
- self.assertFalse(p1._dirty)
- self.assertFalse(p3._dirty)
+ assert not p1._dirty
+ assert not p3._dirty
diff --git a/kombu/tests/test_serialization.py b/t/unit/test_serialization.py
index d98b007e..88af0860 100644
--- a/kombu/tests/test_serialization.py
+++ b/t/unit/test_serialization.py
@@ -2,10 +2,13 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
+import pytest
import sys
from base64 import b64decode
+from case import call, mock, patch, skip
+
from kombu.exceptions import ContentDisallowed, EncodeError, DecodeError
from kombu.five import text_t, bytes_t
from kombu.serialization import (
@@ -17,8 +20,6 @@ from kombu.serialization import (
)
from kombu.utils.encoding import str_to_bytes
-from .case import Case, call, mock, patch, skip
-
# For content_encoding tests
unicode_string = 'abcdé\u8463'
unicode_string_as_utf8 = unicode_string.encode('utf-8')
@@ -72,38 +73,38 @@ registry.register('testS', lambda s: s, lambda s: 'decoded',
'application/testS', 'utf-8')
-class test_Serialization(Case):
+class test_Serialization:
def test_disable(self):
disabled = registry._disabled_content_types
try:
registry.disable('testS')
- self.assertIn('application/testS', disabled)
+ assert 'application/testS' in disabled
disabled.clear()
registry.disable('application/testS')
- self.assertIn('application/testS', disabled)
+ assert 'application/testS' in disabled
finally:
disabled.clear()
def test_enable(self):
registry._disabled_content_types.add('application/json')
registry.enable('json')
- self.assertNotIn('application/json', registry._disabled_content_types)
+ assert 'application/json' not in registry._disabled_content_types
registry._disabled_content_types.add('application/json')
registry.enable('application/json')
- self.assertNotIn('application/json', registry._disabled_content_types)
+ assert 'application/json' not in registry._disabled_content_types
def test_loads_when_disabled(self):
disabled = registry._disabled_content_types
try:
registry.disable('testS')
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
loads('xxd', 'application/testS', 'utf-8', force=False)
ret = loads('xxd', 'application/testS', 'utf-8', force=True)
- self.assertEqual(ret, 'decoded')
+ assert ret == 'decoded'
finally:
disabled.clear()
@@ -111,50 +112,36 @@ class test_Serialization(Case):
loads(None, 'application/testS', 'utf-8')
def test_content_type_decoding(self):
- self.assertEqual(
- unicode_string,
- loads(unicode_string_as_utf8,
- content_type='plain/text', content_encoding='utf-8'),
- )
- self.assertEqual(
- latin_string,
- loads(latin_string_as_latin1,
- content_type='application/data', content_encoding='latin-1'),
- )
+ assert loads(
+ unicode_string_as_utf8,
+ content_type='plain/text',
+ content_encoding='utf-8') == unicode_string
+ assert loads(
+ latin_string_as_latin1,
+ content_type='application/data',
+ content_encoding='latin-1') == latin_string
def test_content_type_binary(self):
- self.assertIsInstance(
+ assert isinstance(
loads(unicode_string_as_utf8,
content_type='application/data', content_encoding='binary'),
- bytes_t,
- )
+ bytes_t)
- self.assertEqual(
+ assert loads(
unicode_string_as_utf8,
- loads(unicode_string_as_utf8,
- content_type='application/data', content_encoding='binary'),
- )
+ content_type='application/data',
+ content_encoding='binary') == unicode_string_as_utf8
def test_content_type_encoding(self):
# Using the 'raw' serializer
- self.assertEqual(
- unicode_string_as_utf8,
- dumps(unicode_string, serializer='raw')[-1],
- )
- self.assertEqual(
- latin_string_as_utf8,
- dumps(latin_string, serializer='raw')[-1],
- )
+ assert (dumps(unicode_string, serializer='raw')[-1] ==
+ unicode_string_as_utf8)
+ assert (dumps(latin_string, serializer='raw')[-1] ==
+ latin_string_as_utf8)
# And again w/o a specific serializer to check the
# code where we force unicode objects into a string.
- self.assertEqual(
- unicode_string_as_utf8,
- dumps(unicode_string)[-1],
- )
- self.assertEqual(
- latin_string_as_utf8,
- dumps(latin_string)[-1],
- )
+ assert dumps(unicode_string)[-1] == unicode_string_as_utf8
+ assert dumps(latin_string)[-1] == latin_string_as_utf8
def test_enable_insecure_serializers(self):
with patch('kombu.serialization.registry') as registry:
@@ -182,34 +169,31 @@ class test_Serialization(Case):
])
def test_reraises_EncodeError(self):
- with self.assertRaises(EncodeError):
+ with pytest.raises(EncodeError):
dumps([object()], serializer='json')
def test_reraises_DecodeError(self):
- with self.assertRaises(DecodeError):
+ with pytest.raises(DecodeError):
loads(object(), content_type='application/json',
content_encoding='utf-8')
def test_json_loads(self):
- self.assertEqual(
- py_data,
- loads(json_data,
- content_type='application/json', content_encoding='utf-8'),
- )
+ assert loads(json_data,
+ content_type='application/json',
+ content_encoding='utf-8') == py_data
def test_json_dumps(self):
- self.assertEqual(
- loads(
- dumps(py_data, serializer='json')[-1],
- content_type='application/json',
- content_encoding='utf-8',
- ),
- loads(
- json_data,
- content_type='application/json',
- content_encoding='utf-8',
- ),
+ a = loads(
+ dumps(py_data, serializer='json')[-1],
+ content_type='application/json',
+ content_encoding='utf-8',
+ )
+ b = loads(
+ json_data,
+ content_type='application/json',
+ content_encoding='utf-8',
)
+ assert a == b
@skip.if_pypy()
@skip.unless_module('msgpack', (ImportError, ValueError))
@@ -224,122 +208,109 @@ class test_Serialization(Case):
res[k] = v.encode()
if isinstance(v, (list, tuple)):
res[k] = [i.encode() for i in v]
- self.assertEqual(
- msgpack_py_data,
- res,
- )
+ assert res == msgpack_py_data
@skip.if_pypy()
@skip.unless_module('msgpack', (ImportError, ValueError))
def test_msgpack_dumps(self):
register_msgpack()
- self.assertEqual(
- loads(
- dumps(msgpack_py_data, serializer='msgpack')[-1],
- content_type='application/x-msgpack',
- content_encoding='binary',
- ),
- loads(
- msgpack_data,
- content_type='application/x-msgpack',
- content_encoding='binary',
- ),
+ a = loads(
+ dumps(msgpack_py_data, serializer='msgpack')[-1],
+ content_type='application/x-msgpack',
+ content_encoding='binary',
)
+ b = loads(
+ msgpack_data,
+ content_type='application/x-msgpack',
+ content_encoding='binary',
+ )
+ assert a == b
@skip.unless_module('yaml')
def test_yaml_loads(self):
register_yaml()
- self.assertEqual(
- py_data,
- loads(yaml_data,
- content_type='application/x-yaml',
- content_encoding='utf-8'),
- )
+ assert loads(
+ yaml_data,
+ content_type='application/x-yaml',
+ content_encoding='utf-8') == py_data
@skip.unless_module('yaml')
def test_yaml_dumps(self):
register_yaml()
- self.assertEqual(
- loads(
- dumps(py_data, serializer='yaml')[-1],
- content_type='application/x-yaml',
- content_encoding='utf-8',
- ),
- loads(
- yaml_data,
- content_type='application/x-yaml',
- content_encoding='utf-8',
- ),
+ a = loads(
+ dumps(py_data, serializer='yaml')[-1],
+ content_type='application/x-yaml',
+ content_encoding='utf-8',
+ )
+ b = loads(
+ yaml_data,
+ content_type='application/x-yaml',
+ content_encoding='utf-8',
)
+ assert a == b
def test_pickle_loads(self):
- self.assertEqual(
- py_data,
- loads(pickle_data,
- content_type='application/x-python-serialize',
- content_encoding='binary'),
- )
+ assert loads(
+ pickle_data,
+ content_type='application/x-python-serialize',
+ content_encoding='binary') == py_data
def test_pickle_dumps(self):
- self.assertEqual(
- pickle.loads(pickle_data),
- pickle.loads(dumps(py_data, serializer='pickle')[-1]),
- )
+ a = pickle.loads(pickle_data),
+ b = pickle.loads(dumps(py_data, serializer='pickle')[-1]),
+ assert a == b
def test_register(self):
register(None, None, None, None)
def test_unregister(self):
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
unregister('nonexisting')
dumps('foo', serializer='pickle')
unregister('pickle')
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
dumps('foo', serializer='pickle')
register_pickle()
def test_set_default_serializer_missing(self):
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
registry._set_default_serializer('nonexisting')
def test_dumps_missing(self):
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
dumps('foo', serializer='nonexisting')
def test_dumps__no_serializer(self):
ctyp, cenc, data = dumps(str_to_bytes('foo'))
- self.assertEqual(ctyp, 'application/data')
- self.assertEqual(cenc, 'binary')
+ assert ctyp == 'application/data'
+ assert cenc == 'binary'
def test_loads__trusted_content(self):
loads('tainted', 'application/data', 'binary', accept=[])
loads('tainted', 'application/text', 'utf-8', accept=[])
def test_loads__not_accepted(self):
- with self.assertRaises(ContentDisallowed):
+ with pytest.raises(ContentDisallowed):
loads('tainted', 'application/x-evil', 'binary', accept=[])
- with self.assertRaises(ContentDisallowed):
+ with pytest.raises(ContentDisallowed):
loads('tainted', 'application/x-evil', 'binary',
accept=['application/x-json'])
- self.assertTrue(
- loads('tainted', 'application/x-doomsday', 'binary',
- accept=['application/x-doomsday'])
- )
+ assert loads('tainted', 'application/x-doomsday', 'binary',
+ accept=['application/x-doomsday'])
def test_raw_encode(self):
- self.assertTupleEqual(
- raw_encode('foo'.encode('utf-8')),
- ('application/data', 'binary', 'foo'.encode('utf-8')),
+ assert raw_encode('foo'.encode('utf-8')) == (
+ 'application/data', 'binary', 'foo'.encode('utf-8'),
)
@mock.mask_modules('yaml')
def test_register_yaml__no_yaml(self):
register_yaml()
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
loads('foo', 'application/x-yaml', 'utf-8')
@mock.mask_modules('msgpack')
def test_register_msgpack__no_msgpack(self):
register_msgpack()
- with self.assertRaises(SerializerNotInstalled):
+ with pytest.raises(SerializerNotInstalled):
loads('foo', 'application/x-msgpack', 'utf-8')
diff --git a/kombu/tests/test_simple.py b/t/unit/test_simple.py
index a0d68eff..e1f7cb22 100644
--- a/kombu/tests/test_simple.py
+++ b/t/unit/test_simple.py
@@ -1,12 +1,13 @@
from __future__ import absolute_import, unicode_literals
-from kombu import Connection, Exchange, Queue
+import pytest
+
+from case import Mock
-from .case import Case, Mock
+from kombu import Connection, Exchange, Queue
-class SimpleBase(Case):
- abstract = True
+class SimpleBase:
def Queue(self, name, *args, **kwargs):
q = name
@@ -20,83 +21,67 @@ class SimpleBase(Case):
raise NotImplementedError()
def setup(self):
- if not self.abstract:
- self.connection = Connection(transport='memory')
- with self.connection.channel() as channel:
- channel.exchange_declare('amq.direct')
- self.q = self.Queue(None, no_ack=True)
+ self.connection = Connection(transport='memory')
+ with self.connection.channel() as channel:
+ channel.exchange_declare('amq.direct')
+ self.q = self.Queue(None, no_ack=True)
def teardown(self):
- if not self.abstract:
- self.q.close()
- self.connection.close()
+ self.q.close()
+ self.connection.close()
def test_produce__consume(self):
- if self.abstract:
- return
q = self.Queue('test_produce__consume', no_ack=True)
q.put({'hello': 'Simple'})
- self.assertEqual(q.get(timeout=1).payload, {'hello': 'Simple'})
- with self.assertRaises(q.Empty):
+ assert q.get(timeout=1).payload == {'hello': 'Simple'}
+ with pytest.raises(q.Empty):
q.get(timeout=0.1)
def test_produce__basic_get(self):
- if self.abstract:
- return
q = self.Queue('test_produce__basic_get', no_ack=True)
q.put({'hello': 'SimpleSync'})
- self.assertEqual(q.get_nowait().payload, {'hello': 'SimpleSync'})
- with self.assertRaises(q.Empty):
+ assert q.get_nowait().payload == {'hello': 'SimpleSync'}
+ with pytest.raises(q.Empty):
q.get_nowait()
q.put({'hello': 'SimpleSync'})
- self.assertEqual(q.get(block=False).payload, {'hello': 'SimpleSync'})
- with self.assertRaises(q.Empty):
+ assert q.get(block=False).payload == {'hello': 'SimpleSync'}
+ with pytest.raises(q.Empty):
q.get(block=False)
def test_clear(self):
- if self.abstract:
- return
q = self.Queue('test_clear', no_ack=True)
for i in range(10):
q.put({'hello': 'SimplePurge%d' % (i,)})
- self.assertEqual(q.clear(), 10)
+ assert q.clear() == 10
def test_enter_exit(self):
- if self.abstract:
- return
q = self.Queue('test_enter_exit')
q.close = Mock()
- self.assertIs(q.__enter__(), q)
+ assert q.__enter__() is q
q.__exit__()
q.close.assert_called_with()
def test_qsize(self):
- if self.abstract:
- return
q = self.Queue('test_clear', no_ack=True)
for i in range(10):
q.put({'hello': 'SimplePurge%d' % (i,)})
- self.assertEqual(q.qsize(), 10)
- self.assertEqual(len(q), 10)
+ assert q.qsize() == 10
+ assert len(q) == 10
def test_autoclose(self):
- if self.abstract:
- return
channel = self.connection.channel()
q = self.Queue('test_autoclose', no_ack=True, channel=channel)
q.close()
def test_custom_Queue(self):
- if self.abstract:
- return
n = self.__class__.__name__
exchange = Exchange('%s-test.custom.Queue' % (n,))
queue = Queue('%s-test.custom.Queue' % (n,),
@@ -104,33 +89,29 @@ class SimpleBase(Case):
'my.routing.key')
q = self.Queue(queue)
- self.assertEqual(q.consumer.queues[0], queue)
+ assert q.consumer.queues[0] == queue
q.close()
def test_bool(self):
- if self.abstract:
- return
q = self.Queue('test_nonzero')
- self.assertTrue(q)
+ assert q
class test_SimpleQueue(SimpleBase):
- abstract = False
def _Queue(self, *args, **kwargs):
return self.connection.SimpleQueue(*args, **kwargs)
def test_is_ack(self):
q = self.Queue('test_is_no_ack')
- self.assertFalse(q.no_ack)
+ assert not q.no_ack
class test_SimpleBuffer(SimpleBase):
- abstract = False
def Queue(self, *args, **kwargs):
return self.connection.SimpleBuffer(*args, **kwargs)
def test_is_no_ack(self):
q = self.Queue('test_is_no_ack')
- self.assertTrue(q.no_ack)
+ assert q.no_ack
diff --git a/kombu/tests/test_syn.py b/t/unit/test_syn.py
index bf5fd972..eaa80f90 100644
--- a/kombu/tests/test_syn.py
+++ b/t/unit/test_syn.py
@@ -4,45 +4,45 @@ import socket
import sys
import types
+from case import mock, patch
+
from kombu import syn
from kombu.five import bytes_if_py2
-from kombu.tests.case import Case, mock, patch
-
-class test_syn(Case):
+class test_syn:
def test_compat(self):
- self.assertEqual(syn.blocking(lambda: 10), 10)
+ assert syn.blocking(lambda: 10) == 10
syn.select_blocking_method('foo')
def test_detect_environment(self):
try:
syn._environment = None
X = syn.detect_environment()
- self.assertEqual(syn._environment, X)
+ assert syn._environment == X
Y = syn.detect_environment()
- self.assertEqual(Y, X)
+ assert Y == X
finally:
syn._environment = None
@mock.module_exists('eventlet', 'eventlet.patcher')
def test_detect_environment_eventlet(self):
with patch('eventlet.patcher.is_monkey_patched', create=True) as m:
- self.assertTrue(sys.modules['eventlet'])
+ assert sys.modules['eventlet']
m.return_value = True
env = syn._detect_environment()
m.assert_called_with(socket)
- self.assertEqual(env, 'eventlet')
+ assert env == 'eventlet'
@mock.module_exists('gevent')
def test_detect_environment_gevent(self):
with patch('gevent.socket', create=True) as m:
prev, socket.socket = socket.socket, m.socket
try:
- self.assertTrue(sys.modules['gevent'])
+ assert sys.modules['gevent']
env = syn._detect_environment()
- self.assertEqual(env, 'gevent')
+ assert env == 'gevent'
finally:
socket.socket = prev
@@ -52,14 +52,14 @@ class test_syn(Case):
bytes_if_py2('eventlet'))
sys.modules['eventlet.patcher'] = types.ModuleType(
bytes_if_py2('patcher'))
- self.assertEqual(syn._detect_environment(), 'default')
+ assert syn._detect_environment() == 'default'
finally:
sys.modules.pop('eventlet.patcher', None)
sys.modules.pop('eventlet', None)
syn._detect_environment()
try:
sys.modules['gevent'] = types.ModuleType(bytes_if_py2('gevent'))
- self.assertEqual(syn._detect_environment(), 'default')
+ assert syn._detect_environment() == 'default'
finally:
sys.modules.pop('gevent', None)
syn._detect_environment()
diff --git a/kombu/tests/transport/virtual/__init__.py b/t/unit/transport/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/transport/virtual/__init__.py
+++ b/t/unit/transport/__init__.py
diff --git a/kombu/tests/transport/test_SQS.py b/t/unit/transport/test_SQS.py
index 98ea3401..5032d43c 100644
--- a/kombu/tests/transport/test_SQS.py
+++ b/t/unit/transport/test_SQS.py
@@ -7,6 +7,10 @@ slightly.
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import skip
+
from kombu import five
from kombu import messaging
from kombu import Connection, Exchange, Queue
@@ -14,8 +18,6 @@ from kombu import Connection, Exchange, Queue
from kombu.transport import SQS
from kombu.async.aws.ext import exception
-from kombu.tests.case import Case, skip
-
class SQSQueueMock(object):
@@ -95,7 +97,7 @@ class SQSConnectionMock(object):
@skip.unless_module('boto')
-class test_Channel(Case):
+class test_Channel:
def handleMessageCallback(self, message):
self.callback_message = message
@@ -141,9 +143,20 @@ class test_Channel(Case):
callback=self.handleMessageCallback,
consumer_tag='unittest')
+ def teardown(self):
+ # Removes QoS reserved messages so we don't restore msgs on shutdown.
+ try:
+ qos = self.channel._qos
+ except AttributeError:
+ pass
+ else:
+ if qos:
+ qos._dirty.clear()
+ qos._delivered.clear()
+
def test_init(self):
"""kombu.SQS.Channel instantiates correctly with mocked queues"""
- self.assertIn(self.queue_name, self.channel._queue_cache)
+ assert self.queue_name in self.channel._queue_cache
def test_auth_fail(self):
normal_func = SQS.Channel.sqs.get_all_queues
@@ -159,11 +172,11 @@ class test_Channel(Case):
try:
SQS.Channel.sqs.access_key = '1234'
SQS.Channel.sqs.get_all_queues = get_all_queues_fail_403
- with self.assertRaises(RuntimeError) as context:
+ with pytest.raises(RuntimeError) as excinfo:
self.channel = self.connection.channel()
- self.assertIn('access_key=1234', str(context.exception))
+ assert 'access_key=1234' in str(excinfo.value)
SQS.Channel.sqs.get_all_queues = get_all_queues_fail_not_403
- with self.assertRaises(exception.SQSError) as context:
+ with pytest.raises(exception.SQSError):
self.channel = self.connection.channel()
finally:
SQS.Channel.sqs.get_all_queues = normal_func
@@ -171,7 +184,7 @@ class test_Channel(Case):
def test_new_queue(self):
queue_name = 'new_unittest_queue'
self.channel._new_queue(queue_name)
- self.assertIn(queue_name, self.sqs_conn_mock.queues)
+ assert queue_name in self.sqs_conn_mock.queues
# For cleanup purposes, delete the queue and the queue file
self.channel._delete(queue_name)
@@ -181,16 +194,16 @@ class test_Channel(Case):
# first 1000 queues sorted by name.
queue_name = 'unittest_queue'
self.channel._new_queue(queue_name)
- self.assertIn(queue_name, self.sqs_conn_mock.queues)
+ assert queue_name in self.sqs_conn_mock.queues
q = self.sqs_conn_mock.get_queue(queue_name)
- self.assertEqual(1, q.count())
- self.assertEqual('hello', q.read())
+ assert 1 == q.count()
+ assert 'hello' == q.read()
def test_delete(self):
queue_name = 'new_unittest_queue'
self.channel._new_queue(queue_name)
self.channel._delete(queue_name)
- self.assertNotIn(queue_name, self.channel._queue_cache)
+ assert queue_name not in self.channel._queue_cache
def test_get_from_sqs(self):
# Test getting a single message
@@ -198,7 +211,7 @@ class test_Channel(Case):
self.producer.publish(message)
q = self.channel._new_queue(self.queue_name)
results = q.get_messages()
- self.assertEqual(len(results), 1)
+ assert len(results) == 1
# Now test getting many messages
for i in range(3):
@@ -206,14 +219,14 @@ class test_Channel(Case):
self.producer.publish(message)
results = q.get_messages(num_messages=3)
- self.assertEqual(len(results), 3)
+ assert len(results) == 3
def test_get_with_empty_list(self):
- with self.assertRaises(five.Empty):
+ with pytest.raises(five.Empty):
self.channel._get(self.queue_name)
def test_get_bulk_raises_empty(self):
- with self.assertRaises(five.Empty):
+ with pytest.raises(five.Empty):
self.channel._get_bulk(self.queue_name)
def test_messages_to_python(self):
@@ -243,27 +256,27 @@ class test_Channel(Case):
)
# We got the same number of payloads back, right?
- self.assertEqual(len(kombu_payloads), kombu_message_count)
- self.assertEqual(len(json_payloads), json_message_count)
+ assert len(kombu_payloads) == kombu_message_count
+ assert len(json_payloads) == json_message_count
# Make sure they're payload-style objects
for p in kombu_payloads:
- self.assertIn('properties', p)
+ assert 'properties' in p
for p in json_payloads:
- self.assertIn('properties', p)
+ assert 'properties' in p
def test_put_and_get(self):
message = 'my test message'
self.producer.publish(message)
results = self.queue(self.channel).get().payload
- self.assertEqual(message, results)
+ assert message == results
def test_put_and_get_bulk(self):
# With QoS.prefetch_count = 0
message = 'my test message'
self.producer.publish(message)
results = self.channel._get_bulk(self.queue_name)
- self.assertEqual(1, len(results))
+ assert 1 == len(results)
def test_puts_and_get_bulk(self):
# Generate 8 messages
@@ -280,19 +293,19 @@ class test_Channel(Case):
# Count how many messages are retrieved the first time. Should
# be 5 (message_count).
results = self.channel._get_bulk(self.queue_name)
- self.assertEqual(5, len(results))
+ assert 5 == len(results)
for i, r in enumerate(results):
self.channel.qos.append(r, i)
# Now, do the get again, the number of messages returned should be 1.
results = self.channel._get_bulk(self.queue_name)
- self.assertEqual(len(results), 1)
+ assert len(results) == 1
def test_drain_events_with_empty_list(self):
def mock_can_consume():
return False
self.channel.qos.can_consume = mock_can_consume
- with self.assertRaises(five.Empty):
+ with pytest.raises(five.Empty):
self.channel.drain_events()
def test_drain_events_with_prefetch_5(self):
@@ -312,9 +325,8 @@ class test_Channel(Case):
self.channel.drain_events()
# How many times was the SQSConnectionMock get_message method called?
- self.assertEqual(
- expected_get_message_count,
- self.channel._queue_cache[self.queue_name]._get_message_calls)
+ assert (expected_get_message_count ==
+ self.channel._queue_cache[self.queue_name]._get_message_calls)
def test_drain_events_with_prefetch_none(self):
# Generate 20 messages
@@ -333,6 +345,5 @@ class test_Channel(Case):
self.channel.drain_events()
# How many times was the SQSConnectionMock get_message method called?
- self.assertEqual(
- expected_get_message_count,
- self.channel._queue_cache[self.queue_name]._get_message_calls)
+ assert (expected_get_message_count ==
+ self.channel._queue_cache[self.queue_name]._get_message_calls)
diff --git a/kombu/tests/transport/test_base.py b/t/unit/transport/test_base.py
index 77d2b144..bf1d551c 100644
--- a/kombu/tests/transport/test_base.py
+++ b/t/unit/transport/test_base.py
@@ -1,14 +1,16 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock
+
from kombu import Connection, Consumer, Exchange, Producer, Queue
from kombu.five import text_t
from kombu.message import Message
from kombu.transport.base import StdChannel, Transport, Management
-from kombu.tests.case import Case, Mock
-
-class test_StdChannel(Case):
+class test_StdChannel:
def setup(self):
self.conn = Connection('memory://')
@@ -20,25 +22,23 @@ class test_StdChannel(Case):
q = Queue('foo', Exchange('foo'))
print(self.channel.queues)
cons = self.channel.Consumer(q)
- self.assertIsInstance(cons, Consumer)
- self.assertIs(cons.channel, self.channel)
+ assert isinstance(cons, Consumer)
+ assert cons.channel is self.channel
def test_Producer(self):
prod = self.channel.Producer()
- self.assertIsInstance(prod, Producer)
- self.assertIs(prod.channel, self.channel)
+ assert isinstance(prod, Producer)
+ assert prod.channel is self.channel
def test_interface_get_bindings(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
StdChannel().get_bindings()
def test_interface_after_reply_message_received(self):
- self.assertIsNone(
- StdChannel().after_reply_message_received(Queue('foo')),
- )
+ assert StdChannel().after_reply_message_received(Queue('foo')) is None
-class test_Message(Case):
+class test_Message:
def setup(self):
self.conn = Connection('memory://')
@@ -47,7 +47,7 @@ class test_Message(Case):
def test_postencode(self):
m = Message(self.channel, text_t('FOO'), postencode='ccyzz')
- with self.assertRaises(LookupError):
+ with pytest.raises(LookupError):
m._reraise_error()
m.ack()
@@ -57,7 +57,7 @@ class test_Message(Case):
ack = self.channel.basic_ack = Mock()
self.message.ack()
- self.assertNotEqual(self.message._state, 'ACK')
+ assert self.message._state != 'ACK'
ack.assert_not_called()
def test_ack_missing_consumer_tag(self):
@@ -88,7 +88,7 @@ class test_Message(Case):
self.message.ack_log_error(logger, KeyError)
ack.assert_called_with(multiple=False)
logger.critical.assert_called()
- self.assertIn("Couldn't ack", logger.critical.call_args[0][0])
+ assert "Couldn't ack" in logger.critical.call_args[0][0]
def test_reject_log_error_when_no_error(self):
reject = self.message.reject = Mock()
@@ -102,36 +102,36 @@ class test_Message(Case):
self.message.reject_log_error(logger, KeyError)
reject.assert_called_with(requeue=False)
logger.critical.assert_called()
- self.assertIn("Couldn't reject", logger.critical.call_args[0][0])
+ assert "Couldn't reject" in logger.critical.call_args[0][0]
-class test_interface(Case):
+class test_interface:
def test_establish_connection(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
Transport(None).establish_connection()
def test_close_connection(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
Transport(None).close_connection(None)
def test_create_channel(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
Transport(None).create_channel(None)
def test_close_channel(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
Transport(None).close_channel(None)
def test_drain_events(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
Transport(None).drain_events(None)
def test_heartbeat_check(self):
Transport(None).heartbeat_check(Mock(name='connection'))
def test_driver_version(self):
- self.assertTrue(Transport(None).driver_version())
+ assert Transport(None).driver_version()
def test_register_with_event_loop(self):
Transport(None).register_with_event_loop(
@@ -144,12 +144,12 @@ class test_interface(Case):
)
def test_manager(self):
- self.assertTrue(Transport(None).manager)
+ assert Transport(None).manager
-class test_Management(Case):
+class test_Management:
def test_get_bindings(self):
m = Management(Mock(name='transport'))
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
m.get_bindings()
diff --git a/kombu/tests/transport/test_consul.py b/t/unit/transport/test_consul.py
index ca2bb4bc..43b2b165 100644
--- a/kombu/tests/transport/test_consul.py
+++ b/t/unit/transport/test_consul.py
@@ -1,62 +1,60 @@
from __future__ import absolute_import, unicode_literals
-from kombu.five import Empty
+import pytest
-from kombu.transport.consul import Channel, Transport
+from case import Mock, skip
-from kombu.tests.case import Case, Mock, skip
+from kombu.five import Empty
+from kombu.transport.consul import Channel, Transport
@skip.unless_module('consul')
-class test_Consul(Case):
+class test_Consul:
def setup(self):
self.connection = Mock()
self.connection.client.transport_options = {}
self.connection.client.port = 303
- self.consul = self.patch('consul.Consul').return_value
+ self.consul = self.patching('consul.Consul').return_value
self.channel = Channel(connection=self.connection)
def test_driver_version(self):
- self.assertTrue(Transport(self.connection.client).driver_version())
+ assert Transport(self.connection.client).driver_version()
def test_failed_get(self):
self.channel._acquire_lock = Mock(return_value=False)
self.channel.client.kv.get.return_value = (1, None)
- with self.assertRaises(Empty):
+ with pytest.raises(Empty):
self.channel._get('empty')()
def test_test_purge(self):
self.channel._destroy_session = Mock(return_value=True)
self.consul.kv.delete = Mock(return_value=True)
- self.assertTrue(self.channel._purge('foo'))
+ assert self.channel._purge('foo')
def test_variables(self):
- self.assertEqual(self.channel.session_ttl, 30)
- self.assertEqual(self.channel.timeout, '10s')
+ assert self.channel.session_ttl == 30
+ assert self.channel.timeout == '10s'
def test_lock_key(self):
key = self.channel._lock_key('myqueue')
- self.assertEqual(key, 'kombu/myqueue.lock')
+ assert key == 'kombu/myqueue.lock'
def test_key_prefix(self):
key = self.channel._key_prefix('myqueue')
- self.assertEqual(key, 'kombu/myqueue')
+ assert key == 'kombu/myqueue'
def test_get_or_create_session(self):
queue = 'myqueue'
session_id = '123456'
self.consul.session.create.return_value = session_id
- self.assertEqual(
- self.channel._get_or_create_session(queue),
- session_id,
- )
+ assert self.channel._get_or_create_session(queue) == session_id
def test_create_delete_queue(self):
queue = 'mynewqueue'
self.consul.kv.put.return_value = True
- self.assertTrue(self.channel._new_queue(queue))
+ assert self.channel._new_queue(queue)
self.consul.kv.delete.return_value = True
self.channel._destroy_session = Mock()
@@ -64,7 +62,7 @@ class test_Consul(Case):
def test_size(self):
self.consul.kv.get.return_value = [(1, {}), (2, {})]
- self.assertEqual(self.channel._size('q'), 2)
+ assert self.channel._size('q') == 2
def test_get(self):
self.channel._obtain_lock = Mock(return_value=True)
@@ -76,8 +74,8 @@ class test_Consul(Case):
self.consul.kv.delete.return_value = True
- self.assertIsNotNone(self.channel._get('myqueue'))
+ assert self.channel._get('myqueue') is not None
def test_put(self):
self.consul.kv.put.return_value = True
- self.assertIsNone(self.channel._put('myqueue', 'mydata'))
+ assert self.channel._put('myqueue', 'mydata') is None
diff --git a/kombu/tests/transport/test_filesystem.py b/t/unit/transport/test_filesystem.py
index b2ab368d..52a925f0 100644
--- a/kombu/tests/transport/test_filesystem.py
+++ b/t/unit/transport/test_filesystem.py
@@ -2,16 +2,17 @@ from __future__ import absolute_import, unicode_literals
import tempfile
-from kombu import Connection, Exchange, Queue, Consumer, Producer
-
+from case import skip
from case.skip import SkipTest
-from kombu.tests.case import Case, skip
+
+from kombu import Connection, Exchange, Queue, Consumer, Producer
@skip.if_win32()
-class test_FilesystemTransport(Case):
+class test_FilesystemTransport:
def setup(self):
+ self.channels = set()
try:
data_folder_in = tempfile.mkdtemp()
data_folder_out = tempfile.mkdtemp()
@@ -22,11 +23,13 @@ class test_FilesystemTransport(Case):
'data_folder_in': data_folder_in,
'data_folder_out': data_folder_out,
})
+ self.channels.add(self.c.default_channel)
self.p = Connection(transport='filesystem',
transport_options={
'data_folder_in': data_folder_out,
'data_folder_out': data_folder_in,
})
+ self.channels.add(self.p.default_channel)
self.e = Exchange('test_transport_filesystem')
self.q = Queue('test_transport_filesystem',
exchange=self.e,
@@ -35,9 +38,26 @@ class test_FilesystemTransport(Case):
exchange=self.e,
routing_key='test_transport_filesystem2')
+ def teardown(self):
+ # make sure we don't attempt to restore messages at shutdown.
+ for channel in self.channels:
+ try:
+ channel._qos._dirty.clear()
+ except AttributeError:
+ pass
+ try:
+ channel._qos._delivered.clear()
+ except AttributeError:
+ pass
+
+ def _add_channel(self, channel):
+ self.channels.add(channel)
+ return channel
+
def test_produce_consume_noack(self):
- producer = Producer(self.p.channel(), self.e)
- consumer = Consumer(self.c.channel(), self.q, no_ack=True)
+ producer = Producer(self._add_channel(self.p.channel()), self.e)
+ consumer = Consumer(self._add_channel(self.c.channel()), self.q,
+ no_ack=True)
for i in range(10):
producer.publish({'foo': i},
@@ -56,11 +76,11 @@ class test_FilesystemTransport(Case):
break
self.c.drain_events()
- self.assertEqual(len(_received), 10)
+ assert len(_received) == 10
def test_produce_consume(self):
- producer_channel = self.p.channel()
- consumer_channel = self.c.channel()
+ producer_channel = self._add_channel(self.p.channel())
+ consumer_channel = self._add_channel(self.c.channel())
producer = Producer(producer_channel, self.e)
consumer1 = Consumer(consumer_channel, self.q)
consumer2 = Consumer(consumer_channel, self.q2)
@@ -95,28 +115,28 @@ class test_FilesystemTransport(Case):
break
self.c.drain_events()
- self.assertEqual(len(_received1) + len(_received2), 20)
+ assert len(_received1) + len(_received2) == 20
# compression
producer.publish({'compressed': True},
routing_key='test_transport_filesystem',
compression='zlib')
m = self.q(consumer_channel).get()
- self.assertDictEqual(m.payload, {'compressed': True})
+ assert m.payload == {'compressed': True}
# queue.delete
for i in range(10):
producer.publish({'foo': i},
routing_key='test_transport_filesystem')
- self.assertTrue(self.q(consumer_channel).get())
+ assert self.q(consumer_channel).get()
self.q(consumer_channel).delete()
self.q(consumer_channel).declare()
- self.assertIsNone(self.q(consumer_channel).get())
+ assert self.q(consumer_channel).get() is None
# queue.purge
for i in range(10):
producer.publish({'foo': i},
routing_key='test_transport_filesystem2')
- self.assertTrue(self.q2(consumer_channel).get())
+ assert self.q2(consumer_channel).get()
self.q2(consumer_channel).purge()
- self.assertIsNone(self.q2(consumer_channel).get())
+ assert self.q2(consumer_channel).get() is None
diff --git a/kombu/tests/transport/test_librabbitmq.py b/t/unit/transport/test_librabbitmq.py
index bbfbe961..26dbfbcd 100644
--- a/kombu/tests/transport/test_librabbitmq.py
+++ b/t/unit/transport/test_librabbitmq.py
@@ -1,5 +1,9 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock, patch, skip
+
try:
import librabbitmq
except ImportError:
@@ -7,11 +11,9 @@ except ImportError:
else:
from kombu.transport import librabbitmq # noqa
-from kombu.tests.case import Case, Mock, patch, skip
-
@skip.unless_module('librabbitmq')
-class lrmqCase(Case):
+class lrmqCase:
pass
@@ -22,9 +24,9 @@ class test_Message(lrmqCase):
message = librabbitmq.Message(
chan, {'prop': 42}, {'delivery_tag': 337}, 'body',
)
- self.assertEqual(message.body, 'body')
- self.assertEqual(message.delivery_tag, 337)
- self.assertEqual(message.properties['prop'], 42)
+ assert message.body == 'body'
+ assert message.delivery_tag == 337
+ assert message.properties['prop'] == 42
class test_Channel(lrmqCase):
@@ -32,7 +34,7 @@ class test_Channel(lrmqCase):
def test_prepare_message(self):
conn = Mock(name='connection')
chan = librabbitmq.Channel(conn, 1)
- self.assertTrue(chan)
+ assert chan
body = 'the quick brown fox...'
properties = {'name': 'Elaine M.'}
@@ -45,32 +47,31 @@ class test_Channel(lrmqCase):
headers={'H': 2},
)
- self.assertEqual(props2['name'], 'Elaine M.')
- self.assertEqual(props2['priority'], 999)
- self.assertEqual(props2['content_type'], 'ctype')
- self.assertEqual(props2['content_encoding'], 'cenc')
- self.assertEqual(props2['headers'], {'H': 2})
- self.assertEqual(body2, body)
+ assert props2['name'] == 'Elaine M.'
+ assert props2['priority'] == 999
+ assert props2['content_type'] == 'ctype'
+ assert props2['content_encoding'] == 'cenc'
+ assert props2['headers'] == {'H': 2}
+ assert body2 == body
body3, props3 = chan.prepare_message(body, priority=777)
- self.assertEqual(props3['priority'], 777)
- self.assertEqual(body3, body)
+ assert props3['priority'] == 777
+ assert body3 == body
class test_Transport(lrmqCase):
def setup(self):
- super(test_Transport, self).setup()
self.client = Mock(name='client')
self.T = librabbitmq.Transport(self.client)
def test_driver_version(self):
- self.assertTrue(self.T.driver_version())
+ assert self.T.driver_version()
def test_create_channel(self):
conn = Mock(name='connection')
chan = self.T.create_channel(conn)
- self.assertTrue(chan)
+ assert chan
conn.channel.assert_called_with()
def test_drain_events(self):
@@ -80,7 +81,7 @@ class test_Transport(lrmqCase):
def test_establish_connection_SSL_not_supported(self):
self.client.ssl = True
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.T.establish_connection()
def test_establish_connection(self):
@@ -90,18 +91,15 @@ class test_Transport(lrmqCase):
self.T.client.transport_options = {}
conn = self.T.establish_connection()
- self.assertEqual(
- self.T.client.port,
- self.T.default_connection_params['port'],
- )
- self.assertEqual(conn.client, self.T.client)
- self.assertEqual(self.T.client.drain_events, conn.drain_events)
+ assert self.T.client.port == self.T.default_connection_params['port']
+ assert conn.client == self.T.client
+ assert self.T.client.drain_events == conn.drain_events
def test_collect__no_conn(self):
self.T.client.drain_events = 1234
self.T._collect(None)
- self.assertIsNone(self.client.drain_events)
- self.assertIsNone(self.T.client)
+ assert self.client.drain_events is None
+ assert self.T.client is None
def test_collect__with_conn(self):
self.T.client.drain_events = 1234
@@ -114,12 +112,12 @@ class test_Transport(lrmqCase):
with patch('os.close') as close:
self.T._collect(conn)
close.assert_called_with(conn.fileno())
- self.assertFalse(conn.channels)
- self.assertFalse(conn.callbacks)
+ assert not conn.channels
+ assert not conn.callbacks
for chan in chans.values():
- self.assertIsNone(chan.connection)
- self.assertIsNone(self.client.drain_events)
- self.assertIsNone(self.T.client)
+ assert chan.connection is None
+ assert self.client.drain_events is None
+ assert self.T.client is None
with patch('os.close') as close:
self.T.client = self.client
@@ -138,11 +136,11 @@ class test_Transport(lrmqCase):
def test_verify_connection(self):
conn = Mock(name='connection')
conn.connected = True
- self.assertTrue(self.T.verify_connection(conn))
+ assert self.T.verify_connection(conn)
def test_close_connection(self):
conn = Mock(name='connection')
self.client.drain_events = 1234
self.T.close_connection(conn)
- self.assertIsNone(self.client.drain_events)
+ assert self.client.drain_events is None
conn.close.assert_called_with()
diff --git a/kombu/tests/transport/test_memory.py b/t/unit/transport/test_memory.py
index 1ca83b9c..d4dc3390 100644
--- a/kombu/tests/transport/test_memory.py
+++ b/t/unit/transport/test_memory.py
@@ -1,13 +1,12 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import socket
from kombu import Connection, Exchange, Queue, Consumer, Producer
-from kombu.tests.case import Case
-
-class test_MemoryTransport(Case):
+class test_MemoryTransport:
def setup(self):
self.c = Connection(transport='memory')
@@ -25,7 +24,7 @@ class test_MemoryTransport(Case):
exchange=self.fanout)
def test_driver_version(self):
- self.assertTrue(self.c.transport.driver_version())
+ assert self.c.transport.driver_version()
def test_produce_consume_noack(self):
channel = self.c.channel()
@@ -48,7 +47,7 @@ class test_MemoryTransport(Case):
break
self.c.drain_events()
- self.assertEqual(len(_received), 10)
+ assert len(_received) == 10
def test_produce_consume_fanout(self):
producer = self.c.Producer()
@@ -60,10 +59,10 @@ class test_MemoryTransport(Case):
exchange=self.fanout,
)
- self.assertEqual(self.q3(self.c).get().payload, {'hello': 'world'})
- self.assertEqual(self.q4(self.c).get().payload, {'hello': 'world'})
- self.assertIsNone(self.q3(self.c).get())
- self.assertIsNone(self.q4(self.c).get())
+ assert self.q3(self.c).get().payload == {'hello': 'world'}
+ assert self.q4(self.c).get().payload == {'hello': 'world'}
+ assert self.q3(self.c).get() is None
+ assert self.q4(self.c).get() is None
def test_produce_consume(self):
channel = self.c.channel()
@@ -99,38 +98,38 @@ class test_MemoryTransport(Case):
break
self.c.drain_events()
- self.assertEqual(len(_received1) + len(_received2), 20)
+ assert len(_received1) + len(_received2) == 20
# compression
producer.publish({'compressed': True},
routing_key='test_transport_memory',
compression='zlib')
m = self.q(channel).get()
- self.assertDictEqual(m.payload, {'compressed': True})
+ assert m.payload == {'compressed': True}
# queue.delete
for i in range(10):
producer.publish({'foo': i}, routing_key='test_transport_memory')
- self.assertTrue(self.q(channel).get())
+ assert self.q(channel).get()
self.q(channel).delete()
self.q(channel).declare()
- self.assertIsNone(self.q(channel).get())
+ assert self.q(channel).get() is None
# queue.purge
for i in range(10):
producer.publish({'foo': i}, routing_key='test_transport_memory2')
- self.assertTrue(self.q2(channel).get())
+ assert self.q2(channel).get()
self.q2(channel).purge()
- self.assertIsNone(self.q2(channel).get())
+ assert self.q2(channel).get() is None
def test_drain_events(self):
- with self.assertRaises(socket.timeout):
+ with pytest.raises(socket.timeout):
self.c.drain_events(timeout=0.1)
c1 = self.c.channel()
c2 = self.c.channel()
- with self.assertRaises(socket.timeout):
+ with pytest.raises(socket.timeout):
self.c.drain_events(timeout=0.1)
del(c1) # so pyflakes doesn't complain.
@@ -162,5 +161,5 @@ class test_MemoryTransport(Case):
chan.queues.clear()
x = chan._queue_for('foo')
- self.assertTrue(x)
- self.assertIs(chan._queue_for('foo'), x)
+ assert x
+ assert chan._queue_for('foo') is x
diff --git a/kombu/tests/transport/test_mongodb.py b/t/unit/transport/test_mongodb.py
index cbe8301a..b21be7f3 100644
--- a/kombu/tests/transport/test_mongodb.py
+++ b/t/unit/transport/test_mongodb.py
@@ -1,10 +1,12 @@
from __future__ import absolute_import, unicode_literals
import datetime
+import pytest
+
+from case import MagicMock, call, patch, skip
from kombu import Connection
from kombu.five import Empty
-from kombu.tests.case import Case, MagicMock, call, patch, skip
def _create_mock_connection(url='', **kwargs):
@@ -44,7 +46,7 @@ def _create_mock_connection(url='', **kwargs):
@skip.unless_module('pymongo')
-class test_mongodb_uri_parsing(Case):
+class test_mongodb_uri_parsing:
def test_defaults(self):
url = 'mongodb://'
@@ -53,22 +55,22 @@ class test_mongodb_uri_parsing(Case):
hostname, dbname, options = channel._parse_uri()
- self.assertEqual(dbname, 'kombu_default')
- self.assertEqual(hostname, 'mongodb://127.0.0.1')
+ assert dbname == 'kombu_default'
+ assert hostname == 'mongodb://127.0.0.1'
def test_custom_host(self):
url = 'mongodb://localhost'
channel = _create_mock_connection(url).default_channel
hostname, dbname, options = channel._parse_uri()
- self.assertEqual(dbname, 'kombu_default')
+ assert dbname == 'kombu_default'
def test_custom_database(self):
url = 'mongodb://localhost/dbname'
channel = _create_mock_connection(url).default_channel
hostname, dbname, options = channel._parse_uri()
- self.assertEqual(dbname, 'dbname')
+ assert dbname == 'dbname'
def test_custom_credentials(self):
url = 'mongodb://localhost/dbname'
@@ -76,11 +78,11 @@ class test_mongodb_uri_parsing(Case):
url, userid='foo', password='bar').default_channel
hostname, dbname, options = channel._parse_uri()
- self.assertEqual(hostname, 'mongodb://foo:bar@localhost/dbname')
- self.assertEqual(dbname, 'dbname')
+ assert hostname == 'mongodb://foo:bar@localhost/dbname'
+ assert dbname == 'dbname'
-class BaseMongoDBChannelCase(Case):
+class BaseMongoDBChannelCase:
def _get_method(self, cname, mname):
collection = getattr(self.channel, cname)
@@ -104,7 +106,7 @@ class BaseMongoDBChannelCase(Case):
self.channel._queue_bind('fanout_exchange', 'foo', '*', queue)
- self.assertIn(queue, self.channel._broadcast_cursors)
+ assert queue in self.channel._broadcast_cursors
def get_broadcast(self, queue):
return self.channel._broadcast_cursors[queue]
@@ -162,10 +164,10 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
],
)
- self.assertDictEqual(event, {'some': 'data'})
+ assert event == {'some': 'data'}
self.set_operation_return_value('messages', 'find_and_modify', None)
- with self.assertRaises(Empty):
+ with pytest.raises(Empty):
self.channel._get('foobar')
def test_get_fanout(self):
@@ -175,9 +177,9 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
event = self.channel._get('foobar')
self.assert_collection_accessed('messages.broadcast')
- self.assertDictEqual(event, {'some': 'data'})
+ assert event == {'some': 'data'}
- with self.assertRaises(Empty):
+ with pytest.raises(Empty):
self.channel._get('foobar')
def test_put(self):
@@ -209,7 +211,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
'messages', 'find', {'queue': 'foobar'},
)
- self.assertEqual(result, 77)
+ assert result == 77
def test_size_fanout(self):
self.declare_droadcast_queue('foobar')
@@ -220,7 +222,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
result = self.channel._size('foobar')
- self.assertEqual(result, 77)
+ assert result == 77
def test_purge(self):
self.set_operation_return_value('messages', 'find.count', 77)
@@ -231,7 +233,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
'messages', 'remove', {'queue': 'foobar'},
)
- self.assertEqual(result, 77)
+ assert result == 77
def test_purge_fanout(self):
self.declare_droadcast_queue('foobar')
@@ -244,7 +246,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
cursor.purge.assert_any_call()
- self.assertEqual(result, 77)
+ assert result == 77
def test_get_table(self):
state_table = [('foo', '*', 'foo')]
@@ -266,9 +268,7 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
'routing', 'find', {'exchange': 'test_exchange'},
)
- self.assertSetEqual(
- set(result), frozenset(state_table) | frozenset(stored_table),
- )
+ assert set(result) == frozenset(state_table) | frozenset(stored_table)
def test_queue_bind(self):
self.channel._queue_bind('test_exchange', 'foo', '*', 'foo')
@@ -299,8 +299,8 @@ class test_mongodb_channel(BaseMongoDBChannelCase):
cursor.close.assert_any_call()
- self.assertNotIn('foobar', self.channel._broadcast_cursors)
- self.assertNotIn('foobar', self.channel._fanout_queues)
+ assert 'foobar' not in self.channel._broadcast_cursors
+ assert 'foobar' not in self.channel._fanout_queues
# Tests for channel internals
@@ -468,14 +468,14 @@ class test_mongodb_channel_ttl(BaseMongoDBChannelCase):
self.channel.client.assert_not_called()
- self.assertEqual(result, self.expire_at)
+ assert result == self.expire_at
self.set_operation_return_value('queues', 'find_one', {
'_id': 'docId', 'options': {'arguments': {'x-expires': 777}},
})
result = self.channel._get_expire('foobar', 'x-expires')
- self.assertEqual(result, self.expire_at)
+ assert result == self.expire_at
def test_update_queues_expire(self):
self.set_operation_return_value('queues', 'find_one', {
@@ -518,4 +518,4 @@ class test_mongodb_channel_calc_queue_size(BaseMongoDBChannelCase):
self.assert_operation_has_calls('messages', 'find', [])
- self.assertEqual(result, 0)
+ assert result == 0
diff --git a/kombu/tests/transport/test_pyamqp.py b/t/unit/transport/test_pyamqp.py
index 5bb8975f..1c8812b8 100644
--- a/kombu/tests/transport/test_pyamqp.py
+++ b/t/unit/transport/test_pyamqp.py
@@ -4,16 +4,11 @@ import sys
from itertools import count
-try:
- import amqp # noqa
-except ImportError:
- pyamqp = None # noqa
-else:
- from kombu.transport import pyamqp
+from case import Mock, mock, patch
+
from kombu import Connection
from kombu.five import nextfun
-
-from kombu.tests.case import Case, Mock, mock, patch
+from kombu.transport import pyamqp
class MockConnection(dict):
@@ -25,7 +20,7 @@ class MockConnection(dict):
pass
-class test_Channel(Case):
+class test_Channel:
def setup(self):
@@ -47,39 +42,39 @@ class test_Channel(Case):
self.channel = Channel(self.conn, 0)
def test_init(self):
- self.assertFalse(self.channel.no_ack_consumers)
+ assert not self.channel.no_ack_consumers
def test_prepare_message(self):
- self.assertTrue(self.channel.prepare_message(
+ assert self.channel.prepare_message(
'foobar', 10, 'application/data', 'utf-8',
properties={},
- ))
+ )
def test_message_to_python(self):
message = Mock()
message.headers = {}
message.properties = {}
- self.assertTrue(self.channel.message_to_python(message))
+ assert self.channel.message_to_python(message)
def test_close_resolves_connection_cycle(self):
- self.assertIsNotNone(self.channel.connection)
+ assert self.channel.connection is not None
self.channel.close()
- self.assertIsNone(self.channel.connection)
+ assert self.channel.connection is None
def test_basic_consume_registers_ack_status(self):
self.channel.wait_returns = 'my-consumer-tag'
self.channel.basic_consume('foo', no_ack=True)
- self.assertIn('my-consumer-tag', self.channel.no_ack_consumers)
+ assert 'my-consumer-tag' in self.channel.no_ack_consumers
self.channel.wait_returns = 'other-consumer-tag'
self.channel.basic_consume('bar', no_ack=False)
- self.assertNotIn('other-consumer-tag', self.channel.no_ack_consumers)
+ assert 'other-consumer-tag' not in self.channel.no_ack_consumers
self.channel.basic_cancel('my-consumer-tag')
- self.assertNotIn('my-consumer-tag', self.channel.no_ack_consumers)
+ assert 'my-consumer-tag' not in self.channel.no_ack_consumers
-class test_Transport(Case):
+class test_Transport:
def setup(self):
self.connection = Connection('pyamqp://')
@@ -91,7 +86,7 @@ class test_Transport(Case):
connection.channel.assert_called_with()
def test_driver_version(self):
- self.assertTrue(self.transport.driver_version())
+ assert self.transport.driver_version()
def test_drain_events(self):
connection = Mock()
@@ -111,18 +106,18 @@ class test_Transport(Case):
self.transport.Connection = Conn
self.transport.client.hostname = 'localhost'
conn1 = self.transport.establish_connection()
- self.assertEqual(conn1.host, '127.0.0.1:5672')
+ assert conn1.host == '127.0.0.1:5672'
self.transport.client.hostname = 'example.com'
conn2 = self.transport.establish_connection()
- self.assertEqual(conn2.host, 'example.com:5672')
+ assert conn2.host == 'example.com:5672'
def test_close_connection(self):
connection = Mock()
connection.client = Mock()
self.transport.close_connection(connection)
- self.assertIsNone(connection.client)
+ assert connection.client is None
connection.close.assert_called_with()
@mock.mask_modules('ssl')
@@ -130,13 +125,13 @@ class test_Transport(Case):
pm = sys.modules.pop('amqp.connection')
try:
from amqp.connection import SSLError
- self.assertEqual(SSLError.__module__, 'amqp.connection')
+ assert SSLError.__module__ == 'amqp.connection'
finally:
if pm is not None:
sys.modules['amqp.connection'] = pm
-class test_pyamqp(Case):
+class test_pyamqp:
def test_default_port(self):
@@ -144,8 +139,7 @@ class test_pyamqp(Case):
Connection = MockConnection
c = Connection(port=None, transport=Transport).connect()
- self.assertEqual(c['host'],
- '127.0.0.1:%s' % (Transport.default_port,))
+ assert c['host'] == '127.0.0.1:%s' % (Transport.default_port,)
def test_custom_port(self):
@@ -153,7 +147,7 @@ class test_pyamqp(Case):
Connection = MockConnection
c = Connection(port=1337, transport=Transport).connect()
- self.assertEqual(c['host'], '127.0.0.1:1337')
+ assert c['host'] == '127.0.0.1:1337'
def test_register_with_event_loop(self):
t = pyamqp.Transport(Mock())
diff --git a/kombu/tests/transport/test_redis.py b/t/unit/transport/test_redis.py
index c660badf..fd47317f 100644
--- a/kombu/tests/transport/test_redis.py
+++ b/t/unit/transport/test_redis.py
@@ -1,11 +1,14 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import socket
import types
from collections import defaultdict
from itertools import count
+from case import ANY, ContextMock, Mock, call, mock, skip, patch
+
from kombu import Connection, Exchange, Queue, Consumer, Producer
from kombu.exceptions import InconsistencyError, VersionMismatch
from kombu.five import Empty, Queue as _Queue, bytes_if_py2
@@ -13,10 +16,6 @@ from kombu.transport import virtual
from kombu.utils import eventio # patch poll
from kombu.utils.json import dumps
-from kombu.tests.case import (
- Case, ContextMock, Mock, call, mock, skip, patch, ANY,
-)
-
class _poll(eventio._select):
@@ -234,7 +233,7 @@ class Transport(redis.Transport):
@skip.unless_module('redis')
-class test_Channel(Case):
+class test_Channel:
def setup(self):
self.connection = self.create_connection()
@@ -253,7 +252,7 @@ class test_Channel(Case):
msg = chan.prepare_message('quick brown fox')
chan.basic_publish(msg, n, n)
payload = chan._get(n)
- self.assertTrue(payload)
+ assert payload
pymsg = chan.message_to_python(payload)
return pymsg.delivery_tag
@@ -261,11 +260,11 @@ class test_Channel(Case):
seen = set()
for i in range(100):
tag = self._get_one_delivery_tag()
- self.assertNotIn(tag, seen)
+ assert tag not in seen
seen.add(tag)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
int(tag)
- self.assertEqual(len(tag), 36)
+ assert len(tag) == 36
def test_disable_ack_emulation(self):
conn = Connection(transport=Transport, transport_options={
@@ -273,8 +272,8 @@ class test_Channel(Case):
})
chan = conn.channel()
- self.assertFalse(chan.ack_emulation)
- self.assertEqual(chan.QoS, virtual.QoS)
+ assert not chan.ack_emulation
+ assert chan.QoS == virtual.QoS
def test_redis_ping_raises(self):
pool = Mock(name='pool')
@@ -295,13 +294,13 @@ class test_Channel(Case):
conn = Connection(transport=XTransport)
client.ping.side_effect = RuntimeError()
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
conn.channel()
pool.disconnect.assert_called_with()
pool.disconnect.reset_mock()
pool_at_init = [None]
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
conn.channel()
pool.disconnect.assert_not_called()
@@ -314,10 +313,8 @@ class test_Channel(Case):
pool.disconnect.assert_called_with()
def test_next_delivery_tag(self):
- self.assertNotEqual(
- self.channel._next_delivery_tag(),
- self.channel._next_delivery_tag(),
- )
+ assert (self.channel._next_delivery_tag() !=
+ self.channel._next_delivery_tag())
def test_do_restore_message(self):
client = Mock(name='client')
@@ -399,21 +396,21 @@ class test_Channel(Case):
qos._vrestore_count = 1
qos.restore_visible()
client.zrevrangebyscore.assert_not_called()
- self.assertEqual(qos._vrestore_count, 2)
+ assert qos._vrestore_count == 2
qos._vrestore_count = 0
qos.restore_visible()
restore.assert_has_calls([
call(1, client), call(2, client), call(3, client),
])
- self.assertEqual(qos._vrestore_count, 1)
+ assert qos._vrestore_count == 1
qos._vrestore_count = 0
restore.reset_mock()
client.zrevrangebyscore.return_value = []
qos.restore_visible()
restore.assert_not_called()
- self.assertEqual(qos._vrestore_count, 1)
+ assert qos._vrestore_count == 1
qos._vrestore_count = 0
client.setnx.side_effect = redis.MutexHeld()
@@ -424,14 +421,13 @@ class test_Channel(Case):
self.channel.queue_declare(queue='txconfanq')
self.channel.queue_bind(queue='txconfanq', exchange='txconfan')
- self.assertIn('txconfanq', self.channel._fanout_queues)
+ assert 'txconfanq' in self.channel._fanout_queues
self.channel.basic_consume('txconfanq', False, None, 1)
- self.assertIn('txconfanq', self.channel.active_fanout_queues)
- self.assertEqual(self.channel._fanout_to_queue.get('txconfan'),
- 'txconfanq')
+ assert 'txconfanq' in self.channel.active_fanout_queues
+ assert self.channel._fanout_to_queue.get('txconfan') == 'txconfanq'
def test_basic_cancel_unknown_delivery_tag(self):
- self.assertIsNone(self.channel.basic_cancel('txaseqwewq'))
+ assert self.channel.basic_cancel('txaseqwewq') is None
def test_subscribe_no_queues(self):
self.channel.subclient = Mock()
@@ -448,7 +444,7 @@ class test_Channel(Case):
self.channel._subscribe()
self.channel.subclient.psubscribe.assert_called()
s_args, _ = self.channel.subclient.psubscribe.call_args
- self.assertItemsEqual(s_args[0], ['/{db}.a', '/{db}.b'])
+ assert sorted(s_args[0]) == ['/{db}.a', '/{db}.b']
self.channel.subclient.connection._sock = None
self.channel._subscribe()
@@ -458,38 +454,34 @@ class test_Channel(Case):
s = self.channel.subclient
s.subscribed = True
self.channel._handle_message(s, ['unsubscribe', 'a', 0])
- self.assertFalse(s.subscribed)
+ assert not s.subscribed
def test_handle_pmessage_message(self):
- self.assertDictEqual(
- self.channel._handle_message(
- self.channel.subclient,
- ['pmessage', 'pattern', 'channel', 'data'],
- ),
- {
- 'type': 'pmessage',
- 'pattern': 'pattern',
- 'channel': 'channel',
- 'data': 'data',
- },
+ res = self.channel._handle_message(
+ self.channel.subclient,
+ ['pmessage', 'pattern', 'channel', 'data'],
)
+ assert res == {
+ 'type': 'pmessage',
+ 'pattern': 'pattern',
+ 'channel': 'channel',
+ 'data': 'data',
+ }
def test_handle_message(self):
- self.assertDictEqual(
- self.channel._handle_message(
- self.channel.subclient,
- ['type', 'channel', 'data'],
- ),
- {
- 'type': 'type',
- 'pattern': None,
- 'channel': 'channel',
- 'data': 'data',
- },
+ res = self.channel._handle_message(
+ self.channel.subclient,
+ ['type', 'channel', 'data'],
)
+ assert res == {
+ 'type': 'type',
+ 'pattern': None,
+ 'channel': 'channel',
+ 'data': 'data',
+ }
def test_brpop_start_but_no_queues(self):
- self.assertIsNone(self.channel._brpop_start())
+ assert self.channel._brpop_start() is None
def test_receive(self):
s = self.channel.subclient = Mock()
@@ -497,37 +489,37 @@ class test_Channel(Case):
s.parse_response.return_value = ['message', 'a',
dumps({'hello': 'world'})]
payload, queue = self.channel._receive()
- self.assertDictEqual(payload, {'hello': 'world'})
- self.assertEqual(queue, 'b')
+ assert payload == {'hello': 'world'}
+ assert queue == 'b'
def test_receive_raises_for_connection_error(self):
self.channel._in_listen = True
s = self.channel.subclient = Mock()
s.parse_response.side_effect = KeyError('foo')
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.channel._receive()
- self.assertFalse(self.channel._in_listen)
+ assert not self.channel._in_listen
def test_receive_empty(self):
s = self.channel.subclient = Mock()
s.parse_response.return_value = None
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
self.channel._receive()
def test_receive_different_message_Type(self):
s = self.channel.subclient = Mock()
s.parse_response.return_value = ['message', '/foo/', 0, 'data']
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
self.channel._receive()
def test_brpop_read_raises(self):
c = self.channel.client = Mock()
c.parse_response.side_effect = KeyError('foo')
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.channel._brpop_read()
c.connection.disconnect.assert_called_with()
@@ -536,7 +528,7 @@ class test_Channel(Case):
c = self.channel.client = Mock()
c.parse_response.return_value = None
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
self.channel._brpop_read()
def test_poll_error(self):
@@ -547,7 +539,7 @@ class test_Channel(Case):
c.parse_response.assert_called_with(c.connection, 'BRPOP')
c.parse_response.side_effect = KeyError('foo')
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.channel._poll_error('BRPOP')
def test_poll_error_on_type_LISTEN(self):
@@ -558,7 +550,7 @@ class test_Channel(Case):
c.parse_response.assert_called_with()
c.parse_response.side_effect = KeyError('foo')
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.channel._poll_error('LISTEN')
def test_put_fanout(self):
@@ -609,14 +601,14 @@ class test_Channel(Case):
self.channel._create_client.return_value = self.channel.client
exists = self.channel.client.exists = Mock()
exists.return_value = True
- self.assertTrue(self.channel._has_queue('foo'))
+ assert self.channel._has_queue('foo')
exists.assert_has_calls([
call(self.channel._q_for_pri('foo', pri))
for pri in redis.PRIORITY_STEPS
])
exists.return_value = False
- self.assertFalse(self.channel._has_queue('foo'))
+ assert not self.channel._has_queue('foo')
def test_close_when_closed(self):
self.channel.closed = True
@@ -642,47 +634,44 @@ class test_Channel(Case):
def test_invalid_database_raises_ValueError(self):
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
self.channel.connection.client.virtual_host = 'dwqeq'
self.channel._connparams()
def test_connparams_allows_slash_in_db(self):
self.channel.connection.client.virtual_host = '/123'
- self.assertEqual(self.channel._connparams()['db'], 123)
+ assert self.channel._connparams()['db'] == 123
def test_connparams_db_can_be_int(self):
self.channel.connection.client.virtual_host = 124
- self.assertEqual(self.channel._connparams()['db'], 124)
+ assert self.channel._connparams()['db'] == 124
def test_new_queue_with_auto_delete(self):
redis.Channel._new_queue(self.channel, 'george', auto_delete=False)
- self.assertNotIn('george', self.channel.auto_delete_queues)
+ assert 'george' not in self.channel.auto_delete_queues
redis.Channel._new_queue(self.channel, 'elaine', auto_delete=True)
- self.assertIn('elaine', self.channel.auto_delete_queues)
+ assert 'elaine' in self.channel.auto_delete_queues
def test_connparams_regular_hostname(self):
self.channel.connection.client.hostname = 'george.vandelay.com'
- self.assertEqual(
- self.channel._connparams()['host'],
- 'george.vandelay.com',
- )
+ assert self.channel._connparams()['host'] == 'george.vandelay.com'
def test_rotate_cycle_ValueError(self):
cycle = self.channel._queue_cycle
cycle.update(['kramer', 'jerry'])
cycle.rotate('kramer')
- self.assertEqual(cycle.items, ['jerry', 'kramer'])
+ assert cycle.items, ['jerry' == 'kramer']
cycle.rotate('elaine')
def test_get_client(self):
import redis as R
KombuRedis = redis.Channel._get_client(self.channel)
- self.assertTrue(KombuRedis)
+ assert KombuRedis
Rv = getattr(R, 'VERSION', None)
try:
R.VERSION = (2, 4, 0)
- with self.assertRaises(VersionMismatch):
+ with pytest.raises(VersionMismatch):
redis.Channel._get_client(self.channel)
finally:
if Rv is not None:
@@ -690,8 +679,7 @@ class test_Channel(Case):
def test_get_response_error(self):
from redis.exceptions import ResponseError
- self.assertIs(redis.Channel._get_response_error(self.channel),
- ResponseError)
+ assert redis.Channel._get_response_error(self.channel) is ResponseError
def test_avail_client(self):
self.channel._pool = Mock()
@@ -745,12 +733,10 @@ class test_Channel(Case):
cb.assert_called_with(ret[0])
def test_transport_get_errors(self):
- self.assertTrue(redis.Transport._get_errors(self.connection.transport))
+ assert redis.Transport._get_errors(self.connection.transport)
def test_transport_driver_version(self):
- self.assertTrue(
- redis.Transport.driver_version(self.connection.transport),
- )
+ assert redis.Transport.driver_version(self.connection.transport)
def test_transport_get_errors_when_InvalidData_used(self):
from redis import exceptions
@@ -764,8 +750,8 @@ class test_Channel(Case):
exceptions.DataError = None
try:
errors = redis.Transport._get_errors(self.connection.transport)
- self.assertTrue(errors)
- self.assertIn(ID, errors[1])
+ assert errors
+ assert ID in errors[1]
finally:
if DataError is not None:
exceptions.DataError = DataError
@@ -779,43 +765,44 @@ class test_Channel(Case):
# Everything is fine, there is a list of queues.
channel.client.sadd(key, 'celery\x06\x16\x06\x16celery')
- self.assertListEqual(channel.get_table('celery'),
- [('celery', '', 'celery')])
+ assert channel.get_table('celery') == [
+ ('celery', '', 'celery'),
+ ]
# ... then for some reason, the _kombu.binding.celery key gets lost
channel.client.srem(key)
# which raises a channel error so that the consumer/publisher
# can recover by redeclaring the required entities.
- with self.assertRaises(InconsistencyError):
+ with pytest.raises(InconsistencyError):
self.channel.get_table('celery')
def test_socket_connection(self):
with patch('kombu.transport.redis.Channel._create_client'):
with Connection('redis+socket:///tmp/redis.sock') as conn:
connparams = conn.default_channel._connparams()
- self.assertTrue(issubclass(
+ assert issubclass(
connparams['connection_class'],
redis.redis.UnixDomainSocketConnection,
- ))
- self.assertEqual(connparams['path'], '/tmp/redis.sock')
+ )
+ assert connparams['path'] == '/tmp/redis.sock'
def test_ssl_argument__dict(self):
with patch('kombu.transport.redis.Channel._create_client'):
with Connection('redis://', ssl={'ca_cert': '/foo'}) as conn:
connparams = conn.default_channel._connparams()
- self.assertTrue(connparams['ssl'])
- self.assertEqual(connparams['ca_cert'], '/foo')
+ assert connparams['ssl']
+ assert connparams['ca_cert'] == '/foo'
def test_ssl_argument__bool(self):
with patch('kombu.transport.redis.Channel._create_client'):
with Connection('redis://', ssl=True) as conn:
connparams = conn.default_channel._connparams()
- self.assertTrue(connparams['ssl'])
+ assert connparams['ssl']
@skip.unless_module('redis')
-class test_Redis(Case):
+class test_Redis:
def setup(self):
self.connection = Connection(transport=Transport)
@@ -832,11 +819,10 @@ class test_Redis(Case):
producer.publish({'hello': 'world'})
- self.assertDictEqual(self.queue(channel).get().payload,
- {'hello': 'world'})
- self.assertIsNone(self.queue(channel).get())
- self.assertIsNone(self.queue(channel).get())
- self.assertIsNone(self.queue(channel).get())
+ assert self.queue(channel).get().payload == {'hello': 'world'}
+ assert self.queue(channel).get() is None
+ assert self.queue(channel).get() is None
+ assert self.queue(channel).get() is None
def test_publish__consume(self):
connection = Connection(transport=Transport)
@@ -854,11 +840,11 @@ class test_Redis(Case):
consumer.register_callback(callback)
consumer.consume()
- self.assertIn(channel, channel.connection.cycle._channels)
+ assert channel in channel.connection.cycle._channels
try:
connection.drain_events(timeout=1)
- self.assertTrue(_received)
- with self.assertRaises(socket.timeout):
+ assert _received
+ with pytest.raises(socket.timeout):
connection.drain_events(timeout=0.01)
finally:
channel.close()
@@ -871,8 +857,8 @@ class test_Redis(Case):
for i in range(10):
producer.publish({'hello': 'world-%s' % (i,)})
- self.assertEqual(channel._size('test_Redis'), 10)
- self.assertEqual(self.queue(channel).purge(), 10)
+ assert channel._size('test_Redis') == 10
+ assert self.queue(channel).purge() == 10
channel.close()
def test_db_values(self):
@@ -885,7 +871,7 @@ class test_Redis(Case):
Connection(virtual_host='/1',
transport=Transport).channel()
- with self.assertRaises(Exception):
+ with pytest.raises(Exception):
Connection('redis:///foo').channel()
def test_db_port(self):
@@ -900,7 +886,7 @@ class test_Redis(Case):
cycle = c.connection.cycle
c.client.connection
c.close()
- self.assertNotIn(c, cycle._channels)
+ assert c not in cycle._channels
def test_close_ResponseError(self):
c = Connection(transport=Transport).channel()
@@ -912,12 +898,12 @@ class test_Redis(Case):
conn1 = c.client.connection
conn2 = c.subclient.connection
c.close()
- self.assertTrue(conn1.disconnected)
- self.assertTrue(conn2.disconnected)
+ assert conn1.disconnected
+ assert conn2.disconnected
def test_get__Empty(self):
channel = self.connection.channel()
- with self.assertRaises(Empty):
+ with pytest.raises(Empty):
channel._get('does-not-exist')
channel.close()
@@ -925,16 +911,16 @@ class test_Redis(Case):
with mock.module_exists(*_redis_modules()):
conn = Connection(transport=Transport)
chan = conn.channel()
- self.assertTrue(chan.Client)
- self.assertTrue(chan.ResponseError)
- self.assertTrue(conn.transport.connection_errors)
- self.assertTrue(conn.transport.channel_errors)
+ assert chan.Client
+ assert chan.ResponseError
+ assert conn.transport.connection_errors
+ assert conn.transport.channel_errors
def test_check_at_least_we_try_to_connect_and_fail(self):
import redis
connection = Connection('redis://localhost:65534/')
- with self.assertRaises(redis.exceptions.ConnectionError):
+ with pytest.raises(redis.exceptions.ConnectionError):
chan = connection.channel()
chan._size('some_queue')
@@ -974,7 +960,7 @@ def _redis_modules():
@skip.unless_module('redis')
-class test_MultiChannelPoller(Case):
+class test_MultiChannelPoller:
def setup(self):
self.Poller = redis.MultiChannelPoller
@@ -1013,7 +999,7 @@ class test_MultiChannelPoller(Case):
p._channels = []
poller = Mock(name='poller')
p.on_poll_init(poller)
- self.assertIs(p.poller, poller)
+ assert p.poller is poller
p._channels = [chan1]
p.on_poll_init(poller)
@@ -1043,7 +1029,7 @@ class test_MultiChannelPoller(Case):
def test_fds(self):
p = self.Poller()
p._fd_to_chan = {1: 2}
- self.assertDictEqual(p.fds, p._fd_to_chan)
+ assert p.fds == p._fd_to_chan
def test_close_unregisters_fds(self):
p = self.Poller()
@@ -1052,12 +1038,14 @@ class test_MultiChannelPoller(Case):
p.close()
- self.assertEqual(poller.unregister.call_count, 3)
+ assert poller.unregister.call_count == 3
u_args = poller.unregister.call_args_list
- self.assertItemsEqual(u_args, [((1,), {}),
- ((2,), {}),
- ((3,), {})])
+ assert sorted(u_args) == [
+ ((1,), {}),
+ ((2,), {}),
+ ((3,), {}),
+ ]
def test_close_when_unregister_raises_KeyError(self):
p = self.Poller()
@@ -1091,8 +1079,8 @@ class test_MultiChannelPoller(Case):
p._chan_to_sock = {(channel, client, type): 6}
p._register(channel, client, type)
p.poller.unregister.assert_called_with(6)
- self.assertTupleEqual(p._fd_to_chan[10], (channel, type))
- self.assertEqual(p._chan_to_sock[(channel, client, type)], sock)
+ assert p._fd_to_chan[10] == (channel, type)
+ assert p._chan_to_sock[(channel, client, type)] == sock
p.poller.register.assert_called_with(sock, p.eventflags)
# when client not connected yet
@@ -1113,15 +1101,15 @@ class test_MultiChannelPoller(Case):
channel._in_poll = False
p._register_BRPOP(channel)
- self.assertEqual(channel._brpop_start.call_count, 1)
- self.assertEqual(p._register.call_count, 1)
+ assert channel._brpop_start.call_count == 1
+ assert p._register.call_count == 1
channel.client.connection._sock = Mock()
p._chan_to_sock[(channel, channel.client, 'BRPOP')] = True
channel._in_poll = True
p._register_BRPOP(channel)
- self.assertEqual(channel._brpop_start.call_count, 1)
- self.assertEqual(p._register.call_count, 1)
+ assert channel._brpop_start.call_count == 1
+ assert p._register.call_count == 1
def test_register_LISTEN(self):
p = self.Poller()
@@ -1132,15 +1120,15 @@ class test_MultiChannelPoller(Case):
p._register_LISTEN(channel)
p._register.assert_called_with(channel, channel.subclient, 'LISTEN')
- self.assertEqual(p._register.call_count, 1)
- self.assertEqual(channel._subscribe.call_count, 1)
+ assert p._register.call_count == 1
+ assert channel._subscribe.call_count == 1
channel._in_listen = True
p._chan_to_sock[(channel, channel.subclient, 'LISTEN')] = 3
channel.subclient.connection._sock = Mock()
p._register_LISTEN(channel)
- self.assertEqual(p._register.call_count, 1)
- self.assertEqual(channel._subscribe.call_count, 1)
+ assert p._register.call_count == 1
+ assert channel._subscribe.call_count == 1
def create_get(self, events=None, queues=None, fanouts=None):
_pr = [] if events is None else events
@@ -1163,7 +1151,7 @@ class test_MultiChannelPoller(Case):
def test_get_no_actions(self):
p, channel = self.create_get()
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
p.get()
def test_qos_reject(self):
@@ -1177,7 +1165,7 @@ class test_MultiChannelPoller(Case):
p, channel = self.create_get(queues=['a_queue'])
channel.qos.can_consume.return_value = True
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
p.get()
p._register_BRPOP.assert_called_with(channel)
@@ -1186,7 +1174,7 @@ class test_MultiChannelPoller(Case):
p, channel = self.create_get(queues=['a_queue'])
channel.qos.can_consume.return_value = False
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
p.get()
p._register_BRPOP.assert_not_called()
@@ -1194,7 +1182,7 @@ class test_MultiChannelPoller(Case):
def test_get_listen(self):
p, channel = self.create_get(fanouts=['f_queue'])
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
p.get()
p._register_LISTEN.assert_called_with(channel)
@@ -1203,7 +1191,7 @@ class test_MultiChannelPoller(Case):
p, channel = self.create_get(events=[(1, eventio.ERR)])
p._fd_to_chan[1] = (channel, 'BRPOP')
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
p.get()
channel._poll_error.assert_called_with('BRPOP')
@@ -1213,14 +1201,14 @@ class test_MultiChannelPoller(Case):
(1, eventio.ERR)])
p._fd_to_chan[1] = (channel, 'BRPOP')
- with self.assertRaises(redis.Empty):
+ with pytest.raises(redis.Empty):
p.get()
channel._poll_error.assert_called_with('BRPOP')
@skip.unless_module('redis')
-class test_Mutex(Case):
+class test_Mutex:
def test_mutex(self, lock_id='xxx'):
client = Mock(name='client')
@@ -1234,29 +1222,29 @@ class test_Mutex(Case):
held = False
with redis.Mutex(client, 'foo1', 100):
held = True
- self.assertTrue(held)
+ assert held
client.setnx.assert_called_with('foo1', lock_id)
pipe.get.return_value = 'yyy'
held = False
with redis.Mutex(client, 'foo1', 100):
held = True
- self.assertTrue(held)
+ assert held
# Did not win
client.expire.reset_mock()
pipe.get.return_value = lock_id
client.setnx.return_value = False
- with self.assertRaises(redis.MutexHeld):
+ with pytest.raises(redis.MutexHeld):
held = False
with redis.Mutex(client, 'foo1', '100'):
held = True
- self.assertFalse(held)
+ assert not held
client.ttl.return_value = 0
- with self.assertRaises(redis.MutexHeld):
+ with pytest.raises(redis.MutexHeld):
held = False
with redis.Mutex(client, 'foo1', '100'):
held = True
- self.assertFalse(held)
+ assert not held
client.expire.assert_called()
# Wins but raises WatchError (and that is ignored)
@@ -1265,14 +1253,11 @@ class test_Mutex(Case):
held = False
with redis.Mutex(client, 'foo1', 100):
held = True
- self.assertTrue(held)
+ assert held
@skip.unless_module('redis.sentinel')
-class test_RedisSentinel(Case):
-
- def setup(self):
- pass
+class test_RedisSentinel:
def test_method_called(self):
from kombu.transport.redis import SentinelChannel
@@ -1289,9 +1274,7 @@ class test_RedisSentinel(Case):
p.assert_called()
def test_getting_master_from_sentinel(self):
- from redis.sentinel import Sentinel
-
- with patch.object(Sentinel, '__new__') as patched:
+ with patch('redis.sentinel.Sentinel') as patched:
connection = Connection(
'sentinel://localhost:65534/',
transport_options={
@@ -1300,7 +1283,7 @@ class test_RedisSentinel(Case):
)
connection.channel()
- self.assertTrue(patched)
+ assert patched
master_for = patched.return_value.master_for
master_for.assert_called()
@@ -1310,11 +1293,11 @@ class test_RedisSentinel(Case):
def test_can_create_connection(self):
from redis.exceptions import ConnectionError
- with self.assertRaises(ConnectionError):
- connection = Connection(
- 'sentinel://localhost:65534/',
- transport_options={
- 'master_name': 'not_important',
- },
- )
+ connection = Connection(
+ 'sentinel://localhost:65534/',
+ transport_options={
+ 'master_name': 'not_important',
+ },
+ )
+ with pytest.raises(ConnectionError):
connection.channel()
diff --git a/kombu/tests/transport/test_transport.py b/t/unit/transport/test_transport.py
index 2f7d2ede..26df9e19 100644
--- a/kombu/tests/transport/test_transport.py
+++ b/t/unit/transport/test_transport.py
@@ -1,26 +1,25 @@
from __future__ import absolute_import, unicode_literals
-from kombu import transport
+from case import Mock, patch
-from kombu.tests.case import Case, Mock, patch
+from kombu import transport
-class test_supports_librabbitmq(Case):
+class test_supports_librabbitmq:
def test_eventlet(self):
with patch('kombu.transport._detect_environment') as de:
de.return_value = 'eventlet'
- self.assertFalse(transport.supports_librabbitmq())
+ assert not transport.supports_librabbitmq()
-class test_transport(Case):
+class test_transport:
def test_resolve_transport(self):
from kombu.transport.memory import Transport
- self.assertIs(transport.resolve_transport(
- 'kombu.transport.memory:Transport'),
- Transport)
- self.assertIs(transport.resolve_transport(Transport), Transport)
+ assert transport.resolve_transport(
+ 'kombu.transport.memory:Transport') is Transport
+ assert transport.resolve_transport(Transport) is Transport
def test_resolve_transport_alias_callable(self):
m = transport.TRANSPORT_ALIASES['George'] = Mock(name='lazyalias')
@@ -31,4 +30,4 @@ class test_transport(Case):
transport.TRANSPORT_ALIASES.pop('George')
def test_resolve_transport_alias(self):
- self.assertTrue(transport.resolve_transport('pyamqp'))
+ assert transport.resolve_transport('pyamqp')
diff --git a/kombu/tests/utils/__init__.py b/t/unit/transport/virtual/__init__.py
index e69de29b..e69de29b 100644
--- a/kombu/tests/utils/__init__.py
+++ b/t/unit/transport/virtual/__init__.py
diff --git a/kombu/tests/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py
index 8b73a7fa..5eb23a08 100644
--- a/kombu/tests/transport/virtual/test_base.py
+++ b/t/unit/transport/virtual/test_base.py
@@ -1,16 +1,18 @@
from __future__ import absolute_import, unicode_literals
+import io
+import pytest
import sys
import warnings
+from case import MagicMock, Mock, patch
+
from kombu import Connection
from kombu.compression import compress
from kombu.exceptions import ResourceError, ChannelError
from kombu.transport import virtual
from kombu.utils.uuid import uuid
-from kombu.tests.case import Case, MagicMock, Mock, mock, patch
-
PY3 = sys.version_info[0] == 3
PRINT_FQDN = 'builtins.print' if PY3 else '__builtin__.print'
@@ -23,17 +25,15 @@ def memory_client():
return Connection(transport='memory')
-class test_BrokerState(Case):
-
- def test_constructor(self):
- s = virtual.BrokerState()
- self.assertTrue(hasattr(s, 'exchanges'))
+def test_BrokerState():
+ s = virtual.BrokerState()
+ assert hasattr(s, 'exchanges')
- t = virtual.BrokerState(exchanges=16)
- self.assertEqual(t.exchanges, 16)
+ t = virtual.BrokerState(exchanges=16)
+ assert t.exchanges == 16
-class test_QoS(Case):
+class test_QoS:
def setup(self):
self.q = virtual.QoS(client().channel(), prefetch_count=10)
@@ -42,17 +42,17 @@ class test_QoS(Case):
self.q._on_collect.cancel()
def test_constructor(self):
- self.assertTrue(self.q.channel)
- self.assertTrue(self.q.prefetch_count)
- self.assertFalse(self.q._delivered.restored)
- self.assertTrue(self.q._on_collect)
+ assert self.q.channel
+ assert self.q.prefetch_count
+ assert not self.q._delivered.restored
+ assert self.q._on_collect
def test_restore_visible__interface(self):
qos = virtual.QoS(client().channel())
qos.restore_visible()
- @mock.stdouts
- def test_can_consume(self, stdout, stderr):
+ def test_can_consume(self, stdouts):
+ stderr = io.StringIO()
_restored = []
class RestoreChannel(virtual.Channel):
@@ -61,63 +61,64 @@ class test_QoS(Case):
def _restore(self, message):
_restored.append(message)
- self.assertTrue(self.q.can_consume())
+ assert self.q.can_consume()
for i in range(self.q.prefetch_count - 1):
self.q.append(i, uuid())
- self.assertTrue(self.q.can_consume())
+ assert self.q.can_consume()
self.q.append(i + 1, uuid())
- self.assertFalse(self.q.can_consume())
+ assert not self.q.can_consume()
tag1 = next(iter(self.q._delivered))
self.q.ack(tag1)
- self.assertTrue(self.q.can_consume())
+ assert self.q.can_consume()
tag2 = uuid()
self.q.append(i + 2, tag2)
- self.assertFalse(self.q.can_consume())
+ assert not self.q.can_consume()
self.q.reject(tag2)
- self.assertTrue(self.q.can_consume())
+ assert self.q.can_consume()
self.q.channel = RestoreChannel(self.q.channel.connection)
tag3 = uuid()
self.q.append(i + 3, tag3)
self.q.reject(tag3, requeue=True)
self.q._flush()
- self.q.restore_unacked_once()
- self.assertListEqual(_restored, [11, 9, 8, 7, 6, 5, 4, 3, 2, 1])
- self.assertTrue(self.q._delivered.restored)
- self.assertFalse(self.q._delivered)
-
- self.q.restore_unacked_once()
+ assert self.q._delivered
+ assert not self.q._delivered.restored
+ self.q.restore_unacked_once(stderr=stderr)
+ assert _restored == [11, 9, 8, 7, 6, 5, 4, 3, 2, 1]
+ assert self.q._delivered.restored
+ assert not self.q._delivered
+
+ self.q.restore_unacked_once(stderr=stderr)
self.q._delivered.restored = False
- self.q.restore_unacked_once()
+ self.q.restore_unacked_once(stderr=stderr)
- self.assertTrue(stderr.getvalue())
- self.assertFalse(stdout.getvalue())
+ assert stderr.getvalue()
+ assert not stdouts.stdout.getvalue()
self.q.restore_at_shutdown = False
self.q.restore_unacked_once()
def test_get(self):
self.q._delivered['foo'] = 1
- self.assertEqual(self.q.get('foo'), 1)
+ assert self.q.get('foo') == 1
-class test_Message(Case):
+class test_Message:
def test_create(self):
c = client().channel()
data = c.prepare_message('the quick brown fox...')
tag = data['properties']['delivery_tag'] = uuid()
message = c.message_to_python(data)
- self.assertIsInstance(message, virtual.Message)
- self.assertIs(message, c.message_to_python(message))
+ assert isinstance(message, virtual.Message)
+ assert message is c.message_to_python(message)
if message.errors:
message._reraise_error()
- self.assertEqual(message.body,
- 'the quick brown fox...'.encode('utf-8'))
- self.assertTrue(message.delivery_tag, tag)
+ assert message.body == 'the quick brown fox...'.encode('utf-8')
+ assert message.delivery_tag, tag
def test_create_no_body(self):
virtual.Message(Mock(), {
@@ -131,46 +132,45 @@ class test_Message(Case):
tag = data['properties']['delivery_tag'] = uuid()
message = c.message_to_python(data)
dict_ = message.serializable()
- self.assertEqual(dict_['body'],
- 'the quick brown fox...'.encode('utf-8'))
- self.assertEqual(dict_['properties']['delivery_tag'], tag)
- self.assertNotIn('compression', dict_['headers'])
+ assert dict_['body'] == 'the quick brown fox...'.encode('utf-8')
+ assert dict_['properties']['delivery_tag'] == tag
+ assert 'compression' not in dict_['headers']
-class test_AbstractChannel(Case):
+class test_AbstractChannel:
def test_get(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
virtual.AbstractChannel()._get('queue')
def test_put(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
virtual.AbstractChannel()._put('queue', 'm')
def test_size(self):
- self.assertEqual(virtual.AbstractChannel()._size('queue'), 0)
+ assert virtual.AbstractChannel()._size('queue') == 0
def test_purge(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
virtual.AbstractChannel()._purge('queue')
def test_delete(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
virtual.AbstractChannel()._delete('queue')
def test_new_queue(self):
- self.assertIsNone(virtual.AbstractChannel()._new_queue('queue'))
+ assert virtual.AbstractChannel()._new_queue('queue') is None
def test_has_queue(self):
- self.assertTrue(virtual.AbstractChannel()._has_queue('queue'))
+ assert virtual.AbstractChannel()._has_queue('queue')
def test_poll(self):
cycle = Mock(name='cycle')
- self.assertTrue(virtual.AbstractChannel()._poll(cycle))
+ assert virtual.AbstractChannel()._poll(cycle)
cycle.get.assert_called()
-class test_Channel(Case):
+class test_Channel:
def setup(self):
self.channel = client().channel()
@@ -184,15 +184,15 @@ class test_Channel(Case):
t = c.transport
avail = t._avail_channel_ids = Mock(name='_avail_channel_ids')
avail.pop.side_effect = IndexError()
- with self.assertRaises(ResourceError):
+ with pytest.raises(ResourceError):
virtual.Channel(t)
def test_exchange_bind_interface(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.channel.exchange_bind('dest', 'src', 'key')
def test_exchange_unbind_interface(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.channel.exchange_unbind('dest', 'src', 'key')
def test_queue_unbind_interface(self):
@@ -200,28 +200,28 @@ class test_Channel(Case):
def test_management(self):
m = self.channel.connection.client.get_manager()
- self.assertTrue(m)
+ assert m
m.get_bindings()
m.close()
def test_exchange_declare(self):
c = self.channel
- with self.assertRaises(ChannelError):
+ with pytest.raises(ChannelError):
c.exchange_declare('test_exchange_declare', 'direct',
durable=True, auto_delete=True, passive=True)
c.exchange_declare('test_exchange_declare', 'direct',
durable=True, auto_delete=True)
c.exchange_declare('test_exchange_declare', 'direct',
durable=True, auto_delete=True, passive=True)
- self.assertIn('test_exchange_declare', c.state.exchanges)
+ assert 'test_exchange_declare' in c.state.exchanges
# can declare again with same values
c.exchange_declare('test_exchange_declare', 'direct',
durable=True, auto_delete=True)
- self.assertIn('test_exchange_declare', c.state.exchanges)
+ assert 'test_exchange_declare' in c.state.exchanges
# using different values raises NotEquivalentError
- with self.assertRaises(virtual.NotEquivalentError):
+ with pytest.raises(virtual.NotEquivalentError):
c.exchange_declare('test_exchange_declare', 'direct',
durable=False, auto_delete=True)
@@ -236,18 +236,18 @@ class test_Channel(Case):
c = PurgeChannel(self.channel.connection)
c.exchange_declare(ex, 'direct', durable=True, auto_delete=True)
- self.assertIn(ex, c.state.exchanges)
- self.assertFalse(c.state.has_binding(ex, ex, ex)) # no bindings yet
+ assert ex in c.state.exchanges
+ assert not c.state.has_binding(ex, ex, ex) # no bindings yet
c.exchange_delete(ex)
- self.assertNotIn(ex, c.state.exchanges)
+ assert ex not in c.state.exchanges
c.exchange_declare(ex, 'direct', durable=True, auto_delete=True)
c.queue_declare(ex)
c.queue_bind(ex, ex, ex)
- self.assertTrue(c.state.has_binding(ex, ex, ex))
+ assert c.state.has_binding(ex, ex, ex)
c.exchange_delete(ex)
- self.assertFalse(c.state.has_binding(ex, ex, ex))
- self.assertIn(ex, c.purged)
+ assert not c.state.has_binding(ex, ex, ex)
+ assert ex in c.purged
def test_queue_delete__if_empty(self, n='test_queue_delete__if_empty'):
class PurgeChannel(virtual.Channel):
@@ -268,12 +268,12 @@ class test_Channel(Case):
c.queue_bind(n, n, n)
c.queue_delete(n, if_empty=True)
- self.assertTrue(c.state.has_binding(n, n, n))
+ assert c.state.has_binding(n, n, n)
c.size = 0
c.queue_delete(n, if_empty=True)
- self.assertFalse(c.state.has_binding(n, n, n))
- self.assertIn(n, c.purged)
+ assert not c.state.has_binding(n, n, n)
+ assert n in c.purged
def test_queue_purge(self, n='test_queue_purge'):
@@ -288,7 +288,7 @@ class test_Channel(Case):
c.queue_declare(n)
c.queue_bind(n, n, n)
c.queue_purge(n)
- self.assertIn(n, c.purged)
+ assert n in c.purged
def test_basic_publish__anon_exchange(self):
c = memory_client().channel()
@@ -315,10 +315,10 @@ class test_Channel(Case):
r1 = c1.message_to_python(c1.basic_get(n))
r2 = c2.message_to_python(c2.basic_get(n))
- self.assertNotEqual(r1.delivery_tag, r2.delivery_tag)
- with self.assertRaises(ValueError):
+ assert r1.delivery_tag != r2.delivery_tag
+ with pytest.raises(ValueError):
int(r1.delivery_tag)
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
int(r2.delivery_tag)
def test_basic_publish__get__consume__restore(self,
@@ -335,31 +335,29 @@ class test_Channel(Case):
c.basic_publish(m, n, n)
r1 = c.message_to_python(c.basic_get(n))
- self.assertTrue(r1)
- self.assertEqual(r1.body,
- 'nthex quick brown fox...'.encode('utf-8'))
- self.assertIsNone(c.basic_get(n))
+ assert r1
+ assert r1.body == 'nthex quick brown fox...'.encode('utf-8')
+ assert c.basic_get(n) is None
consumer_tag = uuid()
c.basic_consume(n + '2', False,
consumer_tag=consumer_tag, callback=lambda *a: None)
- self.assertIn(n + '2', c._active_queues)
+ assert n + '2' in c._active_queues
r2, _ = c.drain_events()
r2 = c.message_to_python(r2)
- self.assertEqual(r2.body,
- 'nthex quick brown fox...'.encode('utf-8'))
- self.assertEqual(r2.delivery_info['exchange'], n)
- self.assertEqual(r2.delivery_info['routing_key'], n)
- with self.assertRaises(virtual.Empty):
+ assert r2.body == 'nthex quick brown fox...'.encode('utf-8')
+ assert r2.delivery_info['exchange'] == n
+ assert r2.delivery_info['routing_key'] == n
+ with pytest.raises(virtual.Empty):
c.drain_events()
c.basic_cancel(consumer_tag)
c._restore(r2)
r3 = c.message_to_python(c.basic_get(n))
- self.assertTrue(r3)
- self.assertEqual(r3.body, 'nthex quick brown fox...'.encode('utf-8'))
- self.assertIsNone(c.basic_get(n))
+ assert r3
+ assert r3.body == 'nthex quick brown fox...'.encode('utf-8')
+ assert c.basic_get(n) is None
def test_basic_ack(self):
@@ -371,7 +369,7 @@ class test_Channel(Case):
self.channel._qos = MockQoS(self.channel)
self.channel.basic_ack('foo')
- self.assertTrue(self.channel._qos.was_acked)
+ assert self.channel._qos.was_acked
def test_basic_recover__requeue(self):
@@ -383,7 +381,7 @@ class test_Channel(Case):
self.channel._qos = MockQoS(self.channel)
self.channel.basic_recover(requeue=True)
- self.assertTrue(self.channel._qos.was_restored)
+ assert self.channel._qos.was_restored
def test_restore_unacked_raises_BaseException(self):
q = self.channel.qos
@@ -394,11 +392,11 @@ class test_Channel(Case):
q.channel._restore.side_effect = SystemExit
errors = q.restore_unacked()
- self.assertIsInstance(errors[0][0], SystemExit)
- self.assertEqual(errors[0][1], 1)
- self.assertFalse(q._delivered)
+ assert isinstance(errors[0][0], SystemExit)
+ assert errors[0][1] == 1
+ assert not q._delivered
- @patch('kombu.transport.virtual.emergency_dump_state')
+ @patch('kombu.transport.virtual.base.emergency_dump_state')
@patch(PRINT_FQDN)
def test_restore_unacked_once_when_unrestored(self, print_,
emergency_dump_state):
@@ -423,7 +421,7 @@ class test_Channel(Case):
emergency_dump_state.assert_called()
def test_basic_recover(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.channel.basic_recover(requeue=False)
def test_basic_reject(self):
@@ -436,39 +434,38 @@ class test_Channel(Case):
self.channel._qos = MockQoS(self.channel)
self.channel.basic_reject('foo')
- self.assertTrue(self.channel._qos.was_rejected)
+ assert self.channel._qos.was_rejected
def test_basic_qos(self):
self.channel.basic_qos(prefetch_count=128)
- self.assertEqual(self.channel._qos.prefetch_count, 128)
+ assert self.channel._qos.prefetch_count == 128
def test_lookup__undeliverable(self, n='test_lookup__undeliverable'):
warnings.resetwarnings()
with warnings.catch_warnings(record=True) as log:
- self.assertListEqual(
- self.channel._lookup(n, n, 'ae.undeliver'),
- ['ae.undeliver'],
- )
- self.assertTrue(log)
- self.assertIn('could not be delivered', log[0].message.args[0])
+ assert self.channel._lookup(n, n, 'ae.undeliver') == [
+ 'ae.undeliver',
+ ]
+ assert log
+ assert 'could not be delivered' in log[0].message.args[0]
def test_context(self):
x = self.channel.__enter__()
- self.assertIs(x, self.channel)
+ assert x is self.channel
x.__exit__()
- self.assertTrue(x.closed)
+ assert x.closed
def test_cycle_property(self):
- self.assertTrue(self.channel.cycle)
+ assert self.channel.cycle
def test_flow(self):
- with self.assertRaises(NotImplementedError):
+ with pytest.raises(NotImplementedError):
self.channel.flow(False)
def test_close_when_no_connection(self):
self.channel.connection = None
self.channel.close()
- self.assertTrue(self.channel.closed)
+ assert self.channel.closed
def test_drain_events_has_get_many(self):
c = self.channel
@@ -483,7 +480,7 @@ class test_Channel(Case):
def test_get_exchanges(self):
self.channel.exchange_declare(exchange='foo')
- self.assertTrue(self.channel.get_exchanges())
+ assert self.channel.get_exchanges()
def test_basic_cancel_not_in_active_queues(self):
c = self.channel
@@ -496,7 +493,7 @@ class test_Channel(Case):
c._active_queues.remove.assert_called_with('foo')
def test_basic_cancel_unknown_ctag(self):
- self.assertIsNone(self.channel.basic_cancel('unknown-tag'))
+ assert self.channel.basic_cancel('unknown-tag') is None
def test_list_bindings(self):
c = self.channel
@@ -504,7 +501,7 @@ class test_Channel(Case):
c.queue_declare(queue='q')
c.queue_bind(queue='q', exchange='foo', routing_key='rk')
- self.assertIn(('q', 'foo', 'rk'), list(c.list_bindings()))
+ assert ('q', 'foo', 'rk') in list(c.list_bindings())
def test_after_reply_message_received(self):
c = self.channel
@@ -513,12 +510,12 @@ class test_Channel(Case):
c.queue_delete.assert_called_with('foo')
def test_queue_delete_unknown_queue(self):
- self.assertIsNone(self.channel.queue_delete('xiwjqjwel'))
+ assert self.channel.queue_delete('xiwjqjwel') is None
def test_queue_declare_passive(self):
has_queue = self.channel._has_queue = Mock()
has_queue.return_value = False
- with self.assertRaises(ChannelError):
+ with pytest.raises(ChannelError):
self.channel.queue_declare(queue='21wisdjwqe', passive=True)
def test_get_message_priority(self):
@@ -528,56 +525,46 @@ class test_Channel(Case):
'the message with priority', priority=priority,
)
- self.assertEqual(
- self.channel._get_message_priority(_message(5)), 5,
- )
- self.assertEqual(
- self.channel._get_message_priority(
- _message(self.channel.min_priority - 10),
- ),
- self.channel.min_priority,
- )
- self.assertEqual(
- self.channel._get_message_priority(
- _message(self.channel.max_priority + 10),
- ),
- self.channel.max_priority,
- )
- self.assertEqual(
- self.channel._get_message_priority(_message('foobar')),
- self.channel.default_priority,
- )
- self.assertEqual(
- self.channel._get_message_priority(_message(2), reverse=True),
- self.channel.max_priority - 2,
- )
+ assert self.channel._get_message_priority(_message(5)) == 5
+ assert self.channel._get_message_priority(
+ _message(self.channel.min_priority - 10)
+ ) == self.channel.min_priority
+ assert self.channel._get_message_priority(
+ _message(self.channel.max_priority + 10),
+ ) == self.channel.max_priority
+ assert self.channel._get_message_priority(
+ _message('foobar'),
+ ) == self.channel.default_priority
+ assert self.channel._get_message_priority(
+ _message(2), reverse=True,
+ ) == self.channel.max_priority - 2
-class test_Transport(Case):
+class test_Transport:
def setup(self):
self.transport = client().transport
def test_custom_polling_interval(self):
x = client(transport_options=dict(polling_interval=32.3))
- self.assertEqual(x.transport.polling_interval, 32.3)
+ assert x.transport.polling_interval == 32.3
def test_close_connection(self):
c1 = self.transport.create_channel(self.transport)
c2 = self.transport.create_channel(self.transport)
- self.assertEqual(len(self.transport.channels), 2)
+ assert len(self.transport.channels) == 2
self.transport.close_connection(self.transport)
- self.assertFalse(self.transport.channels)
+ assert not self.transport.channels
del(c1) # so pyflakes doesn't complain
del(c2)
def test_drain_channel(self):
channel = self.transport.create_channel(self.transport)
- with self.assertRaises(virtual.Empty):
+ with pytest.raises(virtual.Empty):
self.transport._drain_channel(channel)
def test__deliver__no_queue(self):
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.transport._deliver(Mock(name='msg'), queue=None)
def test__reject_inbound_message(self):
@@ -601,12 +588,12 @@ class test_Transport(Case):
callback.assert_called_with(msg)
def test_on_message_ready__no_queue(self):
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.transport.on_message_ready(
Mock(name='channel'), Mock(name='msg'), queue=None)
def test_on_message_ready__no_callback(self):
self.transport._callbacks = {}
- with self.assertRaises(KeyError):
+ with pytest.raises(KeyError):
self.transport.on_message_ready(
Mock(name='channel'), Mock(name='msg'), queue='q1')
diff --git a/t/unit/transport/virtual/test_exchange.py b/t/unit/transport/virtual/test_exchange.py
new file mode 100644
index 00000000..93d48228
--- /dev/null
+++ b/t/unit/transport/virtual/test_exchange.py
@@ -0,0 +1,169 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from case import Mock
+
+from kombu import Connection
+from kombu.transport.virtual import exchange
+
+from t.mocks import Transport
+
+
+class ExchangeCase:
+ type = None
+
+ def setup(self):
+ if self.type:
+ self.e = self.type(Connection(transport=Transport).channel())
+
+
+class test_Direct(ExchangeCase):
+ type = exchange.DirectExchange
+ table = [('rFoo', None, 'qFoo'),
+ ('rFoo', None, 'qFox'),
+ ('rBar', None, 'qBar'),
+ ('rBaz', None, 'qBaz')]
+
+ @pytest.mark.parametrize('exchange,routing_key,default,expected', [
+ ('eFoo', 'rFoo', None, {'qFoo', 'qFox'}),
+ ('eMoz', 'rMoz', 'DEFAULT', set()),
+ ('eBar', 'rBar', None, {'qBar'}),
+ ])
+ def test_lookup(self, exchange, routing_key, default, expected):
+ assert self.e.lookup(
+ self.table, exchange, routing_key, default) == expected
+
+
+class test_Fanout(ExchangeCase):
+ type = exchange.FanoutExchange
+ table = [(None, None, 'qFoo'),
+ (None, None, 'qFox'),
+ (None, None, 'qBar')]
+
+ def test_lookup(self):
+ assert self.e.lookup(self.table, 'eFoo', 'rFoo', None) == {
+ 'qFoo', 'qFox', 'qBar',
+ }
+
+ def test_deliver_when_fanout_supported(self):
+ self.e.channel = Mock()
+ self.e.channel.supports_fanout = True
+ message = Mock()
+
+ self.e.deliver(message, 'exchange', 'rkey')
+ self.e.channel._put_fanout.assert_called_with(
+ 'exchange', message, 'rkey',
+ )
+
+ def test_deliver_when_fanout_unsupported(self):
+ self.e.channel = Mock()
+ self.e.channel.supports_fanout = False
+
+ self.e.deliver(Mock(), 'exchange', None)
+ self.e.channel._put_fanout.assert_not_called()
+
+
+class test_Topic(ExchangeCase):
+ type = exchange.TopicExchange
+ table = [
+ ('stock.#', None, 'rFoo'),
+ ('stock.us.*', None, 'rBar'),
+ ]
+
+ def setup(self):
+ ExchangeCase.setup(self)
+ self.table = [(rkey, self.e.key_to_pattern(rkey), queue)
+ for rkey, _, queue in self.table]
+
+ def test_prepare_bind(self):
+ x = self.e.prepare_bind('qFoo', 'eFoo', 'stock.#', {})
+ assert x == ('stock.#', r'^stock\..*?$', 'qFoo')
+
+ @pytest.mark.parametrize('exchange,routing_key,default,expected', [
+ ('eFoo', 'stock.us.nasdaq', None, {'rFoo', 'rBar'}),
+ ('eFoo', 'stock.europe.OSE', None, {'rFoo'}),
+ ('eFoo', 'stockxeuropexOSE', None, set()),
+ ('eFoo', 'candy.schleckpulver.snap_crackle', None, set()),
+ ])
+ def test_lookup(self, exchange, routing_key, default, expected):
+ assert self.e.lookup(
+ self.table, exchange, routing_key, default) == expected
+ assert self.e._compiled
+
+ def test_deliver(self):
+ self.e.channel = Mock()
+ self.e.channel._lookup.return_value = ('a', 'b')
+ message = Mock()
+ self.e.deliver(message, 'exchange', 'rkey')
+
+ assert self.e.channel._put.call_args_list == [
+ (('a', message), {}),
+ (('b', message), {}),
+ ]
+
+
+class test_TopicMultibind(ExchangeCase):
+ # Testing message delivery in case of multiple overlapping
+ # bindings for the same queue. As AMQP states, in case of
+ # overlapping bindings, a message must be delivered once to
+ # each matching queue.
+ type = exchange.TopicExchange
+ table = [
+ ('stock', None, 'rFoo'),
+ ('stock.#', None, 'rFoo'),
+ ('stock.us.*', None, 'rFoo'),
+ ('#', None, 'rFoo'),
+ ]
+
+ def setup(self):
+ ExchangeCase.setup(self)
+ self.table = [(rkey, self.e.key_to_pattern(rkey), queue)
+ for rkey, _, queue in self.table]
+
+ @pytest.mark.parametrize('exchange,routing_key,default,expected', [
+ ('eFoo', 'stock.us.nasdaq', None, {'rFoo'}),
+ ('eFoo', 'stock.europe.OSE', None, {'rFoo'}),
+ ('eFoo', 'stockxeuropexOSE', None, {'rFoo'}),
+ ('eFoo', 'candy.schleckpulver.snap_crackle', None, {'rFoo'}),
+ ])
+ def test_lookup(self, exchange, routing_key, default, expected):
+ assert self.e._compiled
+ assert self.e.lookup(
+ self.table, exchange, routing_key, default) == expected
+
+
+class test_ExchangeType(ExchangeCase):
+ type = exchange.ExchangeType
+
+ def test_lookup(self):
+ with pytest.raises(NotImplementedError):
+ self.e.lookup([], 'eFoo', 'rFoo', None)
+
+ def test_prepare_bind(self):
+ assert self.e.prepare_bind('qFoo', 'eFoo', 'rFoo', {}) == (
+ 'rFoo', None, 'qFoo',
+ )
+
+ e1 = dict(
+ type='direct',
+ durable=True,
+ auto_delete=True,
+ arguments={},
+ )
+ e2 = dict(e1, arguments={'expires': 3000})
+
+ @pytest.mark.parametrize('ex,eq,name,type,durable,auto_delete,arguments', [
+ (e1, True, 'eFoo', 'direct', True, True, {}),
+ (e1, False, 'eFoo', 'topic', True, True, {}),
+ (e1, False, 'eFoo', 'direct', False, True, {}),
+ (e1, False, 'eFoo', 'direct', True, False, {}),
+ (e1, False, 'eFoo', 'direct', True, True, {'expires': 3000}),
+ (e2, True, 'eFoo', 'direct', True, True, {'expires': 3000}),
+ (e2, False, 'eFoo', 'direct', True, True, {'expires': 6000}),
+ ])
+ def test_equivalent(
+ self, ex, eq, name, type, durable, auto_delete, arguments):
+ is_eq = self.e.equivalent(
+ ex, name, type, durable, auto_delete, arguments)
+ assert is_eq if eq else not is_eq
diff --git a/t/unit/utils/__init__.py b/t/unit/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/t/unit/utils/__init__.py
diff --git a/kombu/tests/utils/test_amq_manager.py b/t/unit/utils/test_amq_manager.py
index 03eb634e..3ea1e41a 100644
--- a/kombu/tests/utils/test_amq_manager.py
+++ b/t/unit/utils/test_amq_manager.py
@@ -1,22 +1,24 @@
from __future__ import absolute_import, unicode_literals
-from kombu import Connection
+import pytest
+
+from case import mock, patch
-from kombu.tests.case import Case, mock, patch
+from kombu import Connection
-class test_get_manager(Case):
+class test_get_manager:
@mock.mask_modules('pyrabbit')
def test_without_pyrabbit(self):
- with self.assertRaises(ImportError):
+ with pytest.raises(ImportError):
Connection('amqp://').get_manager()
@mock.module_exists('pyrabbit')
def test_with_pyrabbit(self):
with patch('pyrabbit.Client', create=True) as Client:
manager = Connection('amqp://').get_manager()
- self.assertIsNotNone(manager)
+ assert manager is not None
Client.assert_called_with(
'localhost:15672', 'guest', 'guest',
)
@@ -30,7 +32,7 @@ class test_get_manager(Case):
'manager_userid': 'george',
'manager_password': 'bosco',
}).get_manager()
- self.assertIsNotNone(manager)
+ assert manager is not None
Client.assert_called_with(
'admin.mq.vandelay.com:808', 'george', 'bosco',
)
diff --git a/kombu/tests/utils/test_compat.py b/t/unit/utils/test_compat.py
index e4bd2c7a..963c50b6 100644
--- a/kombu/tests/utils/test_compat.py
+++ b/t/unit/utils/test_compat.py
@@ -1,32 +1,30 @@
from __future__ import absolute_import, unicode_literals
-from kombu.utils.compat import entrypoints, maybe_fileno
+from case import Mock, mock, patch
-from kombu.tests.case import Case, Mock, mock, patch
+from kombu.utils.compat import entrypoints, maybe_fileno
-class test_entrypoints(Case):
+class test_entrypoints:
@mock.mask_modules('pkg_resources')
def test_without_pkg_resources(self):
- self.assertListEqual(list(entrypoints('kombu.test')), [])
+ assert list(entrypoints('kombu.test')) == []
@mock.module_exists('pkg_resources')
def test_with_pkg_resources(self):
with patch('pkg_resources.iter_entry_points', create=True) as iterep:
eps = iterep.return_value = [Mock(), Mock()]
- self.assertTrue(list(entrypoints('kombu.test')))
+ assert list(entrypoints('kombu.test'))
iterep.assert_called_with('kombu.test')
eps[0].load.assert_called_with()
eps[1].load.assert_called_with()
-class test_maybe_fileno(Case):
-
- def test_maybe_fileno(self):
- self.assertEqual(maybe_fileno(3), 3)
- f = Mock(name='file')
- self.assertIs(maybe_fileno(f), f.fileno())
- f.fileno.side_effect = ValueError()
- self.assertIsNone(maybe_fileno(f))
+def test_maybe_fileno():
+ assert maybe_fileno(3) == 3
+ f = Mock(name='file')
+ assert maybe_fileno(f) is f.fileno()
+ f.fileno.side_effect = ValueError()
+ assert maybe_fileno(f) is None
diff --git a/kombu/tests/utils/test_debug.py b/t/unit/utils/test_debug.py
index 9f10cc76..732638cb 100644
--- a/kombu/tests/utils/test_debug.py
+++ b/t/unit/utils/test_debug.py
@@ -2,13 +2,13 @@ from __future__ import absolute_import, unicode_literals
import logging
+from case import Mock, patch
+
from kombu.five import bytes_if_py2
from kombu.utils.debug import Logwrapped, setup_logging
-from kombu.tests.case import Case, Mock, patch
-
-class test_setup_logging(Case):
+class test_setup_logging:
def test_adds_handlers_sets_level(self):
with patch('kombu.utils.debug.get_logger') as get_logger:
@@ -21,7 +21,7 @@ class test_setup_logging(Case):
logger.setLevel.assert_called_with(logging.DEBUG)
-class test_Logwrapped(Case):
+class test_Logwrapped:
def test_wraps(self):
with patch('kombu.utils.debug.get_logger') as get_logger:
@@ -29,13 +29,13 @@ class test_Logwrapped(Case):
W = Logwrapped(Mock(), 'kombu.test')
get_logger.assert_called_with('kombu.test')
- self.assertIsNotNone(W.instance)
- self.assertIs(W.logger, logger)
+ assert W.instance is not None
+ assert W.logger is logger
W.instance.__repr__ = lambda s: bytes_if_py2('foo')
- self.assertEqual(repr(W), 'foo')
+ assert repr(W) == 'foo'
W.instance.some_attr = 303
- self.assertEqual(W.some_attr, 303)
+ assert W.some_attr == 303
W.instance.some_method.__name__ = bytes_if_py2('some_method')
W.some_method(1, 2, kw=1)
@@ -50,6 +50,6 @@ class test_Logwrapped(Case):
W.ident = 'ident'
W.some_method(kw=1)
logger.debug.assert_called()
- self.assertIn('ident', logger.debug.call_args[0][0])
+ assert 'ident' in logger.debug.call_args[0][0]
- self.assertEqual(dir(W), dir(W.instance))
+ assert dir(W) == dir(W.instance)
diff --git a/t/unit/utils/test_div.py b/t/unit/utils/test_div.py
new file mode 100644
index 00000000..f0b1a058
--- /dev/null
+++ b/t/unit/utils/test_div.py
@@ -0,0 +1,49 @@
+from __future__ import absolute_import, unicode_literals
+
+import pickle
+
+from io import StringIO, BytesIO
+
+from kombu.utils.div import emergency_dump_state
+
+
+class MyStringIO(StringIO):
+
+ def close(self):
+ pass
+
+
+class MyBytesIO(BytesIO):
+
+ def close(self):
+ pass
+
+
+class test_emergency_dump_state:
+
+ def test_dump(self, stdouts):
+ fh = MyBytesIO()
+ stderr = StringIO()
+ emergency_dump_state(
+ {'foo': 'bar'}, open_file=lambda n, m: fh, stderr=stderr)
+ assert pickle.loads(fh.getvalue()) == {'foo': 'bar'}
+ assert stderr.getvalue()
+ assert not stdouts.stdout.getvalue()
+
+ def test_dump_second_strategy(self, stdouts):
+ fh = MyStringIO()
+ stderr = StringIO()
+
+ def raise_something(*args, **kwargs):
+ raise KeyError('foo')
+
+ emergency_dump_state(
+ {'foo': 'bar'},
+ open_file=lambda n, m: fh,
+ dump=raise_something,
+ stderr=stderr,
+ )
+ assert 'foo' in fh.getvalue()
+ assert 'bar' in fh.getvalue()
+ assert stderr.getvalue()
+ assert not stdouts.stdout.getvalue()
diff --git a/t/unit/utils/test_encoding.py b/t/unit/utils/test_encoding.py
new file mode 100644
index 00000000..e3d1040a
--- /dev/null
+++ b/t/unit/utils/test_encoding.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import, unicode_literals
+
+import sys
+
+from contextlib import contextmanager
+
+from case import patch, skip
+
+from kombu.five import bytes_t, string_t
+from kombu.utils.encoding import (
+ get_default_encoding_file, safe_str,
+ set_default_encoding_file, default_encoding,
+)
+
+
+@contextmanager
+def clean_encoding():
+ old_encoding = sys.modules.pop('kombu.utils.encoding', None)
+ import kombu.utils.encoding
+ try:
+ yield kombu.utils.encoding
+ finally:
+ if old_encoding:
+ sys.modules['kombu.utils.encoding'] = old_encoding
+
+
+class test_default_encoding:
+
+ def test_set_default_file(self):
+ prev = get_default_encoding_file()
+ try:
+ set_default_encoding_file('/foo.txt')
+ assert get_default_encoding_file() == '/foo.txt'
+ finally:
+ set_default_encoding_file(prev)
+
+ @patch('sys.getfilesystemencoding')
+ def test_default(self, getdefaultencoding):
+ getdefaultencoding.return_value = 'ascii'
+ with clean_encoding() as encoding:
+ enc = encoding.default_encoding()
+ if sys.platform.startswith('java'):
+ assert enc == 'utf-8'
+ else:
+ assert enc == 'ascii'
+ getdefaultencoding.assert_called_with()
+
+
+@skip.if_python3()
+def test_str_to_bytes():
+ with clean_encoding() as e:
+ assert isinstance(e.str_to_bytes('foobar'), bytes_t)
+
+
+@skip.if_python3()
+def test_from_utf8():
+ with clean_encoding() as e:
+ assert isinstance(e.from_utf8('foobar'), bytes_t)
+
+
+@skip.if_python3()
+def test_default_encode():
+ with clean_encoding() as e:
+ assert e.default_encode(b'foo')
+
+
+class test_safe_str:
+
+ def setup(self):
+ self._encoding = self.patching('sys.getfilesystemencoding')
+ self._encoding.return_value = 'ascii'
+
+ def test_when_bytes(self):
+ assert safe_str('foo') == 'foo'
+
+ def test_when_unicode(self):
+ assert isinstance(safe_str('foo'), string_t)
+
+ def test_when_encoding_utf8(self):
+ self._encoding.return_value = 'utf-8'
+ assert default_encoding() == 'utf-8'
+ s = 'The quiæk fåx jømps øver the lazy dåg'
+ res = safe_str(s)
+ assert isinstance(res, str)
+
+ def test_when_containing_high_chars(self):
+ self._encoding.return_value = 'ascii'
+ s = 'The quiæk fåx jømps øver the lazy dåg'
+ res = safe_str(s)
+ assert isinstance(res, str)
+ assert len(s) == len(res)
+
+ def test_when_not_string(self):
+ o = object()
+ assert safe_str(o) == repr(o)
+
+ def test_when_unrepresentable(self):
+
+ class O(object):
+
+ def __repr__(self):
+ raise KeyError('foo')
+
+ assert '<Unrepresentable' in safe_str(O())
diff --git a/kombu/tests/utils/test_functional.py b/t/unit/utils/test_functional.py
index e953979f..7a97b5b3 100644
--- a/kombu/tests/utils/test_functional.py
+++ b/t/unit/utils/test_functional.py
@@ -1,9 +1,12 @@
from __future__ import absolute_import, unicode_literals
import pickle
+import pytest
from itertools import count
+from case import Mock, mock, skip
+
from kombu.five import items
from kombu.utils import functional as utils
from kombu.utils.functional import (
@@ -11,21 +14,16 @@ from kombu.utils.functional import (
maybe_evaluate, maybe_list, reprcall, reprkwargs, retry_over_time,
)
-from kombu.tests.case import Case, Mock, mock, skip
-
-class test_ChannelPromise(Case):
+class test_ChannelPromise:
def test_repr(self):
obj = Mock(name='cb')
- self.assertIn(
- 'promise',
- repr(ChannelPromise(obj)),
- )
+ assert 'promise' in repr(ChannelPromise(obj))
obj.assert_not_called()
-class test_shufflecycle(Case):
+class test_shufflecycle:
def test_shuffles(self):
prev_repeat, utils.repeat = utils.repeat, Mock()
@@ -37,8 +35,8 @@ class test_shufflecycle(Case):
for i in range(10):
next(cycle)
utils.repeat.assert_called_with(None)
- self.assertTrue(seen.issubset(values))
- with self.assertRaises(StopIteration):
+ assert seen.issubset(values)
+ with pytest.raises(StopIteration):
next(cycle)
next(cycle)
finally:
@@ -49,7 +47,7 @@ def double(x):
return x * 2
-class test_LRUCache(Case):
+class test_LRUCache:
def test_expires(self):
limit = 100
@@ -57,16 +55,16 @@ class test_LRUCache(Case):
slots = list(range(limit * 2))
for i in slots:
x[i] = i
- self.assertListEqual(list(x.keys()), list(slots[limit:]))
- self.assertTrue(x.items())
- self.assertTrue(x.values())
+ assert list(x.keys()) == list(slots[limit:])
+ assert x.items()
+ assert x.values()
def test_is_pickleable(self):
x = LRUCache(limit=10)
x.update(luke=1, leia=2)
y = pickle.loads(pickle.dumps(x))
- self.assertEqual(y.limit, y.limit)
- self.assertEqual(y, x)
+ assert y.limit == y.limit
+ assert y == x
def test_update_expires(self):
limit = 100
@@ -75,117 +73,111 @@ class test_LRUCache(Case):
for i in slots:
x.update({i: i})
- self.assertListEqual(list(x.keys()), list(slots[limit:]))
+ assert list(x.keys()) == list(slots[limit:])
def test_least_recently_used(self):
x = LRUCache(3)
x[1], x[2], x[3] = 1, 2, 3
- self.assertEqual(list(x.keys()), [1, 2, 3])
+ assert list(x.keys()), [1, 2 == 3]
x[4], x[5] = 4, 5
- self.assertEqual(list(x.keys()), [3, 4, 5])
+ assert list(x.keys()), [3, 4 == 5]
# access 3, which makes it the last used key.
x[3]
x[6] = 6
- self.assertEqual(list(x.keys()), [5, 3, 6])
+ assert list(x.keys()), [5, 3 == 6]
x[7] = 7
- self.assertEqual(list(x.keys()), [3, 6, 7])
+ assert list(x.keys()), [3, 6 == 7]
def test_update_larger_than_cache_size(self):
x = LRUCache(2)
x.update({x: x for x in range(100)})
- self.assertEqual(list(x.keys()), [98, 99])
+ assert list(x.keys()), [98 == 99]
def test_items(self):
c = LRUCache()
c.update(a=1, b=2, c=3)
- self.assertTrue(list(items(c)))
+ assert list(items(c))
def test_incr(self):
c = LRUCache()
c.update(a='1')
c.incr('a')
- self.assertEqual(c['a'], '2')
-
+ assert c['a'] == '2'
-class test_memoize(Case):
- def test_memoize(self):
- counter = count(1)
+def test_memoize():
+ counter = count(1)
- @memoize(maxsize=2)
- def x(i):
- return next(counter)
+ @memoize(maxsize=2)
+ def x(i):
+ return next(counter)
- self.assertEqual(x(1), 1)
- self.assertEqual(x(1), 1)
- self.assertEqual(x(2), 2)
- self.assertEqual(x(3), 3)
- self.assertEqual(x(1), 4)
- x.clear()
- self.assertEqual(x(3), 5)
+ assert x(1) == 1
+ assert x(1) == 1
+ assert x(2) == 2
+ assert x(3) == 3
+ assert x(1) == 4
+ x.clear()
+ assert x(3) == 5
-class test_lazy(Case):
+class test_lazy:
def test__str__(self):
- self.assertEqual(
- str(lazy(lambda: 'the quick brown fox')),
- 'the quick brown fox',
- )
+ assert (str(lazy(lambda: 'the quick brown fox')) ==
+ 'the quick brown fox')
def test__repr__(self):
- self.assertEqual(
- repr(lazy(lambda: 'fi fa fo')).strip('u'),
- "'fi fa fo'",
- )
+ assert repr(lazy(lambda: 'fi fa fo')).strip('u') == "'fi fa fo'"
@skip.if_python3()
def test__cmp__(self):
- self.assertEqual(lazy(lambda: 10).__cmp__(lazy(lambda: 20)), -1)
- self.assertEqual(lazy(lambda: 10).__cmp__(5), 1)
+ assert lazy(lambda: 10).__cmp__(lazy(lambda: 20)) == -1
+ assert lazy(lambda: 10).__cmp__(5) == 1
def test_evaluate(self):
- self.assertEqual(lazy(lambda: 2 + 2)(), 4)
- self.assertEqual(lazy(lambda x: x * 4, 2), 8)
- self.assertEqual(lazy(lambda x: x * 8, 2)(), 16)
+ assert lazy(lambda: 2 + 2)() == 4
+ assert lazy(lambda x: x * 4, 2) == 8
+ assert lazy(lambda x: x * 8, 2)() == 16
def test_cmp(self):
- self.assertEqual(lazy(lambda: 10), lazy(lambda: 10))
- self.assertNotEqual(lazy(lambda: 10), lazy(lambda: 20))
+ assert lazy(lambda: 10) == lazy(lambda: 10)
+ assert lazy(lambda: 10) != lazy(lambda: 20)
def test__reduce__(self):
x = lazy(double, 4)
y = pickle.loads(pickle.dumps(x))
- self.assertEqual(x(), y())
+ assert x() == y()
def test__deepcopy__(self):
from copy import deepcopy
x = lazy(double, 4)
y = deepcopy(x)
- self.assertEqual(x._fun, y._fun)
- self.assertEqual(x._args, y._args)
- self.assertEqual(x(), y())
+ assert x._fun == y._fun
+ assert x._args == y._args
+ assert x() == y()
-class test_maybe_evaluate(Case):
+@pytest.mark.parametrize('obj,expected', [
+ (lazy(lambda: 10), 10),
+ (20, 20),
+])
+def test_maybe_evaluate(obj, expected):
+ assert maybe_evaluate(obj) == expected
- def test_evaluates(self):
- self.assertEqual(maybe_evaluate(lazy(lambda: 10)), 10)
- self.assertEqual(maybe_evaluate(20), 20)
+class test_retry_over_time:
-class test_retry_over_time(Case):
+ class Predicate(Exception):
+ pass
def setup(self):
self.index = 0
- class Predicate(Exception):
- pass
-
def myfun(self):
if self.index < 9:
raise self.Predicate()
@@ -195,7 +187,7 @@ class test_retry_over_time(Case):
interval = next(intervals)
sleepvals = (None, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 16.0)
self.index += 1
- self.assertEqual(interval, sleepvals[self.index])
+ assert interval == sleepvals[self.index]
return interval
@mock.sleepdeprived(module=utils)
@@ -205,28 +197,28 @@ class test_retry_over_time(Case):
utils.count.return_value = list(range(1))
x = retry_over_time(self.myfun, self.Predicate,
errback=None, interval_max=14)
- self.assertIsNone(x)
+ assert x is None
utils.count.return_value = list(range(10))
cb = Mock()
x = retry_over_time(self.myfun, self.Predicate,
errback=self.errback, callback=cb,
interval_max=14)
- self.assertEqual(x, 42)
- self.assertEqual(self.index, 9)
+ assert x == 42
+ assert self.index == 9
cb.assert_called_with()
finally:
utils.count = prev_count
@mock.sleepdeprived(module=utils)
def test_retry_once(self):
- with self.assertRaises(self.Predicate):
+ with pytest.raises(self.Predicate):
retry_over_time(
self.myfun, self.Predicate,
max_retries=1, errback=self.errback, interval_max=14,
)
- self.assertEqual(self.index, 1)
+ assert self.index == 1
# no errback
- with self.assertRaises(self.Predicate):
+ with pytest.raises(self.Predicate):
retry_over_time(
self.myfun, self.Predicate,
max_retries=1, errback=None, interval_max=14,
@@ -250,38 +242,39 @@ class test_retry_over_time(Case):
self.calls += 1
fun = Fun()
- self.assertEqual(
- retry_over_time(
- fun, self.Predicate,
- max_retries=0, errback=None, interval_max=14,
- ),
- 42,
- )
- self.assertEqual(fun.calls, 11)
-
-
-class test_utils(Case):
-
- def test_maybe_list(self):
- self.assertIsNone(maybe_list(None))
- self.assertEqual(maybe_list(1), [1])
- self.assertEqual(maybe_list([1, 2, 3]), [1, 2, 3])
-
- def test_fxrange_no_repeatlast(self):
- self.assertEqual(list(fxrange(1.0, 3.0, 1.0)),
- [1.0, 2.0, 3.0])
-
- def test_fxrangemax(self):
- self.assertEqual(list(fxrangemax(1.0, 3.0, 1.0, 30.0)),
- [1.0, 2.0, 3.0, 3.0, 3.0, 3.0,
- 3.0, 3.0, 3.0, 3.0, 3.0])
- self.assertEqual(list(fxrangemax(1.0, None, 1.0, 30.0)),
- [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
-
- def test_reprkwargs(self):
- self.assertTrue(reprkwargs({'foo': 'bar', 1: 2, 'k': 'v'}))
-
- def test_reprcall(self):
- self.assertTrue(
- reprcall('add', (2, 2), {'copy': True}),
- )
+ assert retry_over_time(
+ fun, self.Predicate,
+ max_retries=0, errback=None, interval_max=14) == 42
+ assert fun.calls == 11
+
+
+@pytest.mark.parametrize('obj,expected', [
+ (None, None),
+ (1, [1]),
+ ([1, 2, 3], [1, 2, 3]),
+])
+def test_maybe_list(obj, expected):
+ assert maybe_list(obj) == expected
+
+
+def test_fxrange__no_repeatlast():
+ assert list(fxrange(1.0, 3.0, 1.0)) == [1.0, 2.0, 3.0]
+
+
+@pytest.mark.parametrize('args,expected', [
+ ((1.0, 3.0, 1.0, 30.0),
+ [1.0, 2.0, 3.0, 3.0, 3.0, 3.0,
+ 3.0, 3.0, 3.0, 3.0, 3.0]),
+ ((1.0, None, 1.0, 30.0),
+ [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]),
+])
+def test_fxrangemax(args, expected):
+ assert list(fxrangemax(*args)) == expected
+
+
+def test_reprkwargs():
+ assert reprkwargs({'foo': 'bar', 1: 2, 'k': 'v'})
+
+
+def test_reprcall():
+ assert reprcall('add', (2, 2), {'copy': True})
diff --git a/kombu/tests/utils/test_imports.py b/t/unit/utils/test_imports.py
index 49e2a72f..20b9a2f8 100644
--- a/kombu/tests/utils/test_imports.py
+++ b/t/unit/utils/test_imports.py
@@ -1,37 +1,34 @@
from __future__ import absolute_import, unicode_literals
+import pytest
+
+from case import Mock
+
from kombu import Exchange
from kombu.utils.imports import symbol_by_name
-from kombu.tests.case import Case, Mock
-
-class test_symbol_by_name(Case):
+class test_symbol_by_name:
def test_instance_returns_instance(self):
instance = object()
- self.assertIs(symbol_by_name(instance), instance)
+ assert symbol_by_name(instance) is instance
def test_returns_default(self):
default = object()
- self.assertIs(
- symbol_by_name('xyz.ryx.qedoa.weq:foz', default=default),
- default,
- )
+ assert symbol_by_name(
+ 'xyz.ryx.qedoa.weq:foz', default=default) is default
def test_no_default(self):
- with self.assertRaises(ImportError):
+ with pytest.raises(ImportError):
symbol_by_name('xyz.ryx.qedoa.weq:foz')
def test_imp_reraises_ValueError(self):
imp = Mock()
imp.side_effect = ValueError()
- with self.assertRaises(ValueError):
+ with pytest.raises(ValueError):
symbol_by_name('kombu.Connection', imp=imp)
def test_package(self):
- self.assertIs(
- symbol_by_name('.entity:Exchange', package='kombu'),
- Exchange,
- )
- self.assertTrue(symbol_by_name(':Consumer', package='kombu'))
+ assert symbol_by_name('.entity:Exchange', package='kombu') is Exchange
+ assert symbol_by_name(':Consumer', package='kombu')
diff --git a/kombu/tests/utils/test_json.py b/t/unit/utils/test_json.py
index 1ad8d1c9..ae7b374d 100644
--- a/kombu/tests/utils/test_json.py
+++ b/t/unit/utils/test_json.py
@@ -1,17 +1,18 @@
from __future__ import absolute_import, unicode_literals
+import pytest
import pytz
from datetime import datetime
from decimal import Decimal
from uuid import uuid4
+from case import MagicMock, Mock, skip
+
from kombu.five import text_t
from kombu.utils.encoding import str_to_bytes
from kombu.utils.json import _DecodeError, dumps, loads
-from kombu.tests.case import Case, MagicMock, Mock, skip
-
class Custom(object):
@@ -22,7 +23,7 @@ class Custom(object):
return self.data
-class test_JSONEncoder(Case):
+class test_JSONEncoder:
def test_datetime(self):
now = datetime.utcnow()
@@ -34,64 +35,58 @@ class test_JSONEncoder(Case):
'date': now.date(),
'time': now.time()},
))
- self.assertDictEqual(serialized, {
+ assert serialized == {
'datetime': now.isoformat(),
'tz': '{0}Z'.format(now_utc.isoformat().split('+', 1)[0]),
'time': now.time().isoformat(),
'date': stripped.isoformat(),
- })
+ }
def test_Decimal(self):
d = Decimal('3314132.13363235235324234123213213214134')
- self.assertDictEqual(loads(dumps({'d': d})), {
- 'd': text_t(d),
- })
+ assert loads(dumps({'d': d})), {'d': text_t(d)}
def test_UUID(self):
id = uuid4()
- self.assertDictEqual(loads(dumps({'u': id})), {
- 'u': text_t(id),
- })
+ assert loads(dumps({'u': id})), {'u': text_t(id)}
def test_default(self):
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
dumps({'o': object()})
-class test_dumps_loads(Case):
+class test_dumps_loads:
def test_dumps_custom_object(self):
x = {'foo': Custom({'a': 'b'})}
- self.assertEqual(loads(dumps(x)), {'foo': x['foo'].__json__()})
+ assert loads(dumps(x)) == {'foo': x['foo'].__json__()}
def test_dumps_custom_object_no_json(self):
x = {'foo': object()}
- with self.assertRaises(TypeError):
+ with pytest.raises(TypeError):
dumps(x)
def test_loads_memoryview(self):
- self.assertEqual(
- loads(memoryview(bytearray(dumps({'x': 'z'}), encoding='utf-8'))),
- {'x': 'z'},
- )
+ assert loads(
+ memoryview(bytearray(dumps({'x': 'z'}), encoding='utf-8'))
+ ) == {'x': 'z'}
def test_loads_bytearray(self):
- self.assertEqual(
- loads(bytearray(dumps({'x': 'z'}), encoding='utf-8')),
- {'x': 'z'})
+ assert loads(
+ bytearray(dumps({'x': 'z'}), encoding='utf-8')
+ ) == {'x': 'z'}
def test_loads_bytes(self):
- self.assertEqual(
- loads(str_to_bytes(dumps({'x': 'z'})), decode_bytes=True),
- {'x': 'z'},
- )
+ assert loads(
+ str_to_bytes(dumps({'x': 'z'})),
+ decode_bytes=True) == {'x': 'z'}
@skip.if_python3()
def test_loads_buffer(self):
- self.assertEqual(loads(buffer(dumps({'x': 'z'}))), {'x': 'z'})
+ assert loads(buffer(dumps({'x': 'z'}))) == {'x': 'z'}
def test_loads_DecodeError(self):
_loads = Mock(name='_loads')
_loads.side_effect = _DecodeError(
MagicMock(), MagicMock(), MagicMock())
- self.assertEqual(loads(dumps({'x': 'z'}), _loads=_loads), {'x': 'z'})
+ assert loads(dumps({'x': 'z'}), _loads=_loads) == {'x': 'z'}
diff --git a/kombu/tests/utils/test_objects.py b/t/unit/utils/test_objects.py
index 22893e48..8b13a405 100644
--- a/kombu/tests/utils/test_objects.py
+++ b/t/unit/utils/test_objects.py
@@ -2,10 +2,8 @@ from __future__ import absolute_import, unicode_literals
from kombu.utils.objects import cached_property
-from kombu.tests.case import Case
-
-class test_cached_property(Case):
+class test_cached_property:
def test_deleting(self):
@@ -22,10 +20,10 @@ class test_cached_property(Case):
x = X()
del(x.foo)
- self.assertFalse(x.xx)
+ assert not x.xx
x.__dict__['foo'] = 'here'
del(x.foo)
- self.assertEqual(x.xx, 'here')
+ assert x.xx == 'here'
def test_when_access_from_class(self):
@@ -41,15 +39,15 @@ class test_cached_property(Case):
self.xx = 10
desc = X.__dict__['foo']
- self.assertIs(X.foo, desc)
+ assert X.foo is desc
- self.assertIs(desc.__get__(None), desc)
- self.assertIs(desc.__set__(None, 1), desc)
- self.assertIs(desc.__delete__(None), desc)
- self.assertTrue(desc.setter(1))
+ assert desc.__get__(None) is desc
+ assert desc.__set__(None, 1) is desc
+ assert desc.__delete__(None) is desc
+ assert desc.setter(1)
x = X()
x.foo = 30
- self.assertEqual(x.xx, 10)
+ assert x.xx == 10
del(x.foo)
diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py
new file mode 100644
index 00000000..894286cd
--- /dev/null
+++ b/t/unit/utils/test_scheduling.py
@@ -0,0 +1,102 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from kombu.utils.scheduling import FairCycle, cycle_by_name
+
+
+class MyEmpty(Exception):
+ pass
+
+
+def consume(fun, n):
+ r = []
+ for i in range(n):
+ r.append(fun())
+ return r
+
+
+class test_FairCycle:
+
+ def test_cycle(self):
+ resources = ['a', 'b', 'c', 'd', 'e']
+
+ def echo(r, timeout=None):
+ return r
+
+ # cycle should be ['a', 'b', 'c', 'd', 'e', ... repeat]
+ cycle = FairCycle(echo, resources, MyEmpty)
+ for i in range(len(resources)):
+ assert cycle.get() == (resources[i], resources[i])
+ for i in range(len(resources)):
+ assert cycle.get() == (resources[i], resources[i])
+
+ def test_cycle_breaks(self):
+ resources = ['a', 'b', 'c', 'd', 'e']
+
+ def echo(r):
+ if r == 'c':
+ raise MyEmpty(r)
+ return r
+
+ cycle = FairCycle(echo, resources, MyEmpty)
+ assert consume(cycle.get, len(resources)) == [
+ ('a', 'a'), ('b', 'b'), ('d', 'd'),
+ ('e', 'e'), ('a', 'a'),
+ ]
+ assert consume(cycle.get, len(resources)) == [
+ ('b', 'b'), ('d', 'd'), ('e', 'e'),
+ ('a', 'a'), ('b', 'b'),
+ ]
+ cycle2 = FairCycle(echo, ['c', 'c'], MyEmpty)
+ with pytest.raises(MyEmpty):
+ consume(cycle2.get, 3)
+
+ def test_cycle_no_resources(self):
+ cycle = FairCycle(None, [], MyEmpty)
+ cycle.pos = 10
+
+ with pytest.raises(MyEmpty):
+ cycle._next()
+
+ def test__repr__(self):
+ assert repr(FairCycle(lambda x: x, [1, 2, 3], MyEmpty))
+
+
+def test_round_robin_cycle():
+ it = cycle_by_name('round_robin')(['A', 'B', 'C'])
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('B')
+ assert it.consume(3) == ['A', 'C', 'B']
+ it.rotate('A')
+ assert it.consume(3) == ['C', 'B', 'A']
+ it.rotate('A')
+ assert it.consume(3) == ['C', 'B', 'A']
+ it.rotate('C')
+ assert it.consume(3) == ['B', 'A', 'C']
+
+
+def test_priority_cycle():
+ it = cycle_by_name('priority')(['A', 'B', 'C'])
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('B')
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('A')
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('A')
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('C')
+ assert it.consume(3) == ['A', 'B', 'C']
+
+
+def test_sorted_cycle():
+ it = cycle_by_name('sorted')(['B', 'C', 'A'])
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('B')
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('A')
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('A')
+ assert it.consume(3) == ['A', 'B', 'C']
+ it.rotate('C')
+ assert it.consume(3) == ['A', 'B', 'C']
diff --git a/t/unit/utils/test_url.py b/t/unit/utils/test_url.py
new file mode 100644
index 00000000..3d2b0ede
--- /dev/null
+++ b/t/unit/utils/test_url.py
@@ -0,0 +1,39 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from kombu.utils.url import as_url, parse_url, maybe_sanitize_url
+
+
+def test_parse_url():
+ assert parse_url('amqp://user:pass@localhost:5672/my/vhost') == {
+ 'transport': 'amqp',
+ 'userid': 'user',
+ 'password': 'pass',
+ 'hostname': 'localhost',
+ 'port': 5672,
+ 'virtual_host': 'my/vhost',
+ }
+
+
+@pytest.mark.parametrize('urltuple,expected', [
+ (('https',), 'https:///'),
+ (('https', 'e.com'), 'https://e.com/'),
+ (('https', 'e.com', 80), 'https://e.com:80/'),
+ (('https', 'e.com', 80, 'u'), 'https://u@e.com:80/'),
+ (('https', 'e.com', 80, 'u', 'p'), 'https://u:p@e.com:80/'),
+ (('https', 'e.com', 80, None, 'p'), 'https://:p@e.com:80/'),
+ (('https', 'e.com', 80, None, 'p', '/foo'), 'https://:p@e.com:80//foo'),
+])
+def test_as_url(urltuple, expected):
+ assert as_url(*urltuple) == expected
+
+
+@pytest.mark.parametrize('url,expected', [
+ ('foo', 'foo'),
+ ('http://u:p@e.com//foo', 'http://u:**@e.com//foo'),
+])
+def test_maybe_sanitize_url(url, expected):
+ assert maybe_sanitize_url(url) == expected
+ assert (maybe_sanitize_url('http://u:p@e.com//foo') ==
+ 'http://u:**@e.com//foo')
diff --git a/t/unit/utils/test_utils.py b/t/unit/utils/test_utils.py
new file mode 100644
index 00000000..f4668b83
--- /dev/null
+++ b/t/unit/utils/test_utils.py
@@ -0,0 +1,22 @@
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+
+from kombu import version_info_t
+from kombu.utils.text import version_string_as_tuple
+
+
+def test_dir():
+ import kombu
+ assert dir(kombu)
+
+
+@pytest.mark.parametrize('version,expected', [
+ ('3', version_info_t(3, 0, 0, '', '')),
+ ('3.3', version_info_t(3, 3, 0, '', '')),
+ ('3.3.1', version_info_t(3, 3, 1, '', '')),
+ ('3.3.1a3', version_info_t(3, 3, 1, 'a3', '')),
+ ('3.3.1.a3.40c32', version_info_t(3, 3, 1, 'a3', '40c32')),
+])
+def test_version_string_as_tuple(version, expected):
+ assert version_string_as_tuple(version) == expected
diff --git a/kombu/tests/utils/test_uuid.py b/t/unit/utils/test_uuid.py
index ac018f09..f6b32c28 100644
--- a/kombu/tests/utils/test_uuid.py
+++ b/t/unit/utils/test_uuid.py
@@ -2,16 +2,14 @@ from __future__ import absolute_import, unicode_literals
from kombu.utils.uuid import uuid
-from kombu.tests.case import Case
-
-class test_UUID(Case):
+class test_UUID:
def test_uuid4(self):
- self.assertNotEqual(uuid(), uuid())
+ assert uuid() != uuid()
def test_uuid(self):
i1 = uuid()
i2 = uuid()
- self.assertIsInstance(i1, str)
- self.assertNotEqual(i1, i2)
+ assert isinstance(i1, str)
+ assert i1 != i2
diff --git a/tox.ini b/tox.ini
index ebb79951..a6da75c3 100644
--- a/tox.ini
+++ b/tox.ini
@@ -15,8 +15,7 @@ deps=
flake8,flakeplus: -r{toxinidir}/requirements/pkgutils.txt
commands = pip install -U -r{toxinidir}/requirements/dev.txt
- nosetests -vdsx kombu.tests \
- --with-coverage --cover-inclusive --cover-erase []
+ py.test -xv --cov=kombu/ --cov-report=xml --no-cov-on-fail
basepython =
2.7,flakeplus,flake8,apicheck,linkcheck: python2.7