diff options
author | Ask Solem <askh@opera.com> | 2010-11-09 15:27:43 +0100 |
---|---|---|
committer | Ask Solem <askh@opera.com> | 2010-11-09 15:27:43 +0100 |
commit | 402260bf9f3c39986d2fa611758b7d20ddfef9fb (patch) | |
tree | 4a27adf0f5f64ecbb62cb5a6df0e04e17ae3db11 /kombu | |
parent | 9b103e43b6ad84310a796d28a623f2c62713fd15 (diff) | |
download | kombu-402260bf9f3c39986d2fa611758b7d20ddfef9fb.tar.gz |
94% coverage of the core library, but still missing coverage of compat, pidbox, and functional tests for concrete transports beanstalk, redis, couchdb, pika and mongodb
Diffstat (limited to 'kombu')
-rw-r--r-- | kombu/connection.py | 16 | ||||
-rw-r--r-- | kombu/tests/__init__.py | 32 | ||||
-rw-r--r-- | kombu/tests/test_compression.py | 25 | ||||
-rw-r--r-- | kombu/tests/test_connection.py | 76 | ||||
-rw-r--r-- | kombu/tests/test_entities.py | 18 | ||||
-rw-r--r-- | kombu/tests/test_serialization.py | 43 | ||||
-rw-r--r-- | kombu/tests/test_transport.py | 32 | ||||
-rw-r--r-- | kombu/tests/test_transport_base.py | 12 | ||||
-rw-r--r-- | kombu/tests/test_transport_memory.py | 119 | ||||
-rw-r--r-- | kombu/tests/test_transport_pyamqplib.py | 32 | ||||
-rw-r--r-- | kombu/tests/test_transport_virtual.py (renamed from kombu/tests/test_virtual.py) | 0 | ||||
-rw-r--r-- | kombu/tests/test_virtual_scheduling.py | 2 | ||||
-rw-r--r-- | kombu/tests/utils.py | 62 | ||||
-rw-r--r-- | kombu/transport/base.py | 4 | ||||
-rw-r--r-- | kombu/transport/memory.py | 2 | ||||
-rw-r--r-- | kombu/transport/pyamqplib.py | 18 | ||||
-rw-r--r-- | kombu/transport/virtual/__init__.py | 27 | ||||
-rw-r--r-- | kombu/transport/virtual/scheduling.py | 2 |
18 files changed, 484 insertions, 38 deletions
diff --git a/kombu/connection.py b/kombu/connection.py index 802ad340..ebcf3c99 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -103,15 +103,13 @@ class BrokerConnection(object): return self.transport.drain_events(self.connection, **kwargs) def _close(self): - try: - if self._connection: - try: - self.transport.close_connection(self._connection) - except self.transport.connection_errors + (AttributeError, ): - pass - self._connection = None - except socket.error: - pass + if self._connection: + try: + self.transport.close_connection(self._connection) + except self.transport.connection_errors + (AttributeError, + socket.error): + pass + self._connection = None self._closed = True def release(self): diff --git a/kombu/tests/__init__.py b/kombu/tests/__init__.py index e69de29b..6328c34c 100644 --- a/kombu/tests/__init__.py +++ b/kombu/tests/__init__.py @@ -0,0 +1,32 @@ +moduleindex = ("kombu.abstract", + "kombu.compat", + "kombu.compression", + "kombu.connection", + "kombu.entity", + "kombu.exceptions", + "kombu.messaging", + "kombu.pidbox", + "kombu.serialization", + "kombu.simple", + "kombu.utils", + "kombu.utils.compat", + "kombu.utils.functional", + "kombu.transport", + "kombu.transport.base", + "kombu.transport.beanstalk", + "kombu.transport.memory", + "kombu.transport.mongodb", + "kombu.transport.pyamqplib", + "kombu.transport.pycouchdb", + "kombu.transport.pypika", + "kombu.transport.pyredis", + "kombu.transport.virtual", + "kombu.transport.virtual.exchange", + "kombu.transport.virtual.scheduling") + +def setup(): + # so coverage sees all our modules. + for module in moduleindex: + __import__(module) + + diff --git a/kombu/tests/test_compression.py b/kombu/tests/test_compression.py new file mode 100644 index 00000000..2f509c38 --- /dev/null +++ b/kombu/tests/test_compression.py @@ -0,0 +1,25 @@ +import unittest2 as unittest + +from kombu import compression + + +class test_compression(unittest.TestCase): + + def test_encoders(self): + encoders = compression.encoders() + self.assertIn("application/x-gzip", encoders) + self.assertIn("application/x-bz2", encoders) + + def test_compress__decompress__zlib(self): + text = "The Quick Brown Fox Jumps Over The Lazy Dog" + c, ctype = compression.compress(text, "zlib") + self.assertNotEqual(text, c) + d = compression.decompress(c, ctype) + self.assertEqual(d, text) + + def test_compress__decompress__bzip2(self): + text = "The Brown Quick Fox Over The Lazy Dog Jumps" + c, ctype = compression.compress(text, "bzip2") + self.assertNotEqual(text, c) + d = compression.decompress(c, ctype) + self.assertEqual(d, text) diff --git a/kombu/tests/test_connection.py b/kombu/tests/test_connection.py index 00f62e7f..a6118579 100644 --- a/kombu/tests/test_connection.py +++ b/kombu/tests/test_connection.py @@ -1,14 +1,18 @@ +import pickle import unittest2 as unittest -from kombu.connection import BrokerConnection +from kombu.connection import BrokerConnection, Resource from kombu.tests.mocks import Transport class test_Connection(unittest.TestCase): + def setUp(self): + self.conn = BrokerConnection(port=5672, transport=Transport) + def test_establish_connection(self): - conn = BrokerConnection(port=5672, transport=Transport) + conn = self.conn conn.connect() self.assertTrue(conn.connection.connected) self.assertEqual(conn.host, "localhost:5672") @@ -21,7 +25,7 @@ class test_Connection(unittest.TestCase): self.assertIsInstance(conn.transport, Transport) def test__enter____exit__(self): - conn = BrokerConnection(transport=Transport) + conn = self.conn context = conn.__enter__() self.assertIs(context, conn) conn.connect() @@ -30,6 +34,68 @@ class test_Connection(unittest.TestCase): self.assertIsNone(conn.connection) conn.close() # again + def test_close_survives_connerror(self): + + class _CustomError(Exception): + pass + + class MyTransport(Transport): + connection_errors = (_CustomError, ) + + def close_connection(self, connection): + raise _CustomError("foo") + + conn = BrokerConnection(transport=MyTransport) + conn.connect() + conn.close() + self.assertTrue(conn._closed) + + def test_ensure_connection(self): + self.assertTrue(self.conn.ensure_connection()) + + def test_SimpleQueue(self): + conn = self.conn + q = conn.SimpleQueue("foo") + self.assertTrue(q.channel) + self.assertTrue(q.channel_autoclose) + chan = conn.channel() + q2 = conn.SimpleQueue("foo", channel=chan) + self.assertIs(q2.channel, chan) + self.assertFalse(q2.channel_autoclose) + + def test_SimpleBuffer(self): + conn = self.conn + q = conn.SimpleBuffer("foo") + self.assertTrue(q.channel) + self.assertTrue(q.channel_autoclose) + chan = conn.channel() + q2 = conn.SimpleBuffer("foo", channel=chan) + self.assertIs(q2.channel, chan) + self.assertFalse(q2.channel_autoclose) + + def test__repr__(self): + self.assertTrue(repr(self.conn)) + + def test__reduce__(self): + x = pickle.loads(pickle.dumps(self.conn)) + self.assertDictEqual(x.info(), self.conn.info()) + + def test_channel_errors(self): + + class MyTransport(Transport): + channel_errors = (KeyError, ValueError) + + conn = BrokerConnection(transport=MyTransport) + self.assertTupleEqual(conn.channel_errors, (KeyError, ValueError)) + + def test_connection_errors(self): + + class MyTransport(Transport): + connection_errors = (KeyError, ValueError) + + conn = BrokerConnection(transport=MyTransport) + self.assertTupleEqual(conn.connection_errors, (KeyError, ValueError)) + class ResourceCase(unittest.TestCase): abstract = True @@ -41,6 +107,10 @@ class ResourceCase(unittest.TestCase): self.assertEqual(P._resource.qsize(), avail) self.assertEqual(len(P._dirty), dirty) + def test_setup(self): + if self.abstract: + self.assertRaises(NotImplementedError, Resource) + def test_acquire__release(self): if self.abstract: return diff --git a/kombu/tests/test_entities.py b/kombu/tests/test_entities.py index 473c1d4c..05cadd10 100644 --- a/kombu/tests/test_entities.py +++ b/kombu/tests/test_entities.py @@ -19,6 +19,24 @@ class test_Exchange(unittest.TestCase): self.assertIs(bound.channel, chan) self.assertIn("<bound", repr(bound)) + def test_revive(self): + exchange = Exchange("foo", "direct") + chan = Channel() + + # reviving unbound channel is a noop. + exchange.revive(chan) + self.assertFalse(exchange.is_bound) + self.assertIsNone(exchange._channel) + + bound = exchange.bind(chan) + self.assertTrue(bound.is_bound) + self.assertIs(bound.channel, chan) + + chan2 = Channel() + bound.revive(chan2) + self.assertTrue(bound.is_bound) + self.assertIs(bound._channel, chan2) + def test_assert_is_bound(self): exchange = Exchange("foo", "direct") self.assertRaises(NotBoundError, exchange.declare) diff --git a/kombu/tests/test_serialization.py b/kombu/tests/test_serialization.py index 9ef62fcc..713be0f5 100644 --- a/kombu/tests/test_serialization.py +++ b/kombu/tests/test_serialization.py @@ -1,13 +1,17 @@ #!/usr/bin/python # -*- coding: utf-8 -*- -import cPickle +import cPickle as pickle import sys import unittest2 as unittest from nose import SkipTest -from kombu.serialization import registry +from kombu.serialization import registry, register, SerializerNotInstalled, \ + raw_encode, register_yaml, register_msgpack, \ + decode + +from kombu.tests.utils import mask_modules, module_exists # For content_encoding tests unicode_string = u'abcdé\u8463' @@ -33,7 +37,7 @@ json_data = ('{"int": 10, "float": 3.1415926500000002, ' 'th\\u00e9 lazy dog"}') # Pickle serialization tests -pickle_data = cPickle.dumps(py_data) +pickle_data = pickle.dumps(py_data) # YAML serialization tests yaml_data = ('float: 3.1415926500000002\nint: 10\n' @@ -117,6 +121,7 @@ class test_Serialization(unittest.TestCase): content_encoding='utf-8')) def test_msgpack_decode(self): + register_msgpack() try: import msgpack except ImportError: @@ -130,6 +135,7 @@ class test_Serialization(unittest.TestCase): content_encoding='binary')) def test_msgpack_encode(self): + register_msgpack() try: import msgpack except ImportError: @@ -146,6 +152,7 @@ class test_Serialization(unittest.TestCase): content_encoding='binary')) def test_yaml_decode(self): + register_yaml() try: import yaml except ImportError: @@ -158,6 +165,7 @@ class test_Serialization(unittest.TestCase): content_encoding='utf-8')) def test_yaml_encode(self): + register_yaml() try: import yaml except ImportError: @@ -184,6 +192,35 @@ class test_Serialization(unittest.TestCase): registry.encode(py_data, serializer="pickle")[-1]) + def test_register(self): + register(None, None, None, None) + + def test_set_default_serializer_missing(self): + self.assertRaises(SerializerNotInstalled, + registry._set_default_serializer, "nonexisting") + + def test_encode_missing(self): + self.assertRaises(SerializerNotInstalled, + registry.encode, "foo", serializer="nonexisting") + + def test_raw_encode(self): + self.assertTupleEqual(raw_encode(str("foo")), + ("application/data", "binary", "foo")) + + @mask_modules("yaml") + def test_register_yaml__no_yaml(self): + register_yaml() + self.assertRaises(SerializerNotInstalled, + decode, "foo", "application/x-yaml", "utf-8") + + @mask_modules("msgpack") + def test_register_msgpack__no_msgpack(self): + register_msgpack() + self.assertRaises(SerializerNotInstalled, + decode, "foo", "application/x-msgpack", "utf-8") + + + if __name__ == '__main__': unittest.main() diff --git a/kombu/tests/test_transport.py b/kombu/tests/test_transport.py new file mode 100644 index 00000000..3e1c91e0 --- /dev/null +++ b/kombu/tests/test_transport.py @@ -0,0 +1,32 @@ +import unittest2 as unittest + +from kombu import transport + +from kombu.tests.utils import mask_modules, module_exists + + +class test_transport(unittest.TestCase): + + def test_django_transport(self): + self.assertRaises( + ImportError, + mask_modules("djkombu")(transport.resolve_transport), "django") + + self.assertTupleEqual( + module_exists("djkombu")(transport.resolve_transport)("django"), + ("djkombu.transport", "DatabaseTransport")) + + def test_sqlalchemy_transport(self): + self.assertRaises( + ImportError, + mask_modules("sqlakombu")(transport.resolve_transport), + "sqlalchemy") + + self.assertTupleEqual( + module_exists("sqlakombu")(transport.resolve_transport)( + "sqlalchemy"), + ("sqlakombu.transport", "Transport")) + + def test_resolve_transport__no_class_name(self): + self.assertRaises(KeyError, transport.resolve_transport, + "nonexistant") diff --git a/kombu/tests/test_transport_base.py b/kombu/tests/test_transport_base.py index c89c9e4e..4971a2d5 100644 --- a/kombu/tests/test_transport_base.py +++ b/kombu/tests/test_transport_base.py @@ -12,3 +12,15 @@ class test_interface(unittest.TestCase): def test_close_connection(self): self.assertRaises(NotImplementedError, Transport(None).close_connection, None) + + def test_create_channel(self): + self.assertRaises(NotImplementedError, + Transport(None).create_channel, None) + + def test_close_channel(self): + self.assertRaises(NotImplementedError, + Transport(None).close_channel, None) + + def test_drain_events(self): + self.assertRaises(NotImplementedError, + Transport(None).drain_events, None) diff --git a/kombu/tests/test_transport_memory.py b/kombu/tests/test_transport_memory.py new file mode 100644 index 00000000..64dea766 --- /dev/null +++ b/kombu/tests/test_transport_memory.py @@ -0,0 +1,119 @@ +import socket +import unittest2 as unittest + +from kombu.connection import BrokerConnection +from kombu.entity import Exchange, Queue +from kombu.messaging import Consumer, Producer + + +class test_MemoryTransport(unittest.TestCase): + + def setUp(self): + self.c = BrokerConnection(transport="memory") + self.e = Exchange("test_transport_memory") + self.q = Queue("test_transport_memory", + exchange=self.e, + routing_key="test_transport_memory") + self.q2 = Queue("test_transport_memory2", + exchange=self.e, + routing_key="test_transport_memory2") + + def test_produce_consume_noack(self): + channel = self.c.channel() + producer = Producer(channel, self.e) + consumer = Consumer(channel, self.q, no_ack=True) + + for i in range(10): + producer.publish({"foo": i}, routing_key="test_transport_memory") + + _received = [] + + def callback(message_data, message): + _received.append(message) + + consumer.register_callback(callback) + consumer.consume() + + while 1: + if len(_received) == 10: + break; + self.c.drain_events() + + self.assertEqual(len(_received), 10) + + def test_produce_consume(self): + channel = self.c.channel() + producer = Producer(channel, self.e) + consumer1 = Consumer(channel, self.q) + consumer2 = Consumer(channel, self.q2) + self.q2(channel).declare() + + for i in range(10): + producer.publish({"foo": i}, routing_key="test_transport_memory") + for i in range(10): + producer.publish({"foo": i}, routing_key="test_transport_memory2") + + _received1 = [] + _received2 = [] + + def callback1(message_data, message): + _received1.append(message) + message.ack() + + def callback2(message_data, message): + _received2.append(message) + message.ack() + + consumer1.register_callback(callback1) + consumer2.register_callback(callback2) + + consumer1.consume() + consumer2.consume() + + while 1: + if len(_received1) + len(_received2) == 20: + break; + self.c.drain_events() + + self.assertEqual(len(_received1) + len(_received2), 20) + + # compression + producer.publish({"compressed": True}, + routing_key="test_transport_memory", + compression="zlib") + m = self.q(channel).get() + self.assertDictEqual(m.payload, {"compressed": True}) + + # queue.delete + for i in range(10): + producer.publish({"foo": i}, routing_key="test_transport_memory") + self.assertTrue(self.q(channel).get()) + self.q(channel).delete() + self.q(channel).declare() + self.assertIsNone(self.q(channel).get()) + + # queue.purge + for i in range(10): + producer.publish({"foo": i}, routing_key="test_transport_memory2") + self.assertTrue(self.q2(channel).get()) + self.q2(channel).purge() + self.assertIsNone(self.q2(channel).get()) + + def test_drain_events(self): + self.assertRaises(ValueError, self.c.drain_events, timeout=0.1) + + c1 = self.c.channel() + c2 = self.c.channel() + + self.assertRaises(socket.timeout, self.c.drain_events, timeout=0.1) + + def test_drain_events_unregistered_queue(self): + c1 = self.c.channel() + + class Cycle(object): + + def get(self): + return ("foo", "foo"), c1 + + self.c.transport.cycle = Cycle() + self.assertRaises(KeyError, self.c.drain_events) diff --git a/kombu/tests/test_transport_pyamqplib.py b/kombu/tests/test_transport_pyamqplib.py new file mode 100644 index 00000000..5ec38562 --- /dev/null +++ b/kombu/tests/test_transport_pyamqplib.py @@ -0,0 +1,32 @@ +import unittest2 as unittest + +from kombu.transport import pyamqplib +from kombu.connection import BrokerConnection + + +class test_amqplib(unittest.TestCase): + + def test_conninfo(self): + c = BrokerConnection(userid=None, transport="amqplib") + self.assertRaises(KeyError, c.connect) + c = BrokerConnection(hostname=None, transport="amqplib") + self.assertRaises(KeyError, c.connect) + c = BrokerConnection(password=None, transport="amqplib") + self.assertRaises(KeyError, c.connect) + + def test_default_port(self): + + class Transport(pyamqplib.Transport): + Connection = dict + + c = BrokerConnection(port=None, transport=Transport).connect() + self.assertEqual(c["host"], + "localhost:%s" % (Transport.default_port, )) + + def test_custom_port(self): + + class Transport(pyamqplib.Transport): + Connection = dict + + c = BrokerConnection(port=1337, transport=Transport).connect() + self.assertEqual(c["host"], "localhost:1337") diff --git a/kombu/tests/test_virtual.py b/kombu/tests/test_transport_virtual.py index 44fc4167..44fc4167 100644 --- a/kombu/tests/test_virtual.py +++ b/kombu/tests/test_transport_virtual.py diff --git a/kombu/tests/test_virtual_scheduling.py b/kombu/tests/test_virtual_scheduling.py index 63ecdd1e..cd9706a0 100644 --- a/kombu/tests/test_virtual_scheduling.py +++ b/kombu/tests/test_virtual_scheduling.py @@ -18,7 +18,7 @@ class test_FairCycle(unittest.TestCase): def test_cycle(self): resources = ["a", "b", "c", "d", "e"] - def echo(r): + def echo(r, timeout=None): return r # cycle should be ["a", "b", "c", "d", "e", ... repeat] diff --git a/kombu/tests/utils.py b/kombu/tests/utils.py index ac37891b..df906987 100644 --- a/kombu/tests/utils.py +++ b/kombu/tests/utils.py @@ -1,4 +1,6 @@ +import __builtin__ import sys +import types from StringIO import StringIO @@ -21,3 +23,63 @@ def redirect_stdouts(fun): return _inner +def module_exists(*modules): + + def _inner(fun): + + @wraps(fun) + def __inner(*args, **kwargs): + for module in modules: + if isinstance(module, basestring): + module = types.ModuleType(module) + sys.modules[module.__name__] = module + try: + return fun(*args, **kwargs) + finally: + sys.modules.pop(module.__name__, None) + + return __inner + return _inner + + +# Taken from +# http://bitbucket.org/runeh/snippets/src/tip/missing_modules.py +def mask_modules(*modnames): + """Ban some modules from being importable inside the context + + For example: + + >>> @missing_modules("sys"): + >>> def foo(): + ... try: + ... import sys + ... except ImportError: + ... print "sys not found" + sys not found + + >>> import sys + >>> sys.version + (2, 5, 2, 'final', 0) + + """ + + def _inner(fun): + + @wraps(fun) + def __inner(*args, **kwargs): + realimport = __builtin__.__import__ + + def myimp(name, *args, **kwargs): + if name in modnames: + raise ImportError("No module named %s" % name) + else: + return realimport(name, *args, **kwargs) + + __builtin__.__import__ = myimp + try: + return fun(*args, **kwargs) + finally: + __builtin__.__import__ = realimport + + return __inner + return _inner diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 2fcdf35b..adcca281 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -48,7 +48,7 @@ class Message(object): def __init__(self, channel, body=None, delivery_tag=None, content_type=None, content_encoding=None, delivery_info={}, - properties=None, headers=None, + properties=None, headers=None, postencode=None, **kwargs): self.channel = channel self.body = body @@ -64,6 +64,8 @@ class Message(object): compression = self.headers.get("compression") if compression: self.body = decompress(self.body, compression) + if postencode: + self.body = self.body.encode(postencode) def ack(self): """Acknowledge this message as being processed., diff --git a/kombu/transport/memory.py b/kombu/transport/memory.py index fc9d2f36..714f46af 100644 --- a/kombu/transport/memory.py +++ b/kombu/transport/memory.py @@ -21,7 +21,7 @@ class Channel(virtual.Channel): if queue not in self.queues: self.queues[queue] = Queue() - def _get(self, queue): + def _get(self, queue, timeout=None): return self.queues[queue].get(block=False) def _put(self, queue, message, **kwargs): diff --git a/kombu/transport/pyamqplib.py b/kombu/transport/pyamqplib.py index b3521af8..36b2b979 100644 --- a/kombu/transport/pyamqplib.py +++ b/kombu/transport/pyamqplib.py @@ -26,7 +26,7 @@ DEFAULT_PORT = 5672 transport.AMQP_PROTOCOL_HEADER = "AMQP\x01\x01\x08\x00" -class Connection(amqp.Connection): +class Connection(amqp.Connection): # pragma: no cover def _dispatch_basic_return(self, channel, args, msg): reply_code = args.read_short() @@ -179,6 +179,8 @@ class Channel(_Channel): class Transport(base.Transport): + Connection = Connection + default_port = DEFAULT_PORT connection_errors = (AMQPConnectionException, socket.error, @@ -207,13 +209,13 @@ class Transport(base.Transport): raise KeyError("Missing password for AMQP connection.") if not conninfo.port: conninfo.port = self.default_port - return Connection(host=conninfo.host, - userid=conninfo.userid, - password=conninfo.password, - virtual_host=conninfo.virtual_host, - insist=conninfo.insist, - ssl=conninfo.ssl, - connect_timeout=conninfo.connect_timeout) + return self.Connection(host=conninfo.host, + userid=conninfo.userid, + password=conninfo.password, + virtual_host=conninfo.virtual_host, + insist=conninfo.insist, + ssl=conninfo.ssl, + connect_timeout=conninfo.connect_timeout) def close_connection(self, connection): """Close the AMQP broker connection.""" diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index 78c83589..01c18e0b 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -14,6 +14,7 @@ import socket from itertools import count from multiprocessing.util import Finalize +from time import sleep from Queue import Empty from kombu.transport import base @@ -104,7 +105,7 @@ class QoS(object): while delivered: try: _, message = delivered.popitem() - except KeyError: + except KeyError: # pragma: no cover break try: @@ -144,13 +145,14 @@ class Message(base.Message): def __init__(self, channel, payload, **kwargs): properties = payload["properties"] - fields = {"body": payload.get("body").encode("utf-8"), + fields = {"body": payload.get("body"), "delivery_tag": properties["delivery_tag"], "content_type": payload.get("content-type"), "content_encoding": payload.get("content-encoding"), "headers": payload.get("headers"), "properties": properties, - "delivery_info": properties.get("delivery_info")} + "delivery_info": properties.get("delivery_info"), + "postencode": "utf-8"} super(Message, self).__init__(channel, **dict(kwargs, **fields)) def serializable(self): @@ -170,7 +172,7 @@ class AbstractChannel(object): """ - def _get(self, queue): + def _get(self, queue, timeout=None): """Get next message from `queue`.""" raise NotImplementedError("Virtual channels must implement _get") @@ -205,7 +207,7 @@ class AbstractChannel(object): """ pass - def _poll(self, cycle): + def _poll(self, cycle, timeout=None): """Poll a list of queues for available messages.""" return cycle.get() @@ -331,8 +333,7 @@ class Channel(AbstractChannel): self._consumers.remove(consumer_tag) self._reset_cycle() queue = self._tag_to_queue.pop(consumer_tag, None) - if queue: - self.connection._callbacks.pop(queue, None) + self.connection._callbacks.pop(queue, None) def basic_get(self, queue, **kwargs): """Get message by direct access (synchronous).""" @@ -400,7 +401,7 @@ class Channel(AbstractChannel): if self._consumers and self.qos.can_consume(): if hasattr(self, "_get_many"): return self._get_many(self._active_queues, timeout=timeout) - return self._poll(self.cycle) + return self._poll(self.cycle, timeout=timeout) raise Empty() def message_to_python(self, raw_message): @@ -531,21 +532,23 @@ class Transport(base.Transport): while self.channels: try: channel = self.channels.pop() - except KeyError: + except KeyError: # pragma: no cover pass else: channel.close() def drain_events(self, connection, timeout=None): - cycle_seconds = len(self.channels) * self.interval + if not self.channels: + raise ValueError("No channels to drain events from.") loop = 0 while 1: try: item, channel = self.cycle.get() except Empty: - if timeout and cycle_seconds * loop >= timeout: + if timeout and loop * 0.1 >= timeout: raise socket.timeout() loop += 1 + sleep(0.1) else: break @@ -559,4 +562,4 @@ class Transport(base.Transport): self._callbacks[queue](message) def _drain_channel(self, channel): - return channel.drain_events(timeout=self.interval) + return channel.drain_events() diff --git a/kombu/transport/virtual/scheduling.py b/kombu/transport/virtual/scheduling.py index ca10a7e5..ceb1e02a 100644 --- a/kombu/transport/virtual/scheduling.py +++ b/kombu/transport/virtual/scheduling.py @@ -29,6 +29,8 @@ class FairCycle(object): return resource except IndexError: self.pos = 0 + if not self.resources: + raise self.predicate() def get(self): for tried in count(0): |