diff options
Diffstat (limited to 'kombu')
-rw-r--r-- | kombu/__init__.py | 3 | ||||
-rw-r--r-- | kombu/compression.py | 4 | ||||
-rw-r--r-- | kombu/connection.py | 12 | ||||
-rw-r--r-- | kombu/messaging.py | 1 | ||||
-rw-r--r-- | kombu/pidbox.py | 21 | ||||
-rw-r--r-- | kombu/tests/test_compression.py | 4 | ||||
-rw-r--r-- | kombu/tests/test_pools.py | 3 | ||||
-rw-r--r-- | kombu/tests/transport/test_redis.py | 6 | ||||
-rw-r--r-- | kombu/tests/utils/test_utils.py | 31 | ||||
-rw-r--r-- | kombu/transport/base.py | 3 | ||||
-rw-r--r-- | kombu/transport/librabbitmq.py | 16 | ||||
-rw-r--r-- | kombu/transport/mongodb.py | 7 | ||||
-rw-r--r-- | kombu/transport/pyamqp.py | 7 | ||||
-rw-r--r-- | kombu/transport/redis.py | 28 | ||||
-rw-r--r-- | kombu/utils/__init__.py | 13 | ||||
-rw-r--r-- | kombu/utils/eventio.py | 9 | ||||
-rw-r--r-- | kombu/utils/text.py | 27 |
17 files changed, 174 insertions, 21 deletions
diff --git a/kombu/__init__.py b/kombu/__init__.py index 5498e136..f6b73012 100644 --- a/kombu/__init__.py +++ b/kombu/__init__.py @@ -7,7 +7,7 @@ version_info_t = namedtuple( 'version_info_t', ('major', 'minor', 'micro', 'releaselevel', 'serial'), ) -VERSION = version_info_t(3, 0, 14, '', '') +VERSION = version_info_t(3, 0, 15, '', '') __version__ = '{0.major}.{0.minor}.{0.micro}{0.releaselevel}'.format(VERSION) __author__ = 'Ask Solem' __contact__ = 'ask@celeryproject.org' @@ -99,6 +99,7 @@ new_module.__dict__.update({ '__homepage__': __homepage__, '__docformat__': __docformat__, '__package__': package, + 'version_info_t': version_info_t, 'VERSION': VERSION}) if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover diff --git a/kombu/compression.py b/kombu/compression.py index 39622142..866433d8 100644 --- a/kombu/compression.py +++ b/kombu/compression.py @@ -7,7 +7,7 @@ Compression utilities. """ from __future__ import absolute_import -from kombu.utils.encoding import ensure_bytes, bytes_to_str +from kombu.utils.encoding import ensure_bytes import zlib @@ -67,7 +67,7 @@ def decompress(body, content_type): :param content_type: mime-type of compression method used. """ - return bytes_to_str(get_decoder(content_type)(body)) + return get_decoder(content_type)(body) register(zlib.compress, diff --git a/kombu/connection.py b/kombu/connection.py index d289e342..85b8f5e9 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -25,7 +25,7 @@ from kombu import exceptions from .five import Empty, range, string_t, text_t, LifoQueue as _LifoQueue from .log import get_logger from .transport import get_transport_cls, supports_librabbitmq -from .utils import cached_property, retry_over_time, shufflecycle +from .utils import cached_property, retry_over_time, shufflecycle, HashedSeq from .utils.compat import OrderedDict from .utils.functional import lazy from .utils.url import parse_url, urlparse @@ -565,9 +565,9 @@ class Connection(object): return OrderedDict(self._info()) def __eqhash__(self): - return hash('%s|%s|%s|%s|%s|%s' % ( - self.transport_cls, self.hostname, self.userid, - self.password, self.virtual_host, self.port)) + return HashedSeq(self.transport_cls, self.hostname, self.userid, + self.password, self.virtual_host, self.port, + repr(self.transport_options)) def as_uri(self, include_password=False, mask=''): """Convert connection parameters to URL form.""" @@ -732,6 +732,10 @@ class Connection(object): self.release() @property + def qos_semantics_matches_spec(self): + return self.transport.qos_semantics_matches_spec(self.connection) + + @property def connected(self): """Return true if the connection has been established.""" return (not self._closed and diff --git a/kombu/messaging.py b/kombu/messaging.py index a156a619..98d59d45 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -521,7 +521,6 @@ class Consumer(object): whole messages. :param apply_global: Apply new settings globally on all channels. - Currently not supported by RabbitMQ. """ return self.channel.basic_qos(prefetch_size, diff --git a/kombu/pidbox.py b/kombu/pidbox.py index cbd68317..5c70a382 100644 --- a/kombu/pidbox.py +++ b/kombu/pidbox.py @@ -135,7 +135,8 @@ class Node(object): def reply(self, data, exchange, routing_key, ticket, **kwargs): self.mailbox._publish_reply(data, exchange, routing_key, ticket, - channel=self.channel) + channel=self.channel, + serializer=self.mailbox.serializer) class Mailbox(object): @@ -161,8 +162,12 @@ class Mailbox(object): #: Only accepts json messages by default. accept = ['json'] + #: Message serializer + serializer = None + def __init__(self, namespace, - type='direct', connection=None, clock=None, accept=None): + type='direct', connection=None, clock=None, + accept=None, serializer=None): self.namespace = namespace self.connection = connection self.type = type @@ -172,6 +177,7 @@ class Mailbox(object): self._tls = local() self.unclaimed = defaultdict(deque) self.accept = self.accept if accept is None else accept + self.serializer = self.serializer if serializer is None else serializer def __call__(self, connection): bound = copy(self) @@ -242,7 +248,8 @@ class Mailbox(object): pass # queue probably deleted and no one is expecting a reply. def _publish(self, type, arguments, destination=None, - reply_ticket=None, channel=None, timeout=None): + reply_ticket=None, channel=None, timeout=None, + serializer=None): message = {'method': type, 'arguments': arguments, 'destination': destination} @@ -253,16 +260,18 @@ class Mailbox(object): message.update(ticket=reply_ticket, reply_to={'exchange': self.reply_exchange.name, 'routing_key': self.oid}) + serializer = serializer or self.serializer producer = Producer(chan, auto_declare=False) producer.publish( message, exchange=exchange.name, declare=[exchange], headers={'clock': self.clock.forward(), 'expires': time() + timeout if timeout else 0}, + serializer=serializer, ) def _broadcast(self, command, arguments=None, destination=None, reply=False, timeout=1, limit=None, - callback=None, channel=None): + callback=None, channel=None, serializer=None): if destination is not None and \ not isinstance(destination, (list, tuple)): raise ValueError( @@ -277,10 +286,12 @@ class Mailbox(object): if limit is None and destination: limit = destination and len(destination) or None + serializer = serializer or self.serializer self._publish(command, arguments, destination=destination, reply_ticket=reply_ticket, channel=chan, - timeout=timeout) + timeout=timeout, + serializer=serializer) if reply_ticket: return self._collect(reply_ticket, limit=limit, diff --git a/kombu/tests/test_compression.py b/kombu/tests/test_compression.py index e0cd4cbb..7d651ee2 100644 --- a/kombu/tests/test_compression.py +++ b/kombu/tests/test_compression.py @@ -34,7 +34,7 @@ class test_compression(Case): self.assertIn('application/x-bz2', encoders) def test_compress__decompress__zlib(self): - text = 'The Quick Brown Fox Jumps Over The Lazy Dog' + text = b'The Quick Brown Fox Jumps Over The Lazy Dog' c, ctype = compression.compress(text, 'zlib') self.assertNotEqual(text, c) d = compression.decompress(c, ctype) @@ -43,7 +43,7 @@ class test_compression(Case): def test_compress__decompress__bzip2(self): if not self.has_bzip2: raise SkipTest('bzip2 not available') - text = 'The Brown Quick Fox Over The Lazy Dog Jumps' + text = b'The Brown Quick Fox Over The Lazy Dog Jumps' c, ctype = compression.compress(text, 'bzip2') self.assertNotEqual(text, c) d = compression.decompress(c, ctype) diff --git a/kombu/tests/test_pools.py b/kombu/tests/test_pools.py index 89a6bd20..920c65a7 100644 --- a/kombu/tests/test_pools.py +++ b/kombu/tests/test_pools.py @@ -220,6 +220,9 @@ class test_fun_PoolGroup(Case): assert eqhash(c1) != eqhash(c2) assert eqhash(c1) == eqhash(c3) + c4 = Connection(c1u, transport_options={'confirm_publish': True}) + self.assertNotEqual(eqhash(c3), eqhash(c4)) + p1 = pools.connections[c1] p2 = pools.connections[c2] p3 = pools.connections[c3] diff --git a/kombu/tests/transport/test_redis.py b/kombu/tests/transport/test_redis.py index 9b5da64a..48f7c6be 100644 --- a/kombu/tests/transport/test_redis.py +++ b/kombu/tests/transport/test_redis.py @@ -776,8 +776,10 @@ class test_Channel(Case): with patch('kombu.transport.redis.Channel._create_client'): with Connection('redis+socket:///tmp/redis.sock') as conn: connparams = conn.default_channel._connparams() - self.assertEqual(connparams['connection_class'], - redis.redis.UnixDomainSocketConnection) + self.assertTrue(issubclass( + connparams['connection_class'], + redis.redis.UnixDomainSocketConnection, + )) self.assertEqual(connparams['path'], '/tmp/redis.sock') diff --git a/kombu/tests/utils/test_utils.py b/kombu/tests/utils/test_utils.py index 0d645d5c..0248a303 100644 --- a/kombu/tests/utils/test_utils.py +++ b/kombu/tests/utils/test_utils.py @@ -11,7 +11,9 @@ if sys.version_info >= (3, 0): else: from StringIO import StringIO, StringIO as BytesIO # noqa +from kombu import version_info_t from kombu import utils +from kombu.utils.text import version_string_as_tuple from kombu.five import string_t from kombu.tests.case import ( @@ -379,3 +381,32 @@ class test_shufflecycle(Case): next(cycle) finally: utils.repeat = prev_repeat + + +class test_version_string_as_tuple(Case): + + def test_versions(self): + self.assertTupleEqual( + version_string_as_tuple('3'), + version_info_t(3, 0, 0, '', ''), + ) + self.assertTupleEqual( + version_string_as_tuple('3.3'), + version_info_t(3, 3, 0, '', ''), + ) + self.assertTupleEqual( + version_string_as_tuple('3.3.1'), + version_info_t(3, 3, 1, '', ''), + ) + self.assertTupleEqual( + version_string_as_tuple('3.3.1a3'), + version_info_t(3, 3, 1, 'a3', ''), + ) + self.assertTupleEqual( + version_string_as_tuple('3.3.1a3-40c32'), + version_info_t(3, 3, 1, 'a3', '40c32'), + ) + self.assertEqual( + version_string_as_tuple('3.3.1.a3.40c32'), + version_info_t(3, 3, 1, 'a3', '40c32'), + ) diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 43429ae6..c226307e 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -152,6 +152,9 @@ class Transport(object): return _read + def qos_semantics_matches_spec(self, connection): + return True + def on_readable(self, connection, loop): reader = self.__reader if reader is None: diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py index 8fe06968..e1d7a999 100644 --- a/kombu/transport/librabbitmq.py +++ b/kombu/transport/librabbitmq.py @@ -11,6 +11,7 @@ from __future__ import absolute_import import os import socket +import warnings try: import librabbitmq as amqp @@ -24,9 +25,14 @@ except ImportError: # pragma: no cover from kombu.five import items, values from kombu.utils.amq_manager import get_manager +from kombu.utils.text import version_string_as_tuple from . import base +W_VERSION = """ + librabbitmq version too old to detect RabbitMQ version information + so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3 +""" DEFAULT_PORT = 5672 NO_SSL_ERROR = """\ @@ -150,6 +156,16 @@ class Transport(base.Transport): def get_manager(self, *args, **kwargs): return get_manager(self.client, *args, **kwargs) + def qos_semantics_matches_spec(self, connection): + try: + props = connection.server_properties + except AttributeError: + warnings.warn(UserWarning, W_VERSION) + else: + if props.get('product') == 'RabbitMQ': + return version_string_as_tuple(props['version']) < (3, 3) + return True + @property def default_connection_params(self): return {'userid': 'guest', 'password': 'guest', diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py index e9695e15..78af0f9f 100644 --- a/kombu/transport/mongodb.py +++ b/kombu/transport/mongodb.py @@ -55,14 +55,14 @@ class BroadcastCursor(object): def __iter__(self): return self - def next(self): + def __next__(self): while True: try: msg = next(self._cursor) - except pymongo.errors.OperationFailure, e: + except pymongo.errors.OperationFailure as exc: # In some cases tailed cursor can become invalid # and have to be reinitalized - if 'not valid at server' in e.message: + if 'not valid at server' in exc.message: self.purge() continue @@ -74,6 +74,7 @@ class BroadcastCursor(object): self._offset += 1 return msg + next = __next__ class Channel(virtual.Channel): diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py index 92d8ca03..01844305 100644 --- a/kombu/transport/pyamqp.py +++ b/kombu/transport/pyamqp.py @@ -11,6 +11,7 @@ import amqp from kombu.five import items from kombu.utils.amq_manager import get_manager +from kombu.utils.text import version_string_as_tuple from . import base @@ -129,6 +130,12 @@ class Transport(base.Transport): def heartbeat_check(self, connection, rate=2): return connection.heartbeat_tick(rate=rate) + def qos_semantics_matches_spec(self, connection): + props = connection.server_properties + if props.get('product') == 'RabbitMQ': + return version_string_as_tuple(props['version']) < (3, 3) + return True + @property def default_connection_params(self): return {'userid': 'guest', 'password': 'guest', diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index b21da368..3ee049cb 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -254,6 +254,12 @@ class MultiChannelPoller(object): def discard(self, channel): self._channels.discard(channel) + def _on_connection_disconnect(self, connection): + try: + self.poller.unregister(connection._sock) + except AttributeError: + pass + def _register(self, channel, client, type): if (channel, client, type) in self._chan_to_sock: self._unregister(channel, client, type) @@ -450,6 +456,10 @@ class Channel(virtual.Channel): if self._pool is not None: self._pool.disconnect() + def _on_connection_disconnect(self, connection): + if self.connection and self.connection.cycle: + self.connection.cycle._on_connection_disconnect(connection) + def _do_restore_message(self, payload, exchange, routing_key, client=None, leftmost=False): with self.conn_or_acquire(client) as client: @@ -778,6 +788,19 @@ class Channel(virtual.Channel): connparams.pop('port', None) connparams['db'] = self._prepare_virtual_host( connparams.pop('virtual_host', None)) + + channel = self + connection_cls = ( + connparams.get('connection_class') or + redis.Connection + ) + + class Connection(connection_cls): + def disconnect(self): + channel._on_connection_disconnect(self) + super(Connection, self).disconnect() + connparams['connection_class'] = Connection + return connparams def _create_client(self): @@ -905,6 +928,11 @@ class Transport(virtual.Transport): add_reader = loop.add_reader on_readable = self.on_readable + def _on_disconnect(connection): + if connection._sock: + loop.remove(connection._sock) + cycle._on_connection_disconnect = _on_disconnect + def on_poll_start(): cycle_poll_start() [add_reader(fd, on_readable, fd) for fd in cycle.fds] diff --git a/kombu/utils/__init__.py b/kombu/utils/__init__.py index aff46323..0745ddfe 100644 --- a/kombu/utils/__init__.py +++ b/kombu/utils/__init__.py @@ -101,6 +101,19 @@ def symbol_by_name(name, aliases={}, imp=None, package=None, return default +class HashedSeq(list): + """type used for hash() to make sure the hash is not generated + multiple times.""" + __slots__ = 'hashvalue' + + def __init__(self, *seq): + self[:] = seq + self.hashvalue = hash(seq) + + def __hash__(self): + return self.hashvalue + + def eqhash(o): try: return o.__eqhash__() diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py index fafb3ace..ed9ec06c 100644 --- a/kombu/utils/eventio.py +++ b/kombu/utils/eventio.py @@ -202,7 +202,14 @@ class _select(Poller): self.unregister(fd) def unregister(self, fd): - fd = fileno(fd) + try: + fd = fileno(fd) + except socket.error as exc: + # we don't know the previous fd of this object + # but it will be removed by the next poll iteration. + if get_errno(exc) in SELECT_BAD_FD: + return + raise self._rfd.discard(fd) self._wfd.discard(fd) self._efd.discard(fd) diff --git a/kombu/utils/text.py b/kombu/utils/text.py index 20444710..066b28af 100644 --- a/kombu/utils/text.py +++ b/kombu/utils/text.py @@ -3,6 +3,9 @@ from __future__ import absolute_import from difflib import SequenceMatcher +from kombu import version_info_t +from kombu.five import string_t + def fmatch_iter(needle, haystack, min_ratio=0.6): for key in haystack: @@ -18,3 +21,27 @@ def fmatch_best(needle, haystack, min_ratio=0.6): )[0][1] except IndexError: pass + + +def version_string_as_tuple(s): + v = _unpack_version(*s.split('.')) + # X.Y.3a1 -> (X, Y, 3, 'a1') + if isinstance(v.micro, string_t): + v = version_info_t(v.major, v.minor, *_splitmicro(*v[2:])) + # X.Y.3a1-40 -> (X, Y, 3, 'a1', '40') + if not v.serial and v.releaselevel and '-' in v.releaselevel: + v = version_info_t(*list(v[0:3]) + v.releaselevel.split('-')) + return v + + +def _unpack_version(major, minor=0, micro=0, releaselevel='', serial=''): + return version_info_t(int(major), int(minor), micro, releaselevel, serial) + + +def _splitmicro(micro, releaselevel='', serial=''): + for index, char in enumerate(micro): + if not char.isdigit(): + break + else: + return int(micro or 0), releaselevel, serial + return int(micro[:index]), micro[index:], serial |