diff options
author | Ask Solem <ask@celeryproject.org> | 2017-02-16 10:44:48 -0800 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2017-02-16 10:44:48 -0800 |
commit | e6fab2f68b562cf1400bd8167e9b755f0482aafe (patch) | |
tree | b4d6be32dc8c62fa032e3c1a1a74636ac8360a38 /kombu/messaging.py | |
parent | f2f7c67651106e77fb2db60ded134404ccc0a626 (diff) | |
download | kombu-e6fab2f68b562cf1400bd8167e9b755f0482aafe.tar.gz |
WIP5.0-devel
Diffstat (limited to 'kombu/messaging.py')
-rw-r--r-- | kombu/messaging.py | 288 |
1 files changed, 184 insertions, 104 deletions
diff --git a/kombu/messaging.py b/kombu/messaging.py index 3dd03d14..67ccd15d 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -1,20 +1,23 @@ """Sending and receiving messages.""" +import numbers from itertools import count - +from typing import Any, Callable, Mapping, Sequence, Tuple, Union +from amqp import ChannelT +from . import types +from .abstract import Entity from .common import maybe_declare from .compression import compress from .connection import maybe_channel, is_connection from .entity import Exchange, Queue, maybe_delivery_mode from .exceptions import ContentDisallowed from .serialization import dumps, prepare_accept_content -from .utils import abstract +from .types import ChannelArgT, ClientT, MessageT, QueueT from .utils.functional import ChannelPromise, maybe_list __all__ = ['Exchange', 'Queue', 'Producer', 'Consumer'] -@abstract.Producer.register -class Producer: +class Producer(types.ProducerT): """Message Producer. Arguments: @@ -34,7 +37,7 @@ class Producer: """ #: Default exchange - exchange = None + exchange = None # type: Exchange #: Default routing key. routing_key = '' @@ -56,9 +59,16 @@ class Producer: #: default_channel). __connection__ = None - def __init__(self, channel, exchange=None, routing_key=None, - serializer=None, auto_declare=None, compression=None, - on_return=None): + _channel: ChannelT + + def __init__(self, + channel: ChannelArgT, + exchange: Exchange = None, + routing_key: str = None, + serializer: str = None, + auto_declare: bool = None, + compression: str = None, + on_return: Callable = None): self._channel = channel self.exchange = exchange self.routing_key = routing_key or self.routing_key @@ -71,20 +81,17 @@ class Producer: if auto_declare is not None: self.auto_declare = auto_declare - if self._channel: - self.revive(self._channel) - - def __repr__(self): + def __repr__(self) -> str: return '<Producer: {0._channel}>'.format(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Any, ...]: return self.__class__, self.__reduce_args__() - def __reduce_args__(self): + def __reduce_args__(self) -> Tuple[Any, ...]: return (None, self.exchange, self.routing_key, self.serializer, self.auto_declare, self.compression) - def declare(self): + async def declare(self) -> None: """Declare the exchange. Note: @@ -92,16 +99,18 @@ class Producer: the :attr:`auto_declare` flag is enabled. """ if self.exchange.name: - self.exchange.declare() + await self.exchange.declare() - def maybe_declare(self, entity, retry=False, **retry_policy): + async def maybe_declare(self, entity: Entity, + retry: bool = False, **retry_policy) -> None: """Declare exchange if not already declared during this session.""" if entity: - return maybe_declare(entity, self.channel, retry, **retry_policy) + await maybe_declare(entity, self.channel, retry, **retry_policy) - def _delivery_details(self, exchange, delivery_mode=None, - maybe_delivery_mode=maybe_delivery_mode, - Exchange=Exchange): + def _delivery_details(self, exchange: Union[Exchange, str], + delivery_mode: Union[int, str]=None, + maybe_delivery_mode: Callable = maybe_delivery_mode, + Exchange: Any = Exchange) -> Tuple[str, int]: if isinstance(exchange, Exchange): return exchange.name, maybe_delivery_mode( delivery_mode or exchange.delivery_mode, @@ -112,12 +121,23 @@ class Producer: delivery_mode or self.exchange.delivery_mode, ) - def publish(self, body, routing_key=None, delivery_mode=None, - mandatory=False, immediate=False, priority=0, - content_type=None, content_encoding=None, serializer=None, - headers=None, compression=None, exchange=None, retry=False, - retry_policy=None, declare=None, expiration=None, - **properties): + async def publish(self, body: Any, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + mandatory: bool = False, + immediate: bool = False, + priority: int = 0, + content_type: str = None, + content_encoding: str = None, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + exchange: Union[Exchange, str] = None, + retry: bool = False, + retry_policy: Mapping = None, + declare: Sequence[Entity] = None, + expiration: numbers.Number = None, + **properties) -> None: """Publish message to the specified exchange. Arguments: @@ -172,54 +192,79 @@ class Producer: declare.append(self.exchange) if retry: - _publish = self.connection.ensure(self, _publish, **retry_policy) - return _publish( + _publish = await self.connection.ensure( + self, _publish, **retry_policy) + await _publish( body, priority, content_type, content_encoding, headers, properties, routing_key, mandatory, immediate, exchange_name, declare, ) - def _publish(self, body, priority, content_type, content_encoding, - headers, properties, routing_key, mandatory, - immediate, exchange, declare): - channel = self.channel + async def _publish(self, + body: Any, + priority: int = None, + content_type: str = None, + content_encoding: str = None, + headers: Mapping = None, + properties: Mapping = None, + routing_key: str = None, + mandatory: bool = None, + immediate: bool = None, + exchange: str = None, + declare: Sequence = None) -> None: + channel = await self._resolve_channel() message = channel.prepare_message( body, priority, content_type, content_encoding, headers, properties, ) if declare: maybe_declare = self.maybe_declare - [maybe_declare(entity) for entity in declare] + for entity in declare: + await maybe_declare(entity) # handle autogenerated queue names for reply_to reply_to = properties.get('reply_to') if isinstance(reply_to, Queue): properties['reply_to'] = reply_to.name - return channel.basic_publish( + await channel.basic_publish( message, exchange=exchange, routing_key=routing_key, mandatory=mandatory, immediate=immediate, ) - def _get_channel(self): + async def _resolve_channel(self): + channel = self._channel + if isinstance(channel, ChannelPromise): + channel = self._channel = await channel.resolve() + if self.exchange: + self.exchange.revive(channel) + if self.on_return: + channel.events['basic_return'].add(self.on_return) + else: + channel = maybe_channel(channel) + await channel.open() + return channel + + def _get_channel(self) -> ChannelT: channel = self._channel if isinstance(channel, ChannelPromise): channel = self._channel = channel() - self.exchange.revive(channel) + if self.exchange: + self.exchange.revive(channel) if self.on_return: channel.events['basic_return'].add(self.on_return) return channel - def _set_channel(self, channel): + def _set_channel(self, channel: ChannelT) -> None: self._channel = channel channel = property(_get_channel, _set_channel) - def revive(self, channel): + async def revive(self, channel: ChannelT) -> None: """Revive the producer after connection loss.""" if is_connection(channel): connection = channel self.__connection__ = connection - channel = ChannelPromise(lambda: connection.default_channel) + channel = ChannelPromise(connection) if isinstance(channel, ChannelPromise): self._channel = channel self.exchange = self.exchange(channel) @@ -230,18 +275,29 @@ class Producer: self._channel.events['basic_return'].add(self.on_return) self.exchange = self.exchange(channel) - def __enter__(self): + def __enter__(self) -> 'Producer': return self - def __exit__(self, *exc_info): + async def __aenter__(self) -> 'Producer': + await self.revive(self.channel) + return self + + def __exit__(self, *exc_info) -> None: self.release() - def release(self): - pass + async def __aexit__(self, *exc_info) -> None: + self.release() + + def release(self) -> None: + ... close = release - def _prepare(self, body, serializer=None, content_type=None, - content_encoding=None, compression=None, headers=None): + def _prepare(self, body: Any, + serializer: str = None, + content_type: str = None, + content_encoding: str = None, + compression: str = None, + headers: Mapping = None) -> Tuple[Any, str, str]: # No content_type? Then we're serializing the data internally. if not content_type: @@ -267,15 +323,14 @@ class Producer: return body, content_type, content_encoding @property - def connection(self): + def connection(self) -> ClientT: try: return self.__connection__ or self.channel.connection.client except AttributeError: pass -@abstract.Consumer.register -class Consumer: +class Consumer(types.ConsumerT): """Message consumer. Arguments: @@ -358,13 +413,23 @@ class Consumer: prefetch_count = None #: Mapping of queues we consume from. - _queues = None + _queues: Mapping[str, QueueT] = None _tags = count(1) # global - - def __init__(self, channel, queues=None, no_ack=None, auto_declare=None, - callbacks=None, on_decode_error=None, on_message=None, - accept=None, prefetch_count=None, tag_prefix=None): + _active_tags: Mapping[str, str] + + def __init__( + self, + channel: ChannelT, + queues: Sequence[QueueT] = None, + no_ack: bool = None, + auto_declare: bool = None, + callbacks: Sequence[Callable] = None, + on_decode_error: Callable = None, + on_message: Callable = None, + accept: Sequence[str] = None, + prefetch_count: int = None, + tag_prefix: str = None) -> None: self.channel = channel self.queues = maybe_list(queues or []) self.no_ack = self.no_ack if no_ack is None else no_ack @@ -380,21 +445,19 @@ class Consumer: self.accept = prepare_accept_content(accept) self.prefetch_count = prefetch_count - if self.channel: - self.revive(self.channel) - @property - def queues(self): + def queues(self) -> Sequence[QueueT]: return list(self._queues.values()) @queues.setter - def queues(self, queues): + def queues(self, queues: Sequence[QueueT]) -> None: self._queues = {q.name: q for q in queues} - def revive(self, channel): + async def revive(self, channel: ChannelT) -> None: """Revive consumer after connection loss.""" self._active_tags.clear() channel = self.channel = maybe_channel(channel) + await channel.open() for qname, queue in self._queues.items(): # name may have changed after declare self._queues.pop(qname, None) @@ -402,22 +465,21 @@ class Consumer: queue.revive(channel) if self.auto_declare: - self.declare() + await self.declare() if self.prefetch_count is not None: - self.qos(prefetch_count=self.prefetch_count) + await self.qos(prefetch_count=self.prefetch_count) - def declare(self): + async def declare(self) -> None: """Declare queues, exchanges and bindings. Note: This is done automatically at instantiation when :attr:`auto_declare` is set. """ - for queue in self._queues.values(): - queue.declare() + [await queue.declare() for queue in self._queues.values()] - def register_callback(self, callback): + def register_callback(self, callback: Callable) -> None: """Register a new callback to be called when a message is received. Note: @@ -427,11 +489,26 @@ class Consumer: """ self.callbacks.append(callback) - def __enter__(self): + def __enter__(self) -> 'Consumer': + self.revive(self.channel) self.consume() return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aenter__(self) -> 'Consumer': + await self.revive(self.channel) + await self.consume() + return self + + async def __aexit__(self, *exc_info) -> None: + if self.channel: + conn_errors = self.channel.connection.client.connection_errors + if not isinstance(exc_info[1], conn_errors): + try: + await self.cancel() + except Exception: + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.channel: conn_errors = self.channel.connection.client.connection_errors if not isinstance(exc_val, conn_errors): @@ -440,7 +517,7 @@ class Consumer: except Exception: pass - def add_queue(self, queue): + async def add_queue(self, queue: QueueT) -> QueueT: """Add a queue to the list of queues to consume from. Note: @@ -449,11 +526,11 @@ class Consumer: """ queue = queue(self.channel) if self.auto_declare: - queue.declare() + await queue.declare() self._queues[queue.name] = queue return queue - def consume(self, no_ack=None): + async def consume(self, no_ack: bool = None) -> None: """Start consuming messages. Can be called multiple times, but note that while it @@ -470,10 +547,10 @@ class Consumer: H, T = queues[:-1], queues[-1] for queue in H: - self._basic_consume(queue, no_ack=no_ack, nowait=True) - self._basic_consume(T, no_ack=no_ack, nowait=False) + await self._basic_consume(queue, no_ack=no_ack, nowait=True) + await self._basic_consume(T, no_ack=no_ack, nowait=False) - def cancel(self): + async def cancel(self) -> None: """End all active queue consumers. Note: @@ -482,38 +559,36 @@ class Consumer: """ cancel = self.channel.basic_cancel for tag in self._active_tags.values(): - cancel(tag) + await cancel(tag) self._active_tags.clear() close = cancel - def cancel_by_queue(self, queue): + async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None: """Cancel consumer by queue name.""" - qname = queue.name if isinstance(queue, Queue) else queue + qname = queue.name if isinstance(queue, QueueT) else queue try: tag = self._active_tags.pop(qname) except KeyError: pass else: - self.channel.basic_cancel(tag) + await self.channel.basic_cancel(tag) finally: self._queues.pop(qname, None) - def consuming_from(self, queue): + def consuming_from(self, queue: Union[QueueT, str]) -> bool: """Return :const:`True` if currently consuming from queue'.""" - name = queue - if isinstance(queue, Queue): - name = queue.name + name = queue.name if isinstance(queue, QueueT) else queue return name in self._active_tags - def purge(self): + async def purge(self) -> int: """Purge messages from all queues. Warning: This will *delete all ready messages*, there is no undo operation. """ - return sum(queue.purge() for queue in self._queues.values()) + return sum(await queue.purge() for queue in self._queues.values()) - def flow(self, active): + async def flow(self, active: bool) -> None: """Enable/disable flow from peer. This is a simple flow-control mechanism that a peer can use @@ -524,9 +599,12 @@ class Consumer: will finish sending the current content (if any), and then wait until flow is reactivated. """ - self.channel.flow(active) + await self.channel.flow(active) - def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False): + async def qos(self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: """Specify quality of service. The client can request that messages should be sent in @@ -550,11 +628,10 @@ class Consumer: apply_global (bool): Apply new settings globally on all channels. """ - return self.channel.basic_qos(prefetch_size, - prefetch_count, - apply_global) + await self.channel.basic_qos( + prefetch_size, prefetch_count, apply_global) - def recover(self, requeue=False): + async def recover(self, requeue: bool = False) -> None: """Redeliver unacknowledged messages. Asks the broker to redeliver all unacknowledged messages @@ -566,9 +643,9 @@ class Consumer: server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. """ - return self.channel.basic_recover(requeue=requeue) + await self.channel.basic_recover(requeue=requeue) - def receive(self, body, message): + async def receive(self, body: Any, message: MessageT) -> None: """Method called when a message is received. This dispatches to the registered :attr:`callbacks`. @@ -584,24 +661,27 @@ class Consumer: callbacks = self.callbacks if not callbacks: raise NotImplementedError('Consumer does not have any callbacks') - [callback(body, message) for callback in callbacks] + for callback in callbacks: + await callback(body, message) - def _basic_consume(self, queue, consumer_tag=None, - no_ack=no_ack, nowait=True): + async def _basic_consume(self, queue: QueueT, + consumer_tag: str = None, + no_ack: bool = no_ack, + nowait: bool = True) -> str: tag = self._active_tags.get(queue.name) if tag is None: tag = self._add_tag(queue, consumer_tag) - queue.consume(tag, self._receive_callback, - no_ack=no_ack, nowait=nowait) + await queue.consume(tag, self._receive_callback, + no_ack=no_ack, nowait=nowait) return tag - def _add_tag(self, queue, consumer_tag=None): + def _add_tag(self, queue: QueueT, consumer_tag: str = None) -> str: tag = consumer_tag or '{0}{1}'.format( self.tag_prefix, next(self._tags)) self._active_tags[queue.name] = tag return tag - def _receive_callback(self, message): + async def _receive_callback(self, message: MessageT) -> None: accept = self.accept on_m, channel, decoded = self.on_message, self.channel, None try: @@ -611,20 +691,20 @@ class Consumer: if accept is not None: message.accept = accept if message.errors: - return message._reraise_error(self.on_decode_error) + message._reraise_error(self.on_decode_error) decoded = None if on_m else message.decode() except Exception as exc: if not self.on_decode_error: raise self.on_decode_error(message, exc) else: - return on_m(message) if on_m else self.receive(decoded, message) + await on_m(message) if on_m else self.receive(decoded, message) - def __repr__(self): + def __repr__(self) -> str: return '<{name}: {0.queues}>'.format(self, name=type(self).__name__) @property - def connection(self): + def connection(self) -> ClientT: try: return self.channel.connection.client except AttributeError: |