summaryrefslogtreecommitdiff
path: root/kombu/messaging.py
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2017-02-16 10:44:48 -0800
committerAsk Solem <ask@celeryproject.org>2017-02-16 10:44:48 -0800
commite6fab2f68b562cf1400bd8167e9b755f0482aafe (patch)
treeb4d6be32dc8c62fa032e3c1a1a74636ac8360a38 /kombu/messaging.py
parentf2f7c67651106e77fb2db60ded134404ccc0a626 (diff)
downloadkombu-e6fab2f68b562cf1400bd8167e9b755f0482aafe.tar.gz
Diffstat (limited to 'kombu/messaging.py')
-rw-r--r--kombu/messaging.py288
1 files changed, 184 insertions, 104 deletions
diff --git a/kombu/messaging.py b/kombu/messaging.py
index 3dd03d14..67ccd15d 100644
--- a/kombu/messaging.py
+++ b/kombu/messaging.py
@@ -1,20 +1,23 @@
"""Sending and receiving messages."""
+import numbers
from itertools import count
-
+from typing import Any, Callable, Mapping, Sequence, Tuple, Union
+from amqp import ChannelT
+from . import types
+from .abstract import Entity
from .common import maybe_declare
from .compression import compress
from .connection import maybe_channel, is_connection
from .entity import Exchange, Queue, maybe_delivery_mode
from .exceptions import ContentDisallowed
from .serialization import dumps, prepare_accept_content
-from .utils import abstract
+from .types import ChannelArgT, ClientT, MessageT, QueueT
from .utils.functional import ChannelPromise, maybe_list
__all__ = ['Exchange', 'Queue', 'Producer', 'Consumer']
-@abstract.Producer.register
-class Producer:
+class Producer(types.ProducerT):
"""Message Producer.
Arguments:
@@ -34,7 +37,7 @@ class Producer:
"""
#: Default exchange
- exchange = None
+ exchange = None # type: Exchange
#: Default routing key.
routing_key = ''
@@ -56,9 +59,16 @@ class Producer:
#: default_channel).
__connection__ = None
- def __init__(self, channel, exchange=None, routing_key=None,
- serializer=None, auto_declare=None, compression=None,
- on_return=None):
+ _channel: ChannelT
+
+ def __init__(self,
+ channel: ChannelArgT,
+ exchange: Exchange = None,
+ routing_key: str = None,
+ serializer: str = None,
+ auto_declare: bool = None,
+ compression: str = None,
+ on_return: Callable = None):
self._channel = channel
self.exchange = exchange
self.routing_key = routing_key or self.routing_key
@@ -71,20 +81,17 @@ class Producer:
if auto_declare is not None:
self.auto_declare = auto_declare
- if self._channel:
- self.revive(self._channel)
-
- def __repr__(self):
+ def __repr__(self) -> str:
return '<Producer: {0._channel}>'.format(self)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Any, ...]:
return self.__class__, self.__reduce_args__()
- def __reduce_args__(self):
+ def __reduce_args__(self) -> Tuple[Any, ...]:
return (None, self.exchange, self.routing_key, self.serializer,
self.auto_declare, self.compression)
- def declare(self):
+ async def declare(self) -> None:
"""Declare the exchange.
Note:
@@ -92,16 +99,18 @@ class Producer:
the :attr:`auto_declare` flag is enabled.
"""
if self.exchange.name:
- self.exchange.declare()
+ await self.exchange.declare()
- def maybe_declare(self, entity, retry=False, **retry_policy):
+ async def maybe_declare(self, entity: Entity,
+ retry: bool = False, **retry_policy) -> None:
"""Declare exchange if not already declared during this session."""
if entity:
- return maybe_declare(entity, self.channel, retry, **retry_policy)
+ await maybe_declare(entity, self.channel, retry, **retry_policy)
- def _delivery_details(self, exchange, delivery_mode=None,
- maybe_delivery_mode=maybe_delivery_mode,
- Exchange=Exchange):
+ def _delivery_details(self, exchange: Union[Exchange, str],
+ delivery_mode: Union[int, str]=None,
+ maybe_delivery_mode: Callable = maybe_delivery_mode,
+ Exchange: Any = Exchange) -> Tuple[str, int]:
if isinstance(exchange, Exchange):
return exchange.name, maybe_delivery_mode(
delivery_mode or exchange.delivery_mode,
@@ -112,12 +121,23 @@ class Producer:
delivery_mode or self.exchange.delivery_mode,
)
- def publish(self, body, routing_key=None, delivery_mode=None,
- mandatory=False, immediate=False, priority=0,
- content_type=None, content_encoding=None, serializer=None,
- headers=None, compression=None, exchange=None, retry=False,
- retry_policy=None, declare=None, expiration=None,
- **properties):
+ async def publish(self, body: Any,
+ routing_key: str = None,
+ delivery_mode: Union[int, str] = None,
+ mandatory: bool = False,
+ immediate: bool = False,
+ priority: int = 0,
+ content_type: str = None,
+ content_encoding: str = None,
+ serializer: str = None,
+ headers: Mapping = None,
+ compression: str = None,
+ exchange: Union[Exchange, str] = None,
+ retry: bool = False,
+ retry_policy: Mapping = None,
+ declare: Sequence[Entity] = None,
+ expiration: numbers.Number = None,
+ **properties) -> None:
"""Publish message to the specified exchange.
Arguments:
@@ -172,54 +192,79 @@ class Producer:
declare.append(self.exchange)
if retry:
- _publish = self.connection.ensure(self, _publish, **retry_policy)
- return _publish(
+ _publish = await self.connection.ensure(
+ self, _publish, **retry_policy)
+ await _publish(
body, priority, content_type, content_encoding,
headers, properties, routing_key, mandatory, immediate,
exchange_name, declare,
)
- def _publish(self, body, priority, content_type, content_encoding,
- headers, properties, routing_key, mandatory,
- immediate, exchange, declare):
- channel = self.channel
+ async def _publish(self,
+ body: Any,
+ priority: int = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ headers: Mapping = None,
+ properties: Mapping = None,
+ routing_key: str = None,
+ mandatory: bool = None,
+ immediate: bool = None,
+ exchange: str = None,
+ declare: Sequence = None) -> None:
+ channel = await self._resolve_channel()
message = channel.prepare_message(
body, priority, content_type,
content_encoding, headers, properties,
)
if declare:
maybe_declare = self.maybe_declare
- [maybe_declare(entity) for entity in declare]
+ for entity in declare:
+ await maybe_declare(entity)
# handle autogenerated queue names for reply_to
reply_to = properties.get('reply_to')
if isinstance(reply_to, Queue):
properties['reply_to'] = reply_to.name
- return channel.basic_publish(
+ await channel.basic_publish(
message,
exchange=exchange, routing_key=routing_key,
mandatory=mandatory, immediate=immediate,
)
- def _get_channel(self):
+ async def _resolve_channel(self):
+ channel = self._channel
+ if isinstance(channel, ChannelPromise):
+ channel = self._channel = await channel.resolve()
+ if self.exchange:
+ self.exchange.revive(channel)
+ if self.on_return:
+ channel.events['basic_return'].add(self.on_return)
+ else:
+ channel = maybe_channel(channel)
+ await channel.open()
+ return channel
+
+ def _get_channel(self) -> ChannelT:
channel = self._channel
if isinstance(channel, ChannelPromise):
channel = self._channel = channel()
- self.exchange.revive(channel)
+ if self.exchange:
+ self.exchange.revive(channel)
if self.on_return:
channel.events['basic_return'].add(self.on_return)
return channel
- def _set_channel(self, channel):
+ def _set_channel(self, channel: ChannelT) -> None:
self._channel = channel
channel = property(_get_channel, _set_channel)
- def revive(self, channel):
+ async def revive(self, channel: ChannelT) -> None:
"""Revive the producer after connection loss."""
if is_connection(channel):
connection = channel
self.__connection__ = connection
- channel = ChannelPromise(lambda: connection.default_channel)
+ channel = ChannelPromise(connection)
if isinstance(channel, ChannelPromise):
self._channel = channel
self.exchange = self.exchange(channel)
@@ -230,18 +275,29 @@ class Producer:
self._channel.events['basic_return'].add(self.on_return)
self.exchange = self.exchange(channel)
- def __enter__(self):
+ def __enter__(self) -> 'Producer':
return self
- def __exit__(self, *exc_info):
+ async def __aenter__(self) -> 'Producer':
+ await self.revive(self.channel)
+ return self
+
+ def __exit__(self, *exc_info) -> None:
self.release()
- def release(self):
- pass
+ async def __aexit__(self, *exc_info) -> None:
+ self.release()
+
+ def release(self) -> None:
+ ...
close = release
- def _prepare(self, body, serializer=None, content_type=None,
- content_encoding=None, compression=None, headers=None):
+ def _prepare(self, body: Any,
+ serializer: str = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ compression: str = None,
+ headers: Mapping = None) -> Tuple[Any, str, str]:
# No content_type? Then we're serializing the data internally.
if not content_type:
@@ -267,15 +323,14 @@ class Producer:
return body, content_type, content_encoding
@property
- def connection(self):
+ def connection(self) -> ClientT:
try:
return self.__connection__ or self.channel.connection.client
except AttributeError:
pass
-@abstract.Consumer.register
-class Consumer:
+class Consumer(types.ConsumerT):
"""Message consumer.
Arguments:
@@ -358,13 +413,23 @@ class Consumer:
prefetch_count = None
#: Mapping of queues we consume from.
- _queues = None
+ _queues: Mapping[str, QueueT] = None
_tags = count(1) # global
-
- def __init__(self, channel, queues=None, no_ack=None, auto_declare=None,
- callbacks=None, on_decode_error=None, on_message=None,
- accept=None, prefetch_count=None, tag_prefix=None):
+ _active_tags: Mapping[str, str]
+
+ def __init__(
+ self,
+ channel: ChannelT,
+ queues: Sequence[QueueT] = None,
+ no_ack: bool = None,
+ auto_declare: bool = None,
+ callbacks: Sequence[Callable] = None,
+ on_decode_error: Callable = None,
+ on_message: Callable = None,
+ accept: Sequence[str] = None,
+ prefetch_count: int = None,
+ tag_prefix: str = None) -> None:
self.channel = channel
self.queues = maybe_list(queues or [])
self.no_ack = self.no_ack if no_ack is None else no_ack
@@ -380,21 +445,19 @@ class Consumer:
self.accept = prepare_accept_content(accept)
self.prefetch_count = prefetch_count
- if self.channel:
- self.revive(self.channel)
-
@property
- def queues(self):
+ def queues(self) -> Sequence[QueueT]:
return list(self._queues.values())
@queues.setter
- def queues(self, queues):
+ def queues(self, queues: Sequence[QueueT]) -> None:
self._queues = {q.name: q for q in queues}
- def revive(self, channel):
+ async def revive(self, channel: ChannelT) -> None:
"""Revive consumer after connection loss."""
self._active_tags.clear()
channel = self.channel = maybe_channel(channel)
+ await channel.open()
for qname, queue in self._queues.items():
# name may have changed after declare
self._queues.pop(qname, None)
@@ -402,22 +465,21 @@ class Consumer:
queue.revive(channel)
if self.auto_declare:
- self.declare()
+ await self.declare()
if self.prefetch_count is not None:
- self.qos(prefetch_count=self.prefetch_count)
+ await self.qos(prefetch_count=self.prefetch_count)
- def declare(self):
+ async def declare(self) -> None:
"""Declare queues, exchanges and bindings.
Note:
This is done automatically at instantiation
when :attr:`auto_declare` is set.
"""
- for queue in self._queues.values():
- queue.declare()
+ [await queue.declare() for queue in self._queues.values()]
- def register_callback(self, callback):
+ def register_callback(self, callback: Callable) -> None:
"""Register a new callback to be called when a message is received.
Note:
@@ -427,11 +489,26 @@ class Consumer:
"""
self.callbacks.append(callback)
- def __enter__(self):
+ def __enter__(self) -> 'Consumer':
+ self.revive(self.channel)
self.consume()
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ async def __aenter__(self) -> 'Consumer':
+ await self.revive(self.channel)
+ await self.consume()
+ return self
+
+ async def __aexit__(self, *exc_info) -> None:
+ if self.channel:
+ conn_errors = self.channel.connection.client.connection_errors
+ if not isinstance(exc_info[1], conn_errors):
+ try:
+ await self.cancel()
+ except Exception:
+ pass
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.channel:
conn_errors = self.channel.connection.client.connection_errors
if not isinstance(exc_val, conn_errors):
@@ -440,7 +517,7 @@ class Consumer:
except Exception:
pass
- def add_queue(self, queue):
+ async def add_queue(self, queue: QueueT) -> QueueT:
"""Add a queue to the list of queues to consume from.
Note:
@@ -449,11 +526,11 @@ class Consumer:
"""
queue = queue(self.channel)
if self.auto_declare:
- queue.declare()
+ await queue.declare()
self._queues[queue.name] = queue
return queue
- def consume(self, no_ack=None):
+ async def consume(self, no_ack: bool = None) -> None:
"""Start consuming messages.
Can be called multiple times, but note that while it
@@ -470,10 +547,10 @@ class Consumer:
H, T = queues[:-1], queues[-1]
for queue in H:
- self._basic_consume(queue, no_ack=no_ack, nowait=True)
- self._basic_consume(T, no_ack=no_ack, nowait=False)
+ await self._basic_consume(queue, no_ack=no_ack, nowait=True)
+ await self._basic_consume(T, no_ack=no_ack, nowait=False)
- def cancel(self):
+ async def cancel(self) -> None:
"""End all active queue consumers.
Note:
@@ -482,38 +559,36 @@ class Consumer:
"""
cancel = self.channel.basic_cancel
for tag in self._active_tags.values():
- cancel(tag)
+ await cancel(tag)
self._active_tags.clear()
close = cancel
- def cancel_by_queue(self, queue):
+ async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None:
"""Cancel consumer by queue name."""
- qname = queue.name if isinstance(queue, Queue) else queue
+ qname = queue.name if isinstance(queue, QueueT) else queue
try:
tag = self._active_tags.pop(qname)
except KeyError:
pass
else:
- self.channel.basic_cancel(tag)
+ await self.channel.basic_cancel(tag)
finally:
self._queues.pop(qname, None)
- def consuming_from(self, queue):
+ def consuming_from(self, queue: Union[QueueT, str]) -> bool:
"""Return :const:`True` if currently consuming from queue'."""
- name = queue
- if isinstance(queue, Queue):
- name = queue.name
+ name = queue.name if isinstance(queue, QueueT) else queue
return name in self._active_tags
- def purge(self):
+ async def purge(self) -> int:
"""Purge messages from all queues.
Warning:
This will *delete all ready messages*, there is no undo operation.
"""
- return sum(queue.purge() for queue in self._queues.values())
+ return sum(await queue.purge() for queue in self._queues.values())
- def flow(self, active):
+ async def flow(self, active: bool) -> None:
"""Enable/disable flow from peer.
This is a simple flow-control mechanism that a peer can use
@@ -524,9 +599,12 @@ class Consumer:
will finish sending the current content (if any), and then wait
until flow is reactivated.
"""
- self.channel.flow(active)
+ await self.channel.flow(active)
- def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
+ async def qos(self,
+ prefetch_size: int = 0,
+ prefetch_count: int = 0,
+ apply_global: bool = False) -> None:
"""Specify quality of service.
The client can request that messages should be sent in
@@ -550,11 +628,10 @@ class Consumer:
apply_global (bool): Apply new settings globally on all channels.
"""
- return self.channel.basic_qos(prefetch_size,
- prefetch_count,
- apply_global)
+ await self.channel.basic_qos(
+ prefetch_size, prefetch_count, apply_global)
- def recover(self, requeue=False):
+ async def recover(self, requeue: bool = False) -> None:
"""Redeliver unacknowledged messages.
Asks the broker to redeliver all unacknowledged messages
@@ -566,9 +643,9 @@ class Consumer:
server will attempt to requeue the message, potentially then
delivering it to an alternative subscriber.
"""
- return self.channel.basic_recover(requeue=requeue)
+ await self.channel.basic_recover(requeue=requeue)
- def receive(self, body, message):
+ async def receive(self, body: Any, message: MessageT) -> None:
"""Method called when a message is received.
This dispatches to the registered :attr:`callbacks`.
@@ -584,24 +661,27 @@ class Consumer:
callbacks = self.callbacks
if not callbacks:
raise NotImplementedError('Consumer does not have any callbacks')
- [callback(body, message) for callback in callbacks]
+ for callback in callbacks:
+ await callback(body, message)
- def _basic_consume(self, queue, consumer_tag=None,
- no_ack=no_ack, nowait=True):
+ async def _basic_consume(self, queue: QueueT,
+ consumer_tag: str = None,
+ no_ack: bool = no_ack,
+ nowait: bool = True) -> str:
tag = self._active_tags.get(queue.name)
if tag is None:
tag = self._add_tag(queue, consumer_tag)
- queue.consume(tag, self._receive_callback,
- no_ack=no_ack, nowait=nowait)
+ await queue.consume(tag, self._receive_callback,
+ no_ack=no_ack, nowait=nowait)
return tag
- def _add_tag(self, queue, consumer_tag=None):
+ def _add_tag(self, queue: QueueT, consumer_tag: str = None) -> str:
tag = consumer_tag or '{0}{1}'.format(
self.tag_prefix, next(self._tags))
self._active_tags[queue.name] = tag
return tag
- def _receive_callback(self, message):
+ async def _receive_callback(self, message: MessageT) -> None:
accept = self.accept
on_m, channel, decoded = self.on_message, self.channel, None
try:
@@ -611,20 +691,20 @@ class Consumer:
if accept is not None:
message.accept = accept
if message.errors:
- return message._reraise_error(self.on_decode_error)
+ message._reraise_error(self.on_decode_error)
decoded = None if on_m else message.decode()
except Exception as exc:
if not self.on_decode_error:
raise
self.on_decode_error(message, exc)
else:
- return on_m(message) if on_m else self.receive(decoded, message)
+ await on_m(message) if on_m else self.receive(decoded, message)
- def __repr__(self):
+ def __repr__(self) -> str:
return '<{name}: {0.queues}>'.format(self, name=type(self).__name__)
@property
- def connection(self):
+ def connection(self) -> ClientT:
try:
return self.channel.connection.client
except AttributeError: