diff options
author | Ask Solem <ask@celeryproject.org> | 2017-02-16 10:44:48 -0800 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2017-02-16 10:44:48 -0800 |
commit | e6fab2f68b562cf1400bd8167e9b755f0482aafe (patch) | |
tree | b4d6be32dc8c62fa032e3c1a1a74636ac8360a38 | |
parent | f2f7c67651106e77fb2db60ded134404ccc0a626 (diff) | |
download | kombu-5.0-devel.tar.gz |
WIP5.0-devel
-rw-r--r-- | Makefile | 3 | ||||
-rw-r--r-- | docs/reference/kombu.abstract.rst | 2 | ||||
-rw-r--r-- | kombu/__init__.py | 7 | ||||
-rw-r--r-- | kombu/abstract.py | 75 | ||||
-rw-r--r-- | kombu/async/timer.py | 10 | ||||
-rw-r--r-- | kombu/clocks.py | 93 | ||||
-rw-r--r-- | kombu/common.py | 93 | ||||
-rw-r--r-- | kombu/compression.py | 16 | ||||
-rw-r--r-- | kombu/connection.py | 213 | ||||
-rw-r--r-- | kombu/entity.py | 465 | ||||
-rw-r--r-- | kombu/message.py | 63 | ||||
-rw-r--r-- | kombu/messaging.py | 288 | ||||
-rw-r--r-- | kombu/pidbox.py | 191 | ||||
-rw-r--r-- | kombu/pools.py | 60 | ||||
-rw-r--r-- | kombu/resource.py | 55 | ||||
-rw-r--r-- | kombu/serialization.py | 10 | ||||
-rw-r--r-- | kombu/simple.py | 81 | ||||
-rw-r--r-- | kombu/transport/base.py | 101 | ||||
-rw-r--r-- | kombu/transport/pyamqp.py | 16 | ||||
-rw-r--r-- | kombu/transport/virtual/__init__.py | 2 | ||||
-rw-r--r-- | kombu/transport/virtual/base.py | 459 | ||||
-rw-r--r-- | kombu/transport/virtual/exchange.py | 54 | ||||
-rw-r--r-- | kombu/types.py | 893 | ||||
-rw-r--r-- | kombu/utils/abstract.py | 49 | ||||
-rw-r--r-- | kombu/utils/amq_manager.py | 4 | ||||
-rw-r--r-- | kombu/utils/collections.py | 20 | ||||
-rw-r--r-- | kombu/utils/compat.py | 17 | ||||
-rw-r--r-- | kombu/utils/debug.py | 9 | ||||
-rw-r--r-- | kombu/utils/div.py | 8 | ||||
-rw-r--r-- | kombu/utils/encoding.py | 7 | ||||
-rw-r--r-- | kombu/utils/eventio.py | 10 | ||||
-rw-r--r-- | kombu/utils/functional.py | 102 | ||||
-rw-r--r-- | kombu/utils/imports.py | 12 | ||||
-rw-r--r-- | kombu/utils/json.py | 14 | ||||
-rw-r--r-- | kombu/utils/limits.py | 17 | ||||
-rw-r--r-- | kombu/utils/scheduling.py | 11 | ||||
-rw-r--r-- | kombu/utils/text.py | 21 | ||||
-rw-r--r-- | kombu/utils/time.py | 6 | ||||
-rw-r--r-- | kombu/utils/typing.py | 21 | ||||
-rw-r--r-- | kombu/utils/url.py | 30 | ||||
-rw-r--r-- | kombu/utils/uuid.py | 3 | ||||
-rw-r--r-- | moved/transport/SLMQ.py (renamed from kombu/transport/SLMQ.py) | 0 | ||||
-rw-r--r-- | moved/transport/SQS.py (renamed from kombu/transport/SQS.py) | 0 | ||||
-rw-r--r-- | moved/transport/consul.py (renamed from kombu/transport/consul.py) | 0 | ||||
-rw-r--r-- | moved/transport/etcd.py (renamed from kombu/transport/etcd.py) | 8 | ||||
-rw-r--r-- | moved/transport/filesystem.py (renamed from kombu/transport/filesystem.py) | 0 | ||||
-rw-r--r-- | moved/transport/librabbitmq.py (renamed from kombu/transport/librabbitmq.py) | 0 | ||||
-rw-r--r-- | moved/transport/memory.py (renamed from kombu/transport/memory.py) | 0 | ||||
-rw-r--r-- | moved/transport/mongodb.py (renamed from kombu/transport/mongodb.py) | 0 | ||||
-rw-r--r-- | moved/transport/pyro.py (renamed from kombu/transport/pyro.py) | 0 | ||||
-rw-r--r-- | moved/transport/qpid.py (renamed from kombu/transport/qpid.py) | 7 | ||||
-rw-r--r-- | moved/transport/redis.py (renamed from kombu/transport/redis.py) | 18 | ||||
-rw-r--r-- | moved/transport/zookeeper.py (renamed from kombu/transport/zookeeper.py) | 0 | ||||
-rw-r--r-- | requirements/test.txt | 1 | ||||
-rw-r--r-- | t/integration/test_async.py | 51 | ||||
-rw-r--r-- | t/oldint/__init__.py (renamed from t/integration/__init__.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/__init__.py (renamed from t/integration/tests/__init__.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_SLMQ.py (renamed from t/integration/tests/test_SLMQ.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_SQS.py (renamed from t/integration/tests/test_SQS.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_amqp.py (renamed from t/integration/tests/test_amqp.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_librabbitmq.py (renamed from t/integration/tests/test_librabbitmq.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_mongodb.py (renamed from t/integration/tests/test_mongodb.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_pyamqp.py (renamed from t/integration/tests/test_pyamqp.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_qpid.py (renamed from t/integration/tests/test_qpid.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_redis.py (renamed from t/integration/tests/test_redis.py) | 0 | ||||
-rw-r--r-- | t/oldint/tests/test_zookeeper.py (renamed from t/integration/tests/test_zookeeper.py) | 0 | ||||
-rw-r--r-- | t/oldint/transport.py (renamed from t/integration/transport.py) | 0 | ||||
-rw-r--r-- | t/unit/__init__.py | 1 | ||||
-rw-r--r-- | t/unit/test_clocks.py | 5 | ||||
-rw-r--r-- | t/unit/test_entity.py | 6 | ||||
-rw-r--r-- | t/unit/test_exceptions.py | 3 | ||||
-rw-r--r-- | t/unit/transport/test_base.py | 1 | ||||
-rw-r--r-- | t/unit/transport/test_etcd.py | 7 | ||||
-rw-r--r-- | t/unit/transport/test_qpid.py | 21 | ||||
-rw-r--r-- | t/unit/transport/virtual/test_exchange.py | 5 | ||||
-rw-r--r-- | t/unit/utils/test_compat.py | 13 | ||||
-rw-r--r-- | t/unit/utils/test_div.py | 4 | ||||
-rw-r--r-- | t/unit/utils/test_encoding.py | 12 | ||||
-rw-r--r-- | t/unit/utils/test_scheduling.py | 4 | ||||
-rw-r--r-- | t/unit/utils/test_time.py | 3 | ||||
-rw-r--r-- | t/unit/utils/test_url.py | 3 | ||||
-rw-r--r-- | t/unit/utils/test_utils.py | 3 |
82 files changed, 2571 insertions, 1216 deletions
@@ -151,3 +151,6 @@ build: distcheck: lint test clean dist: readme contrib clean-dist build + +typecheck: + $(PYTHON) -m mypy --fast-parser --python-version=3.6 --ignore-missing-imports $(PROJ) diff --git a/docs/reference/kombu.abstract.rst b/docs/reference/kombu.abstract.rst index 436bc535..c51f5b6a 100644 --- a/docs/reference/kombu.abstract.rst +++ b/docs/reference/kombu.abstract.rst @@ -9,6 +9,6 @@ .. contents:: :local: - .. autoclass:: MaybeChannelBound + .. autoclass:: Entity :members: :undoc-members: diff --git a/kombu/__init__.py b/kombu/__init__.py index 848d2b99..6e0211e1 100644 --- a/kombu/__init__.py +++ b/kombu/__init__.py @@ -71,7 +71,7 @@ for module, items in all_by_module.items(): object_origins[item] = module -class module(ModuleType): +class _module(ModuleType): """Customized Python module.""" def __getattr__(self, name): @@ -100,7 +100,7 @@ except NameError: # pragma: no cover # keep a reference to this module so that it's not garbage collected old_module = sys.modules[__name__] -new_module = sys.modules[__name__] = module(__name__) +new_module = sys.modules[__name__] = _module(__name__) # type: ignore new_module.__dict__.update({ '__file__': __file__, '__path__': __path__, @@ -118,6 +118,7 @@ new_module.__dict__.update({ }) if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover - os.environ.update(KOMBU_LOG_CHANNEL='1', KOMBU_LOG_CONNECTION='1') + os.environ['KOMBU_LOG_CHANNEL'] = '1' + os.environ['KOMBU_LOG_CONNECTION'] = '1' from .utils import debug debug.setup_logging() diff --git a/kombu/abstract.py b/kombu/abstract.py index 916b1eea..4f824415 100644 --- a/kombu/abstract.py +++ b/kombu/abstract.py @@ -1,14 +1,13 @@ -import amqp.abstract - from copy import copy -from typing import Any, Dict +from typing import Any, Dict, Optional from typing import Sequence, Tuple # noqa - +from amqp.types import ChannelT from .connection import maybe_channel from .exceptions import NotBoundError +from .types import EntityT from .utils.functional import ChannelPromise -__all__ = ['Object', 'MaybeChannelBound'] +__all__ = ['Entity'] def unpickle_dict(cls: Any, kwargs: Dict) -> Any: @@ -19,13 +18,16 @@ def _any(v: Any) -> Any: return v -class Object: - """Common base class. - - Supports automatic kwargs->attributes handling, and cloning. - """ +class Entity(EntityT): + """Mixin for classes that can be bound to an AMQP channel.""" attrs = () # type: Sequence[Tuple[str, Any]] + #: Defines whether maybe_declare can skip declaring this entity twice. + can_cache_declaration = False + + _channel: ChannelT + _is_bound: bool = False + def __init__(self, *args, **kwargs) -> None: for name, type_ in self.attrs: value = kwargs.get(name) @@ -37,41 +39,15 @@ class Object: except AttributeError: setattr(self, name, None) - def as_dict(self, recurse: bool=False) -> Dict: - def f(obj: Any, type: Any) -> Any: - if recurse and isinstance(obj, Object): - return obj.as_dict(recurse=True) - return type(obj) if type and obj is not None else obj - return { - attr: f(getattr(self, attr), type) for attr, type in self.attrs - } - - def __reduce__(self) -> Any: - return unpickle_dict, (self.__class__, self.as_dict()) - - def __copy__(self) -> Any: - return self.__class__(**self.as_dict()) - - -class MaybeChannelBound(Object): - """Mixin for classes that can be bound to an AMQP channel.""" - - _channel = None # type: amqp.abstract.Channel - _is_bound = False # type: bool - - #: Defines whether maybe_declare can skip declaring this entity twice. - can_cache_declaration = False - - def __call__(self, channel: amqp.abstract.Channel) -> 'MaybeChannelBound': + def __call__(self, channel: ChannelT) -> EntityT: """`self(channel) -> self.bind(channel)`""" return self.bind(channel) - def bind(self, channel: amqp.abstract.Channel) -> 'MaybeChannelBound': + def bind(self, channel: ChannelT) -> EntityT: """Create copy of the instance that is bound to a channel.""" return copy(self).maybe_bind(channel) - def maybe_bind(self, - channel: amqp.abstract.Channel) -> 'MaybeChannelBound': + def maybe_bind(self, channel: Optional[ChannelT]) -> EntityT: """Bind instance to channel if not already bound.""" if not self.is_bound and channel: self._channel = maybe_channel(channel) @@ -79,7 +55,7 @@ class MaybeChannelBound(Object): self._is_bound = True return self - def revive(self, channel: amqp.abstract.Channel) -> None: + def revive(self, channel: ChannelT) -> None: """Revive channel after the connection has been re-established. Used by :meth:`~kombu.Connection.ensure`. @@ -96,20 +72,35 @@ class MaybeChannelBound(Object): def __repr__(self) -> str: return self._repr_entity(type(self).__name__) - def _repr_entity(self, item: str='') -> str: + def _repr_entity(self, item: str = '') -> str: item = item or type(self).__name__ if self.is_bound: return '<{0} bound to chan:{1}>'.format( item or type(self).__name__, self.channel.channel_id) return '<unbound {0}>'.format(item) + def as_dict(self, recurse: bool = False) -> Dict: + def f(obj: Any, type: Any) -> Any: + if recurse and isinstance(obj, Entity): + return obj.as_dict(recurse=True) + return type(obj) if type and obj is not None else obj + return { + attr: f(getattr(self, attr), type) for attr, type in self.attrs + } + + def __reduce__(self) -> Any: + return unpickle_dict, (self.__class__, self.as_dict()) + + def __copy__(self) -> Any: + return self.__class__(**self.as_dict()) + @property def is_bound(self) -> bool: """Flag set if the channel is bound.""" return self._is_bound and self._channel is not None @property - def channel(self) -> amqp.abstract.Channel: + def channel(self) -> ChannelT: """Current channel if the object is bound.""" channel = self._channel if channel is None: diff --git a/kombu/async/timer.py b/kombu/async/timer.py index 69d93a36..b67f4ddf 100644 --- a/kombu/async/timer.py +++ b/kombu/async/timer.py @@ -3,10 +3,10 @@ import heapq import sys -from collections import namedtuple from datetime import datetime from functools import total_ordering from time import monotonic +from typing import NamedTuple from weakref import proxy as weakrefproxy from vine.utils import wraps @@ -27,7 +27,13 @@ DEFAULT_MAX_INTERVAL = 2 EPOCH = datetime.utcfromtimestamp(0).replace(tzinfo=utc) IS_PYPY = hasattr(sys, 'pypy_version_info') -scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry')) + +class scheduled(NamedTuple): + """Information about scheduled item.""" + + eta: float + priority: int + entry: 'Entry' def to_timestamp(d, default_timezone=utc, time=monotonic): diff --git a/kombu/clocks.py b/kombu/clocks.py index cef4570b..ff82ffb0 100644 --- a/kombu/clocks.py +++ b/kombu/clocks.py @@ -2,9 +2,9 @@ from threading import Lock from itertools import islice from operator import itemgetter -from typing import Any, List, Sequence +from typing import Any, Callable, List, Sequence -__all__ = ['LamportClock', 'timetuple'] +__all__ = ['Clock', 'LamportClock', 'timetuple'] R_CLOCK = '_lamport(clock={0}, timestamp={1}, id={2} {3!r})' @@ -24,7 +24,7 @@ class timetuple(tuple): __slots__ = () def __new__(cls, clock: int, timestamp: float, - id: str, obj: Any=None) -> 'timetuple': + id: str, obj: Any = None) -> 'timetuple': return tuple.__new__(cls, (clock, timestamp, id, obj)) def __repr__(self) -> str: @@ -61,7 +61,54 @@ class timetuple(tuple): obj = property(itemgetter(3)) -class LamportClock: +class Clock: + + value: int = 0 + + def __init__(self, initial_value: int = 0, **kwargs) -> None: + self.value = initial_value + + def adjust(self, other: int) -> int: + raise NotImplementedError() + + def forward(self) -> int: + raise NotImplementedError() + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return '<{name}: {0.value}>'.format(self, name=type(self).__name__) + + def sort_heap(self, h: List[Sequence]) -> Any: + """Sort heap of events. + + List of tuples containing at least two elements, representing + an event, where the first element is the event's scalar clock value, + and the second element is the id of the process (usually + ``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])`` + + The list must already be sorted, which is why we refer to it as a + heap. + + The tuple will not be unpacked, so more than two elements can be + present. + + Will return the latest event. + """ + if h[0][0] == h[1][0]: + same = [] + for PN in zip(h, islice(h, 1, None)): + if PN[0][0] != PN[1][0]: + break # Prev and Next's clocks differ + same.append(PN[0]) + # return first item sorted by process id + return sorted(same, key=lambda event: event[1])[0] + # clock values unique, return first item + return h[0] + + +class LamportClock(Clock): """Lamport's logical clock. From Wikipedia: @@ -100,9 +147,10 @@ class LamportClock: #: The clocks current value. value = 0 - def __init__(self, initial_value: int=0, Lock: Any=Lock) -> None: - self.value = initial_value + def __init__(self, initial_value: int = 0, + Lock: Callable = Lock, **kwargs) -> None: self.mutex = Lock() + super().__init__(initial_value) def adjust(self, other: int) -> int: with self.mutex: @@ -113,36 +161,3 @@ class LamportClock: with self.mutex: self.value += 1 return self.value - - def sort_heap(self, h: List[Sequence]) -> Any: - """Sort heap of events. - - List of tuples containing at least two elements, representing - an event, where the first element is the event's scalar clock value, - and the second element is the id of the process (usually - ``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])`` - - The list must already be sorted, which is why we refer to it as a - heap. - - The tuple will not be unpacked, so more than two elements can be - present. - - Will return the latest event. - """ - if h[0][0] == h[1][0]: - same = [] - for PN in zip(h, islice(h, 1, None)): - if PN[0][0] != PN[1][0]: - break # Prev and Next's clocks differ - same.append(PN[0]) - # return first item sorted by process id - return sorted(same, key=lambda event: event[1])[0] - # clock values unique, return first item - return h[0] - - def __str__(self) -> str: - return str(self.value) - - def __repr__(self) -> str: - return '<LamportClock: {0.value}>'.format(self) diff --git a/kombu/common.py b/kombu/common.py index 0b1c8bbb..85134482 100644 --- a/kombu/common.py +++ b/kombu/common.py @@ -1,5 +1,4 @@ """Common Utilities.""" -import amqp.abstract import os import socket import threading @@ -15,13 +14,15 @@ from typing import ( ) from uuid import uuid4, uuid3, NAMESPACE_OID -from amqp import RecoverableConnectionError +from amqp import ChannelT, RecoverableConnectionError from .entity import Exchange, Queue from .log import get_logger from .serialization import registry as serializers +from .types import ( + ClientT, ConsumerT, EntityT, MessageT, ProducerT, ResourceT, +) from .utils import abstract -from .utils.typing import Timeout from .utils.uuid import uuid try: @@ -89,10 +90,10 @@ class Broadcast(Queue): attrs = Queue.attrs + (('queue', None),) - def __init__(self, name: Optional[str]=None, - queue: Optional[abstract.Entity]=None, - auto_delete: bool=True, - exchange: Optional[Union[Exchange, str]]=None, + def __init__(self, name: str = None, + queue: EntityT = None, + auto_delete: bool = True, + exchange: Union[Exchange, str] = None, alias: Optional[str]=None, **kwargs) -> None: queue = queue or 'bcast.{0}'.format(uuid()) return super().__init__( @@ -106,15 +107,15 @@ class Broadcast(Queue): ) -def declaration_cached(entity: abstract.Entity, - channel: amqp.abstract.Channel) -> bool: +def declaration_cached(entity: EntityT, + channel: ChannelT) -> bool: return entity in channel.connection.client.declared_entities -def maybe_declare(entity: abstract.Entity, - channel: amqp.abstract.Channel = None, - retry: bool = False, - **retry_policy) -> bool: +async def maybe_declare(entity: EntityT, + channel: ChannelT = None, + retry: bool = False, + **retry_policy) -> bool: """Declare entity (cached).""" is_bound = entity.is_bound orig = entity @@ -135,18 +136,18 @@ def maybe_declare(entity: abstract.Entity, return False if retry: - return _imaybe_declare(entity, declared, ident, - channel, orig, **retry_policy) - return _maybe_declare(entity, declared, ident, channel, orig) + return await _imaybe_declare(entity, declared, ident, + channel, orig, **retry_policy) + return await _maybe_declare(entity, declared, ident, channel, orig) -def _maybe_declare(entity: abstract.Entity, declared: MutableSet, ident: int, - channel: Optional[amqp.abstract.Channel], - orig: abstract.Entity = None) -> bool: +async def _maybe_declare(entity: EntityT, declared: MutableSet, ident: int, + channel: Optional[ChannelT], + orig: EntityT = None) -> bool: channel = channel or entity.channel if not channel.connection: raise RecoverableConnectionError('channel disconnected') - entity.declare(channel=channel) + await entity.declare(channel=channel) if declared is not None and ident: declared.add(ident) if orig is not None: @@ -154,22 +155,22 @@ def _maybe_declare(entity: abstract.Entity, declared: MutableSet, ident: int, return True -def _imaybe_declare(entity: abstract.Entity, - declared: MutableSet, - ident: int, - channel: Optional[amqp.abstract.Channel], - orig: Optional[abstract.Entity]=None, - **retry_policy) -> bool: - return entity.channel.connection.client.ensure( +async def _imaybe_declare(entity: EntityT, + declared: MutableSet, + ident: int, + channel: Optional[ChannelT], + orig: EntityT = None, + **retry_policy) -> bool: + return await entity.channel.connection.client.ensure( entity, _maybe_declare, **retry_policy)( entity, declared, ident, channel, orig) def drain_consumer( - consumer: abstract.Consumer, + consumer: ConsumerT, limit: int = 1, - timeout: Timeout = None, - callbacks: Sequence[Callable] = None) -> Iterator[abstract.Message]: + timeout: float = None, + callbacks: Sequence[Callable] = None) -> Iterator[MessageT]: """Drain messages from consumer instance.""" acc = deque() @@ -189,12 +190,12 @@ def drain_consumer( def itermessages( conn: abstract.Connection, - channel: Optional[amqp.abstract.Channel], - queue: abstract.Entity, + channel: Optional[ChannelT], + queue: EntityT, limit: int = 1, - timeout: Timeout = None, + timeout: float = None, callbacks: Sequence[Callable] = None, - **kwargs) -> Iterator[abstract.Message]: + **kwargs) -> Iterator[MessageT]: """Iterator over messages.""" return drain_consumer( conn.Consumer(queues=[queue], channel=channel, **kwargs), @@ -202,9 +203,9 @@ def itermessages( ) -def eventloop(conn: abstract.Connection, +def eventloop(conn: ClientT, limit: int = None, - timeout: Timeout = None, + timeout: float = None, ignore_timeouts: bool = False) -> Iterator[Any]: """Best practice generator wrapper around ``Connection.drain_events``. @@ -243,9 +244,9 @@ def eventloop(conn: abstract.Connection, raise -def send_reply(exchange: Union[abstract.Exchange, str], - req: abstract.Message, msg: Any, - producer: abstract.Producer = None, +def send_reply(exchange: Union[Exchange, str], + req: MessageT, msg: Any, + producer: ProducerT = None, retry: bool = False, retry_policy: Mapping = None, **props) -> None: """Send reply for request. @@ -271,9 +272,9 @@ def send_reply(exchange: Union[abstract.Exchange, str], def collect_replies( - conn: abstract.Connection, - channel: Optional[amqp.abstract.Channel], - queue: abstract.Entity, + conn: ClientT, + channel: Optional[ChannelT], + queue: EntityT, *args, **kwargs) -> Iterator[Any]: """Generator collecting replies from ``queue``.""" no_ack = kwargs.setdefault('no_ack', True) @@ -305,7 +306,7 @@ def _ignore_errors(conn) -> Iterator: pass -def ignore_errors(conn: abstract.Connection, +def ignore_errors(conn: ClientT, fun: Callable = None, *args, **kwargs) -> Any: """Ignore connection and channel errors. @@ -339,14 +340,14 @@ def ignore_errors(conn: abstract.Connection, return _ignore_errors(conn) -def revive_connection(connection: abstract.Connection, - channel: amqp.abstract.Channel, +def revive_connection(connection: ClientT, + channel: ChannelT, on_revive: Callable = None) -> None: if on_revive: on_revive(channel) -def insured(pool: abstract.Resource, +def insured(pool: ResourceT, fun: Callable, args: Sequence, kwargs: Dict, errback: Callable = None, diff --git a/kombu/compression.py b/kombu/compression.py index 8e5ad99a..cec0282c 100644 --- a/kombu/compression.py +++ b/kombu/compression.py @@ -11,15 +11,15 @@ __all__ = [ 'get_decoder', 'compress', 'decompress', ] -TEncoder = Callable[[bytes], bytes] -TDecoder = Callable[[bytes], bytes] +EncoderT = Callable[[bytes], bytes] +DecoderT = Callable[[bytes], bytes] -_aliases = {} # type: MutableMapping[str, str] -_encoders = {} # type: MutableMapping[str, TEncoder] -_decoders = {} # type: MutableMapping[str, TDecoder] +_aliases: MutableMapping[str, str] = {} +_encoders: MutableMapping[str, EncoderT] = {} +_decoders: MutableMapping[str, DecoderT] = {} -def register(encoder: TEncoder, decoder: TDecoder, content_type: str, +def register(encoder: EncoderT, decoder: DecoderT, content_type: str, aliases: Sequence[str]=[]) -> None: """Register new compression method. @@ -42,13 +42,13 @@ def encoders() -> Sequence[str]: return list(_encoders) -def get_encoder(t: str) -> Tuple[TEncoder, str]: +def get_encoder(t: str) -> Tuple[EncoderT, str]: """Get encoder by alias name.""" t = _aliases.get(t, t) return _encoders[t], t -def get_decoder(t: str) -> TDecoder: +def get_decoder(t: str) -> DecoderT: """Get decoder by alias name.""" return _decoders[_aliases.get(t, t)] diff --git a/kombu/connection.py b/kombu/connection.py index c4c6ff7d..11970ea1 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -7,6 +7,9 @@ from collections import OrderedDict from contextlib import contextmanager from itertools import count, cycle from operator import itemgetter +from typing import Any, Callable, Iterable, Mapping, Set, Sequence, Union + +from amqp.types import ChannelT, ConnectionT # jython breaks on relative import for .exceptions for some reason # (Issue #112) @@ -15,7 +18,7 @@ from kombu import exceptions from .log import get_logger from .resource import Resource from .transport import get_transport_cls, supports_librabbitmq -from .utils import abstract +from .types import ClientT, EntityT, ResourceT, TransportT from .utils.collections import HashedSeq from .utils.functional import dictfilter, lazy, retry_over_time, shufflecycle from .utils.objects import cached_property @@ -27,22 +30,21 @@ logger = get_logger(__name__) roundrobin_failover = cycle -resolve_aliases = { +resolve_aliases: Mapping[str, str] = { 'pyamqp': 'amqp', 'librabbitmq': 'amqp', } -failover_strategies = { +failover_strategies: Mapping[str, Callable] = { 'round-robin': roundrobin_failover, 'shuffle': shufflecycle, } -_log_connection = os.environ.get('KOMBU_LOG_CONNECTION', False) -_log_channel = os.environ.get('KOMBU_LOG_CHANNEL', False) +_log_connection = bool(os.environ.get('KOMBU_LOG_CONNECTION', False)) +_log_channel = bool(os.environ.get('KOMBU_LOG_CHANNEL', False)) -@abstract.Connection.register -class Connection: +class Connection(ClientT): """A connection to the broker. Example: @@ -106,48 +108,65 @@ class Connection: :keyword port: Default port if not provided in the URL. """ - port = None - virtual_host = '/' - connect_timeout = 5 + hostname: str + userid: str + password: str + ssl: Any + login_method: str + port: int = None + virtual_host: str = '/' + connect_timeout: float = 5.0 - _closed = None - _connection = None - _default_channel = None - _transport = None - _logger = False - uri_prefix = None + uri_prefix: str = None #: The cache of declared entities is per connection, #: in case the server loses data. - declared_entities = None + declared_entities: Set[EntityT] = None #: Iterator returning the next broker URL to try in the event #: of connection failure (initialized by :attr:`failover_strategy`). - cycle = None + cycle: Iterable = None #: Additional transport specific options, #: passed on to the transport instance. - transport_options = None + transport_options: Mapping = None #: Strategy used to select new hosts when reconnecting after connection #: failure. One of "round-robin", "shuffle" or any custom iterator #: constantly yielding new URLs to try. - failover_strategy = 'round-robin' + failover_strategy: str = 'round-robin' #: Heartbeat value, currently only supported by the py-amqp transport. - heartbeat = None + heartbeat: float = None - resolve_aliases = resolve_aliases - failover_strategies = failover_strategies + resolve_aliases: Mapping[str, str] = resolve_aliases + failover_strategies: Mapping[str, Callable] = failover_strategies - hostname = userid = password = ssl = login_method = None + _closed: bool = None + _connection: ConnectionT = None + _default_channel: ChannelT = None + _transport: TransportT = None + _logger: bool = False + _initial_params: Mapping - def __init__(self, hostname='localhost', userid=None, - password=None, virtual_host=None, port=None, insist=False, - ssl=False, transport=None, connect_timeout=5, - transport_options=None, login_method=None, uri_prefix=None, - heartbeat=0, failover_strategy='round-robin', - alternates=None, **kwargs): + def __init__( + self, + hostname: str = 'localhost', + userid: str = None, + password: str = None, + virtual_host: str = None, + port: int = None, + insist: bool = False, + ssl: Any = None, + transport: Union[type, str] = None, + connect_timeout: float = 5.0, + transport_options: Mapping = None, + login_method: str = None, + uri_prefix: str = None, + heartbeat: float = None, + failover_strategy: str = 'round-robin', + alternates: Sequence[str] = None, + **kwargs): alt = [] if alternates is None else alternates # have to spell the args out, just to get nice docstrings :( params = self._initial_params = { @@ -208,7 +227,7 @@ class Connection: self.declared_entities = set() - def switch(self, url): + def switch(self, url: str) -> None: """Switch connection parameters to use a new URL. Note: @@ -219,14 +238,15 @@ class Connection: self._closed = False self._init_params(**dict(self._initial_params, **parse_url(url))) - def maybe_switch_next(self): + def maybe_switch_next(self) -> None: """Switch to next URL given by the current failover strategy.""" if self.cycle: self.switch(next(self.cycle)) - def _init_params(self, hostname, userid, password, virtual_host, port, - insist, ssl, transport, connect_timeout, - login_method, heartbeat): + def _init_params(self, hostname: str, userid: str, password: str, + virtual_host: str, port: int, insist: bool, ssl: Any, + transport: Union[type, str], connect_timeout: float, + login_method: str, heartbeat: float) -> None: transport = transport or 'amqp' if transport == 'amqp' and supports_librabbitmq(): transport = 'librabbitmq' @@ -245,18 +265,20 @@ class Connection: def register_with_event_loop(self, loop): self.transport.register_with_event_loop(self.connection, loop) - def _debug(self, msg, *args, **kwargs): + def _debug(self, msg: str, *args, **kwargs) -> None: if self._logger: # pragma: no cover fmt = '[Kombu connection:{id:#x}] {msg}' logger.debug(fmt.format(id=id(self), msg=str(msg)), *args, **kwargs) - def connect(self): + async def connect(self) -> None: """Establish connection to server immediately.""" self._closed = False - return self.connection + if self._connection is None: + await self._start_connection() + return self._connection - def channel(self): + def channel(self) -> ChannelT: """Create and return a new channel.""" self._debug('create channel') chan = self.transport.create_channel(self.connection) @@ -266,7 +288,7 @@ class Connection: '[Kombu channel:{0.channel_id}] ') return chan - def heartbeat_check(self, rate=2): + async def heartbeat_check(self, rate: int = 2) -> None: """Check heartbeats. Allow the transport to perform any periodic tasks @@ -283,9 +305,9 @@ class Connection: is called every 3 / 2 seconds, then the rate is 2. This value is currently unused by any transports. """ - return self.transport.heartbeat_check(self.connection, rate=rate) + await self.transport.heartbeat_check(self.connection, rate=rate) - def drain_events(self, **kwargs): + async def drain_events(self, timeout: float = None, **kwargs) -> None: """Wait for a single event from the server. Arguments: @@ -294,30 +316,31 @@ class Connection: Raises: socket.timeout: if the timeout is exceeded. """ - return self.transport.drain_events(self.connection, **kwargs) + await self.transport.drain_events( + self.connection, timeout=timeout, **kwargs) - def maybe_close_channel(self, channel): + async def maybe_close_channel(self, channel: ChannelT) -> None: """Close given channel, but ignore connection and channel errors.""" try: - channel.close() + await channel.close() except (self.connection_errors + self.channel_errors): pass - def _do_close_self(self): + async def _do_close_self(self): # Close only connection and channel(s), but not transport. self.declared_entities.clear() if self._default_channel: - self.maybe_close_channel(self._default_channel) + await self.maybe_close_channel(self._default_channel) if self._connection: try: - self.transport.close_connection(self._connection) + await self.transport.close_connection(self._connection) except self.connection_errors + (AttributeError, socket.error): pass self._connection = None - def _close(self): + async def _close(self): """Really close connection, even if part of a connection pool.""" - self._do_close_self() + await self._do_close_self() self._do_close_transport() self._debug('closed') self._closed = True @@ -327,7 +350,7 @@ class Connection: self._transport.client = None self._transport = None - def collect(self, socket_timeout=None): + async def collect(self, socket_timeout=None): # amqp requires communication to close, we don't need that just # to clear out references, Transport._collect can also be implemented # by other transports that want fast after fork @@ -337,7 +360,7 @@ class Connection: _timeo = socket.getdefaulttimeout() socket.setdefaulttimeout(socket_timeout) try: - self._do_close_self() + await self._do_close_self() except socket.timeout: pass finally: @@ -349,14 +372,22 @@ class Connection: self.declared_entities.clear() self._connection = None - def release(self): + async def release(self): """Close the connection (if open).""" - self._close() - close = release + await self._close() + + async def close(self): + await self._close() - def ensure_connection(self, errback=None, max_retries=None, - interval_start=2, interval_step=2, interval_max=30, - callback=None, reraise_as_library_errors=True): + async def ensure_connection( + self, + errback=None, + max_retries=None, + interval_start=2, + interval_step=2, + interval_max=30, + callback=None, + reraise_as_library_errors=True): """Ensure we have a connection to the server. If not retry establishing the connection with the settings @@ -395,10 +426,12 @@ class Connection: if not reraise_as_library_errors: ctx = self._dummy_context with ctx(): - retry_over_time(self.connect, self.recoverable_connection_errors, - (), {}, on_error, max_retries, - interval_start, interval_step, interval_max, - callback) + await retry_over_time( + self.connect, self.recoverable_connection_errors, + (), {}, on_error, max_retries, + interval_start, interval_step, interval_max, + callback, + ) return self @contextmanager @@ -423,19 +456,23 @@ class Connection: """Return true if the cycle is complete after number of `retries`.""" return not (retries + 1) % len(self.alt) if self.alt else True - def revive(self, new_channel): + async def revive(self, new_channel): """Revive connection after connection re-established.""" if self._default_channel and new_channel is not self._default_channel: - self.maybe_close_channel(self._default_channel) + await self.maybe_close_channel(self._default_channel) self._default_channel = None def _default_ensure_callback(self, exc, interval): logger.error("Ensure: Operation error: %r. Retry in %ss", exc, interval, exc_info=True) - def ensure(self, obj, fun, errback=None, max_retries=None, - interval_start=1, interval_step=1, interval_max=1, - on_revive=None): + async def ensure(self, obj, fun, + errback=None, + max_retries=None, + interval_start=1, + interval_step=1, + interval_max=1, + on_revive=None): """Ensure operation completes. Regardless of any channel/connection errors occurring. @@ -475,7 +512,7 @@ class Connection: ... errback=errback, max_retries=3) >>> publish({'hello': 'world'}, routing_key='dest') """ - def _ensured(*args, **kwargs): + async def _ensured(*args, **kwargs): got_connection = 0 conn_errors = self.recoverable_connection_errors chan_errors = self.recoverable_channel_errors @@ -485,7 +522,7 @@ class Connection: with self._reraise_as_library_errors(): for retries in count(0): # for infinity try: - return fun(*args, **kwargs) + return await fun(*args, **kwargs) except conn_errors as exc: if got_connection and not has_modern_errors: # transport can not distinguish between @@ -497,21 +534,21 @@ class Connection: raise self._debug('ensure connection error: %r', exc, exc_info=1) - self.collect() + await self.collect() errback and errback(exc, 0) remaining_retries = None if max_retries is not None: remaining_retries = max(max_retries - retries, 1) - self.ensure_connection( + await self.ensure_connection( errback, remaining_retries, interval_start, interval_step, interval_max, reraise_as_library_errors=False, ) channel = self.default_channel - obj.revive(channel) + await obj.revive(channel) if on_revive: - on_revive(channel) + await on_revive(channel) got_connection += 1 except chan_errors as exc: if max_retries is not None and retries > max_retries: @@ -641,7 +678,7 @@ class Connection: sanitize=not include_password, mask=mask, ) - def Pool(self, limit=None, **kwargs): + def Pool(self, limit=None, **kwargs) -> ResourceT: """Pool of connections. See Also: @@ -667,7 +704,7 @@ class Connection: """ return ConnectionPool(self, limit, **kwargs) - def ChannelPool(self, limit=None, **kwargs): + def ChannelPool(self, limit=None, **kwargs) -> ResourceT: """Pool of channels. See Also: @@ -746,9 +783,9 @@ class Connection: return SimpleBuffer(channel or self, name, no_ack, queue_opts, exchange_opts, **kwargs) - def _establish_connection(self): + async def _establish_connection(self): self._debug('establishing connection...') - conn = self.transport.establish_connection() + conn = await self.transport.establish_connection() self._debug('connection established: %r', self) return conn @@ -765,11 +802,19 @@ class Connection: return self.__class__, tuple(self.info().values()), None def __enter__(self): + self.connect() return self def __exit__(self, *args): self.release() + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + @property def qos_semantics_matches_spec(self): return self.transport.qos_semantics_matches_spec(self.connection) @@ -781,6 +826,12 @@ class Connection: self._connection is not None and self.transport.verify_connection(self._connection)) + async def _start_connection(self): + self.declared_entities.clear() + self._default_channel = None + self._connection = await self._establish_connection() + self._closed = False + @property def connection(self): """The underlying connection object. @@ -789,13 +840,7 @@ class Connection: This instance is transport specific, so do not depend on the interface of this object. """ - if not self._closed: - if not self.connected: - self.declared_entities.clear() - self._default_channel = None - self._connection = self._establish_connection() - self._closed = False - return self._connection + return self._connection @property def default_channel(self): diff --git a/kombu/entity.py b/kombu/entity.py index 52ae605e..80522b55 100644 --- a/kombu/entity.py +++ b/kombu/entity.py @@ -1,42 +1,46 @@ """Exchange and Queue declarations.""" import numbers - -from .abstract import MaybeChannelBound +from typing import ( + Any, Callable, Dict, Mapping, Set, Sequence, Tuple, Union, +) +from amqp.protocol import queue_declare_ok_t +from amqp.types import ChannelT +from .types import EntityT from .exceptions import ContentDisallowed from .serialization import prepare_accept_content -from .utils import abstract - -TRANSIENT_DELIVERY_MODE = 1 -PERSISTENT_DELIVERY_MODE = 2 -DELIVERY_MODES = {'transient': TRANSIENT_DELIVERY_MODE, - 'persistent': PERSISTENT_DELIVERY_MODE} +from .types import BindingT, ExchangeT, MessageT, QueueT __all__ = ['Exchange', 'Queue', 'binding', 'maybe_delivery_mode'] -INTERNAL_EXCHANGE_PREFIX = ('amq.',) +TRANSIENT_DELIVERY_MODE = 1 +PERSISTENT_DELIVERY_MODE = 2 +DELIVERY_MODES: Mapping[str, int] = { + 'transient': TRANSIENT_DELIVERY_MODE, + 'persistent': PERSISTENT_DELIVERY_MODE, +} +INTERNAL_EXCHANGE_PREFIX: Tuple[str] = ('amq.',) -def _reprstr(s): - s = repr(s) - if isinstance(s, str) and s.startswith("u'"): - return s[2:-1] - return s[1:-1] +def _reprstr(s: Any) -> str: + return repr(s).strip("'") -def pretty_bindings(bindings): +def pretty_bindings(bindings: Sequence) -> str: return '[{0}]'.format(', '.join(map(str, bindings))) def maybe_delivery_mode( - v, modes=DELIVERY_MODES, default=PERSISTENT_DELIVERY_MODE): + v: Union[numbers.Integral, str], + *, + modes: Mapping[str, int] = DELIVERY_MODES, + default: int = PERSISTENT_DELIVERY_MODE) -> int: """Get delivery mode by name (or none if undefined).""" if v: return v if isinstance(v, numbers.Integral) else modes[v] return default -@abstract.Entity.register -class Exchange(MaybeChannelBound): +class Exchange(ExchangeT): """An Exchange declaration. Arguments: @@ -137,7 +141,7 @@ class Exchange(MaybeChannelBound): durable = True auto_delete = False passive = False - delivery_mode = None + delivery_mode = None # type: int no_declare = False attrs = ( @@ -151,21 +155,29 @@ class Exchange(MaybeChannelBound): ('no_declare', bool), ) - def __init__(self, name='', type='', channel=None, **kwargs): + def __init__(self, + name: str = '', + type: str = '', + channel: ChannelT = None, + **kwargs) -> None: super().__init__(**kwargs) self.name = name or self.name self.type = type or self.type self.maybe_bind(channel) - def __hash__(self): + def __hash__(self) -> int: return hash('E|%s' % (self.name,)) - def _can_declare(self): + def _can_declare(self) -> bool: return not self.no_declare and ( self.name and not self.name.startswith( INTERNAL_EXCHANGE_PREFIX)) - def declare(self, nowait=False, passive=None, channel=None): + async def declare( + self, + nowait: bool = False, + passive: bool = None, + channel: ChannelT = None) -> None: """Declare the exchange. Creates the exchange on the broker, unless passive is set @@ -177,14 +189,20 @@ class Exchange(MaybeChannelBound): """ if self._can_declare(): passive = self.passive if passive is None else passive - return (channel or self.channel).exchange_declare( + await (channel or self.channel).exchange_declare( exchange=self.name, type=self.type, durable=self.durable, auto_delete=self.auto_delete, arguments=self.arguments, nowait=nowait, passive=passive, ) - def bind_to(self, exchange='', routing_key='', - arguments=None, nowait=False, channel=None, **kwargs): + async def bind_to( + self, + exchange: Union[str, ExchangeT] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None, + **kwargs) -> None: """Bind the exchange to another exchange. Arguments: @@ -192,9 +210,9 @@ class Exchange(MaybeChannelBound): will not block waiting for a response. Default is :const:`False`. """ - if isinstance(exchange, Exchange): + if isinstance(exchange, ExchangeT): exchange = exchange.name - return (channel or self.channel).exchange_bind( + await (channel or self.channel).exchange_bind( destination=self.name, source=exchange, routing_key=routing_key, @@ -202,12 +220,17 @@ class Exchange(MaybeChannelBound): arguments=arguments, ) - def unbind_from(self, source='', routing_key='', - nowait=False, arguments=None, channel=None): + async def unbind_from( + self, + source: Union[str, ExchangeT] = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None, + channel: ChannelT = None) -> None: """Delete previously created exchange binding from the server.""" - if isinstance(source, Exchange): + if isinstance(source, ExchangeT): source = source.name - return (channel or self.channel).exchange_unbind( + await (channel or self.channel).exchange_unbind( destination=self.name, source=source, routing_key=routing_key, @@ -215,43 +238,11 @@ class Exchange(MaybeChannelBound): arguments=arguments, ) - def Message(self, body, delivery_mode=None, properties=None, **kwargs): - """Create message instance to be sent with :meth:`publish`. - - Arguments: - body (Any): Message body. - - delivery_mode (bool): Set custom delivery mode. - Defaults to :attr:`delivery_mode`. - - priority (int): Message priority, 0 to broker configured - max priority, where higher is better. - - content_type (str): The messages content_type. If content_type - is set, no serialization occurs as it is assumed this is either - a binary object, or you've done your own serialization. - Leave blank if using built-in serialization as our library - properly sets content_type. - - content_encoding (str): The character set in which this object - is encoded. Use "binary" if sending in raw binary objects. - Leave blank if using built-in serialization as our library - properly sets content_encoding. - - properties (Dict): Message properties. - - headers (Dict): Message headers. - """ - # XXX This method is unused by kombu itself AFAICT [ask]. - properties = {} if properties is None else properties - properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode) - return self.channel.prepare_message( - body, - properties=properties, - **kwargs) - - def publish(self, message, routing_key=None, mandatory=False, - immediate=False, exchange=None): + async def publish(self, message: Union[MessageT, str, bytes], + routing_key: str = None, + mandatory: bool = False, + immediate: bool = False, + exchange: str = None) -> None: """Publish message. Arguments: @@ -264,7 +255,7 @@ class Exchange(MaybeChannelBound): if isinstance(message, str): message = self.Message(message) exchange = exchange or self.name - return self.channel.basic_publish( + await self.channel.basic_publish( message, exchange=exchange, routing_key=routing_key, @@ -272,7 +263,10 @@ class Exchange(MaybeChannelBound): immediate=immediate, ) - def delete(self, if_unused=False, nowait=False): + async def delete( + self, + if_unused: bool = False, + nowait: bool = False) -> None: """Delete the exchange declaration on server. Arguments: @@ -281,14 +275,20 @@ class Exchange(MaybeChannelBound): nowait (bool): If set the server will not respond, and a response will not be waited for. Default is :const:`False`. """ - return self.channel.exchange_delete(exchange=self.name, - if_unused=if_unused, - nowait=nowait) + await self.channel.exchange_delete( + exchange=self.name, + if_unused=if_unused, + nowait=nowait, + ) - def binding(self, routing_key='', arguments=None, unbind_arguments=None): + def binding( + self, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> 'binding': return binding(self, routing_key, arguments, unbind_arguments) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Exchange): return (self.name == other.name and self.type == other.type and @@ -298,28 +298,70 @@ class Exchange(MaybeChannelBound): self.delivery_mode == other.delivery_mode) return NotImplemented - def __ne__(self, other): - return not self.__eq__(other) + def __ne__(self, other: Any) -> bool: + ret = self.__eq__(other) + if ret is NotImplemented: + return ret + return not ret - def __repr__(self): + def __repr__(self) -> str: return self._repr_entity(self) - def __str__(self): - return 'Exchange {0}({1})'.format( + def __str__(self) -> str: + return '{name} {0}({1})'.format( _reprstr(self.name) or repr(''), self.type, + name=type(self).__name__, ) + def Message( + self, body: Any, + delivery_mode: Union[str, int] = None, + properties: Mapping = None, + **kwargs) -> Any: + """Create message instance to be sent with :meth:`publish`. + + Arguments: + body (Any): Message body. + + delivery_mode (Union[str, int]): Set custom delivery mode. + Defaults to :attr:`delivery_mode`. + + priority (int): Message priority, 0 to broker configured + max priority, where higher is better. + + content_type (str): The messages content_type. If content_type + is set, no serialization occurs as it is assumed this is either + a binary object, or you've done your own serialization. + Leave blank if using built-in serialization as our library + properly sets content_type. + + content_encoding (str): The character set in which this object + is encoded. Use "binary" if sending in raw binary objects. + Leave blank if using built-in serialization as our library + properly sets content_encoding. + + properties (Dict): Message properties. + + headers (Dict): Message headers. + """ + # XXX This method is unused by kombu itself AFAICT [ask]. + properties = {} if properties is None else properties + properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode) + return self.channel.prepare_message( + body, + properties=properties, + **kwargs) + @property - def can_cache_declaration(self): + def can_cache_declaration(self) -> bool: return not self.auto_delete -@abstract.Entity.register -class binding: +class binding(BindingT): """Represents a queue or exchange binding. Arguments: - exchange (Exchange): Exchange to bind to. + exchange (ExchangeT): Exchange to bind to. routing_key (str): Routing key used as binding key. arguments (Dict): Arguments for bind operation. unbind_arguments (Dict): Arguments for unbind operation. @@ -332,45 +374,58 @@ class binding: ('unbind_arguments', None) ) - def __init__(self, exchange=None, routing_key='', - arguments=None, unbind_arguments=None): + def __init__( + self, + exchange: ExchangeT = None, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> None: self.exchange = exchange self.routing_key = routing_key self.arguments = arguments self.unbind_arguments = unbind_arguments - def declare(self, channel, nowait=False): + def declare(self, channel: ChannelT, nowait: bool = False) -> None: """Declare destination exchange.""" if self.exchange and self.exchange.name: self.exchange.declare(channel=channel, nowait=nowait) - def bind(self, entity, nowait=False, channel=None): + def bind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: """Bind entity to this binding.""" - entity.bind_to(exchange=self.exchange, - routing_key=self.routing_key, - arguments=self.arguments, - nowait=nowait, - channel=channel) + entity.bind_to( + exchange=self.exchange, + routing_key=self.routing_key, + arguments=self.arguments, + nowait=nowait, + channel=channel, + ) - def unbind(self, entity, nowait=False, channel=None): + def unbind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: """Unbind entity from this binding.""" - entity.unbind_from(self.exchange, - routing_key=self.routing_key, - arguments=self.unbind_arguments, - nowait=nowait, - channel=channel) + entity.unbind_from( + self.exchange, + routing_key=self.routing_key, + arguments=self.unbind_arguments, + nowait=nowait, + channel=channel + ) - def __repr__(self): + def __repr__(self) -> str: return '<binding: {0}>'.format(self) - def __str__(self): + def __str__(self) -> str: return '{0}->{1}'.format( _reprstr(self.exchange.name), _reprstr(self.routing_key), ) -@abstract.Entity.register -class Queue(MaybeChannelBound): +class Queue(QueueT): """A Queue declaration. Arguments: @@ -561,9 +616,15 @@ class Queue(MaybeChannelBound): ('max_priority', int) ) - def __init__(self, name='', exchange=None, routing_key='', - channel=None, bindings=None, on_declared=None, - **kwargs): + def __init__( + self, + name: str = '', + exchange: ExchangeT = None, + routing_key: str = '', + channel: ChannelT = None, + bindings: Sequence[BindingT] = None, + on_declared: Callable = None, + **kwargs) -> None: super().__init__(**kwargs) self.name = name or self.name self.exchange = exchange or self.exchange @@ -582,44 +643,56 @@ class Queue(MaybeChannelBound): self.auto_delete = True self.maybe_bind(channel) - def bind(self, channel): + def bind(self, channel: ChannelT) -> EntityT: on_declared = self.on_declared bound = super().bind(channel) bound.on_declared = on_declared return bound - def __hash__(self): + def __hash__(self) -> int: return hash('Q|%s' % (self.name,)) - def when_bound(self): + def when_bound(self) -> None: if self.exchange: self.exchange = self.exchange(self.channel) - def declare(self, nowait=False, channel=None): + async def declare(self, + nowait: bool = False, + channel: ChannelT = None) -> str: """Declare queue and exchange then binds queue to exchange.""" if not self.no_declare: # - declare main binding. - self._create_exchange(nowait=nowait, channel=channel) - self._create_queue(nowait=nowait, channel=channel) - self._create_bindings(nowait=nowait, channel=channel) + await self._create_exchange(nowait=nowait, channel=channel) + await self._create_queue(nowait=nowait, channel=channel) + await self._create_bindings(nowait=nowait, channel=channel) return self.name - def _create_exchange(self, nowait=False, channel=None): + async def _create_exchange(self, + nowait: bool = False, + channel: ChannelT = None) -> None: if self.exchange: - self.exchange.declare(nowait=nowait, channel=channel) + await self.exchange.declare(nowait=nowait, channel=channel) - def _create_queue(self, nowait=False, channel=None): - self.queue_declare(nowait=nowait, passive=False, channel=channel) + async def _create_queue(self, + nowait: bool = False, + channel: ChannelT = None) -> None: + await self.queue_declare(nowait=nowait, passive=False, channel=channel) if self.exchange and self.exchange.name: - self.queue_bind(nowait=nowait, channel=channel) + await self.queue_bind(nowait=nowait, channel=channel) - def _create_bindings(self, nowait=False, channel=None): + async def _create_bindings(self, + nowait: bool = False, + channel: ChannelT = None) -> None: for B in self.bindings: channel = channel or self.channel - B.declare(channel) - B.bind(self, nowait=nowait, channel=channel) - - def queue_declare(self, nowait=False, passive=False, channel=None): + await B.declare(channel) + await B.bind(self, nowait=nowait, channel=channel) + + async def queue_declare( + self, + nowait: bool = False, + passive: bool = False, + channel: ChannelT = None) -> queue_declare_ok_t: """Declare queue on the server. Arguments: @@ -637,7 +710,7 @@ class Queue(MaybeChannelBound): max_length_bytes=self.max_length_bytes, max_priority=self.max_priority, ) - ret = channel.queue_declare( + ret = await channel.queue_declare( queue=self.name, passive=passive, durable=self.durable, @@ -652,18 +725,26 @@ class Queue(MaybeChannelBound): self.on_declared(*ret) return ret - def queue_bind(self, nowait=False, channel=None): + async def queue_bind( + self, + nowait: bool = False, + channel: ChannelT = None) -> None: """Create the queue binding on the server.""" - return self.bind_to(self.exchange, self.routing_key, - self.binding_arguments, - channel=channel, nowait=nowait) - - def bind_to(self, exchange='', routing_key='', - arguments=None, nowait=False, channel=None): + await self.bind_to( + self.exchange, self.routing_key, self.binding_arguments, + channel=channel, nowait=nowait) + + async def bind_to( + self, + exchange: Union[str, ExchangeT] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: if isinstance(exchange, Exchange): exchange = exchange.name - return (channel or self.channel).queue_bind( + await (channel or self.channel).queue_bind( queue=self.name, exchange=exchange, routing_key=routing_key, @@ -671,7 +752,10 @@ class Queue(MaybeChannelBound): nowait=nowait, ) - def get(self, no_ack=None, accept=None): + async def get( + self, + no_ack: bool = None, + accept: Set[str] = None) -> MessageT: """Poll the server for a new message. This method provides direct access to the messages in a @@ -686,10 +770,10 @@ class Queue(MaybeChannelBound): Arguments: no_ack (bool): If enabled the broker will automatically ack messages. - accept (Set[str]): Custom list of accepted content types. + accept (Container): Custom list of accepted content types. """ no_ack = self.no_ack if no_ack is None else no_ack - message = self.channel.basic_get(queue=self.name, no_ack=no_ack) + message = await self.channel.basic_get(queue=self.name, no_ack=no_ack) if message is not None: m2p = getattr(self.channel, 'message_to_python', None) if m2p: @@ -699,13 +783,17 @@ class Queue(MaybeChannelBound): message.accept = prepare_accept_content(accept) return message - def purge(self, nowait=False): + async def purge(self, nowait: bool = False) -> int: """Remove all ready messages from the queue.""" return self.channel.queue_purge(queue=self.name, nowait=nowait) or 0 - def consume(self, consumer_tag='', callback=None, - no_ack=None, nowait=False): + async def consume( + self, + consumer_tag: str = '', + callback: Callable = None, + no_ack: bool = None, + nowait: bool = False) -> None: """Start a queue consumer. Consumers last as long as the channel they were created on, or @@ -726,7 +814,7 @@ class Queue(MaybeChannelBound): """ if no_ack is None: no_ack = self.no_ack - return self.channel.basic_consume( + await self.channel.basic_consume( queue=self.name, no_ack=no_ack, consumer_tag=consumer_tag or '', @@ -734,11 +822,15 @@ class Queue(MaybeChannelBound): nowait=nowait, arguments=self.consumer_arguments) - def cancel(self, consumer_tag): + async def cancel(self, consumer_tag: str) -> None: """Cancel a consumer by consumer tag.""" - return self.channel.basic_cancel(consumer_tag) + await self.channel.basic_cancel(consumer_tag) - def delete(self, if_unused=False, if_empty=False, nowait=False): + async def delete( + self, + if_unused: bool = False, + if_empty: bool = False, + nowait: bool = False) -> None: """Delete the queue. Arguments: @@ -751,19 +843,30 @@ class Queue(MaybeChannelBound): nowait (bool): Do not wait for a reply. """ - return self.channel.queue_delete(queue=self.name, - if_unused=if_unused, - if_empty=if_empty, - nowait=nowait) - - def queue_unbind(self, arguments=None, nowait=False, channel=None): - return self.unbind_from(self.exchange, self.routing_key, - arguments, nowait, channel) + await self.channel.queue_delete( + queue=self.name, + if_unused=if_unused, + if_empty=if_empty, + nowait=nowait, + ) - def unbind_from(self, exchange='', routing_key='', - arguments=None, nowait=False, channel=None): + async def queue_unbind( + self, + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + await self.unbind_from(self.exchange, self.routing_key, + arguments, nowait, channel) + + async def unbind_from( + self, + exchange: str = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: """Unbind queue by deleting the binding from the server.""" - return (channel or self.channel).queue_unbind( + await (channel or self.channel).queue_unbind( queue=self.name, exchange=exchange.name, routing_key=routing_key, @@ -771,7 +874,7 @@ class Queue(MaybeChannelBound): nowait=nowait, ) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Queue): return (self.name == other.name and self.exchange == other.exchange and @@ -785,9 +888,12 @@ class Queue(MaybeChannelBound): return NotImplemented def __ne__(self, other): - return not self.__eq__(other) + ret = self.__eq__(other) + if ret is NotImplemented: + return ret + return not ret - def __repr__(self): + def __repr__(self) -> str: if self.bindings: return self._repr_entity('Queue {name} -> {bindings}'.format( name=_reprstr(self.name), @@ -801,30 +907,31 @@ class Queue(MaybeChannelBound): ) @property - def can_cache_declaration(self): + def can_cache_declaration(self) -> bool: return not self.auto_delete @classmethod - def from_dict(self, queue, - exchange=None, - exchange_type=None, - binding_key=None, - routing_key=None, - delivery_mode=None, - bindings=None, - durable=None, - queue_durable=None, - exchange_durable=None, - auto_delete=None, - queue_auto_delete=None, - exchange_auto_delete=None, - exchange_arguments=None, - queue_arguments=None, - binding_arguments=None, - consumer_arguments=None, - exclusive=None, - no_ack=None, - **options): + def from_dict( + self, queue: str, + exchange: str = None, + exchange_type: str = None, + binding_key: str = None, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + bindings: Sequence = None, + durable: bool = None, + queue_durable: bool = None, + exchange_durable: bool = None, + auto_delete: bool = None, + queue_auto_delete: bool = None, + exchange_auto_delete: bool = None, + exchange_arguments: Mapping = None, + queue_arguments: Mapping = None, + binding_arguments: Mapping = None, + consumer_arguments: Mapping = None, + exclusive: bool = None, + no_ack: bool = None, + **options) -> QueueT: return Queue( queue, exchange=Exchange( @@ -850,7 +957,7 @@ class Queue(MaybeChannelBound): bindings=bindings, ) - def as_dict(self, recurse=False): + def as_dict(self, recurse: bool = False) -> Dict: res = super().as_dict(recurse) if not recurse: return res diff --git a/kombu/message.py b/kombu/message.py index dd6f0b4e..9d74571b 100644 --- a/kombu/message.py +++ b/kombu/message.py @@ -1,10 +1,12 @@ """Message class.""" +import logging import sys - +from typing import Any, Callable, List, Mapping, Set, Tuple +from amqp.types import ChannelT +from . import types from .compression import decompress from .exceptions import MessageStateError from .serialization import loads -from .utils import abstract from .utils.functional import dictfilter __all__ = ['Message'] @@ -13,7 +15,7 @@ ACK_STATES = {'ACK', 'REJECTED', 'REQUEUED'} IS_PYPY = hasattr(sys, 'pypy_version_info') -@abstract.Message.register +@types.MessageT.register class Message: """Base class for received messages. @@ -48,7 +50,7 @@ class Message: MessageStateError = MessageStateError - errors = None + errors: List[Any] = None if not IS_PYPY: # pragma: no cover __slots__ = ( @@ -58,10 +60,19 @@ class Message: 'body', '_decoded_cache', 'accept', '__dict__', ) - def __init__(self, body=None, delivery_tag=None, - content_type=None, content_encoding=None, delivery_info={}, - properties=None, headers=None, postencode=None, - accept=None, channel=None, **kwargs): + def __init__( + self, + body: Any = None, + delivery_tag: str = None, + content_type: str = None, + content_encoding: str = None, + delivery_info: Mapping = None, + properties: Mapping = None, + headers: Mapping = None, + postencode: str = None, + accept: Set[str] = None, + channel: ChannelT = None, + **kwargs) -> None: self.errors = [] if self.errors is None else self.errors self.channel = channel self.delivery_tag = delivery_tag @@ -88,7 +99,7 @@ class Message: self.errors.append(sys.exc_info()) self.body = body - def _reraise_error(self, callback=None): + def _reraise_error(self, callback: Callable = None) -> None: try: raise self.errors[0][1].with_traceback(self.errors[0][2]) except Exception as exc: @@ -96,7 +107,7 @@ class Message: raise callback(self, exc) - def ack(self, multiple=False): + async def ack(self, multiple: bool = False) -> None: """Acknowledge this message as being processed. This will remove the message from the queue. @@ -120,24 +131,28 @@ class Message: raise self.MessageStateError( 'Message already acknowledged with state: {0._state}'.format( self)) - self.channel.basic_ack(self.delivery_tag, multiple=multiple) + await self.channel.basic_ack(self.delivery_tag, multiple=multiple) self._state = 'ACK' - def ack_log_error(self, logger, errors, multiple=False): + async def ack_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + multiple: bool = False) -> None: try: - self.ack(multiple=multiple) + await self.ack(multiple=multiple) except errors as exc: logger.critical("Couldn't ack %r, reason:%r", self.delivery_tag, exc, exc_info=True) - def reject_log_error(self, logger, errors, requeue=False): + async def reject_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + requeue: bool = False) -> None: try: - self.reject(requeue=requeue) + await self.reject(requeue=requeue) except errors as exc: logger.critical("Couldn't reject %r, reason: %r", self.delivery_tag, exc, exc_info=True) - def reject(self, requeue=False): + async def reject(self, requeue: bool = False) -> None: """Reject this message. The message will be discarded by the server. @@ -153,10 +168,10 @@ class Message: raise self.MessageStateError( 'Message already acknowledged with state: {0._state}'.format( self)) - self.channel.basic_reject(self.delivery_tag, requeue=requeue) + await self.channel.basic_reject(self.delivery_tag, requeue=requeue) self._state = 'REJECTED' - def requeue(self): + async def requeue(self) -> None: """Reject this message and put it back on the queue. Warning: @@ -174,10 +189,10 @@ class Message: raise self.MessageStateError( 'Message already acknowledged with state: {0._state}'.format( self)) - self.channel.basic_reject(self.delivery_tag, requeue=True) + await self.channel.basic_reject(self.delivery_tag, requeue=True) self._state = 'REQUEUED' - def decode(self): + def decode(self) -> Any: """Deserialize the message body. Returning the original python structure sent by the publisher. @@ -190,21 +205,21 @@ class Message: self._decoded_cache = self._decode() return self._decoded_cache - def _decode(self): + def _decode(self) -> Any: return loads(self.body, self.content_type, self.content_encoding, accept=self.accept) @property - def acknowledged(self): + def acknowledged(self) -> bool: """Set to true if the message has been acknowledged.""" return self._state in ACK_STATES @property - def payload(self): + def payload(self) -> Any: """The decoded message body.""" return self._decoded_cache if self._decoded_cache else self.decode() - def __repr__(self): + def __repr__(self) -> str: return '<{0} object at {1:#x} with details {2!r}>'.format( type(self).__name__, id(self), dictfilter( state=self._state, diff --git a/kombu/messaging.py b/kombu/messaging.py index 3dd03d14..67ccd15d 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -1,20 +1,23 @@ """Sending and receiving messages.""" +import numbers from itertools import count - +from typing import Any, Callable, Mapping, Sequence, Tuple, Union +from amqp import ChannelT +from . import types +from .abstract import Entity from .common import maybe_declare from .compression import compress from .connection import maybe_channel, is_connection from .entity import Exchange, Queue, maybe_delivery_mode from .exceptions import ContentDisallowed from .serialization import dumps, prepare_accept_content -from .utils import abstract +from .types import ChannelArgT, ClientT, MessageT, QueueT from .utils.functional import ChannelPromise, maybe_list __all__ = ['Exchange', 'Queue', 'Producer', 'Consumer'] -@abstract.Producer.register -class Producer: +class Producer(types.ProducerT): """Message Producer. Arguments: @@ -34,7 +37,7 @@ class Producer: """ #: Default exchange - exchange = None + exchange = None # type: Exchange #: Default routing key. routing_key = '' @@ -56,9 +59,16 @@ class Producer: #: default_channel). __connection__ = None - def __init__(self, channel, exchange=None, routing_key=None, - serializer=None, auto_declare=None, compression=None, - on_return=None): + _channel: ChannelT + + def __init__(self, + channel: ChannelArgT, + exchange: Exchange = None, + routing_key: str = None, + serializer: str = None, + auto_declare: bool = None, + compression: str = None, + on_return: Callable = None): self._channel = channel self.exchange = exchange self.routing_key = routing_key or self.routing_key @@ -71,20 +81,17 @@ class Producer: if auto_declare is not None: self.auto_declare = auto_declare - if self._channel: - self.revive(self._channel) - - def __repr__(self): + def __repr__(self) -> str: return '<Producer: {0._channel}>'.format(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Any, ...]: return self.__class__, self.__reduce_args__() - def __reduce_args__(self): + def __reduce_args__(self) -> Tuple[Any, ...]: return (None, self.exchange, self.routing_key, self.serializer, self.auto_declare, self.compression) - def declare(self): + async def declare(self) -> None: """Declare the exchange. Note: @@ -92,16 +99,18 @@ class Producer: the :attr:`auto_declare` flag is enabled. """ if self.exchange.name: - self.exchange.declare() + await self.exchange.declare() - def maybe_declare(self, entity, retry=False, **retry_policy): + async def maybe_declare(self, entity: Entity, + retry: bool = False, **retry_policy) -> None: """Declare exchange if not already declared during this session.""" if entity: - return maybe_declare(entity, self.channel, retry, **retry_policy) + await maybe_declare(entity, self.channel, retry, **retry_policy) - def _delivery_details(self, exchange, delivery_mode=None, - maybe_delivery_mode=maybe_delivery_mode, - Exchange=Exchange): + def _delivery_details(self, exchange: Union[Exchange, str], + delivery_mode: Union[int, str]=None, + maybe_delivery_mode: Callable = maybe_delivery_mode, + Exchange: Any = Exchange) -> Tuple[str, int]: if isinstance(exchange, Exchange): return exchange.name, maybe_delivery_mode( delivery_mode or exchange.delivery_mode, @@ -112,12 +121,23 @@ class Producer: delivery_mode or self.exchange.delivery_mode, ) - def publish(self, body, routing_key=None, delivery_mode=None, - mandatory=False, immediate=False, priority=0, - content_type=None, content_encoding=None, serializer=None, - headers=None, compression=None, exchange=None, retry=False, - retry_policy=None, declare=None, expiration=None, - **properties): + async def publish(self, body: Any, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + mandatory: bool = False, + immediate: bool = False, + priority: int = 0, + content_type: str = None, + content_encoding: str = None, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + exchange: Union[Exchange, str] = None, + retry: bool = False, + retry_policy: Mapping = None, + declare: Sequence[Entity] = None, + expiration: numbers.Number = None, + **properties) -> None: """Publish message to the specified exchange. Arguments: @@ -172,54 +192,79 @@ class Producer: declare.append(self.exchange) if retry: - _publish = self.connection.ensure(self, _publish, **retry_policy) - return _publish( + _publish = await self.connection.ensure( + self, _publish, **retry_policy) + await _publish( body, priority, content_type, content_encoding, headers, properties, routing_key, mandatory, immediate, exchange_name, declare, ) - def _publish(self, body, priority, content_type, content_encoding, - headers, properties, routing_key, mandatory, - immediate, exchange, declare): - channel = self.channel + async def _publish(self, + body: Any, + priority: int = None, + content_type: str = None, + content_encoding: str = None, + headers: Mapping = None, + properties: Mapping = None, + routing_key: str = None, + mandatory: bool = None, + immediate: bool = None, + exchange: str = None, + declare: Sequence = None) -> None: + channel = await self._resolve_channel() message = channel.prepare_message( body, priority, content_type, content_encoding, headers, properties, ) if declare: maybe_declare = self.maybe_declare - [maybe_declare(entity) for entity in declare] + for entity in declare: + await maybe_declare(entity) # handle autogenerated queue names for reply_to reply_to = properties.get('reply_to') if isinstance(reply_to, Queue): properties['reply_to'] = reply_to.name - return channel.basic_publish( + await channel.basic_publish( message, exchange=exchange, routing_key=routing_key, mandatory=mandatory, immediate=immediate, ) - def _get_channel(self): + async def _resolve_channel(self): + channel = self._channel + if isinstance(channel, ChannelPromise): + channel = self._channel = await channel.resolve() + if self.exchange: + self.exchange.revive(channel) + if self.on_return: + channel.events['basic_return'].add(self.on_return) + else: + channel = maybe_channel(channel) + await channel.open() + return channel + + def _get_channel(self) -> ChannelT: channel = self._channel if isinstance(channel, ChannelPromise): channel = self._channel = channel() - self.exchange.revive(channel) + if self.exchange: + self.exchange.revive(channel) if self.on_return: channel.events['basic_return'].add(self.on_return) return channel - def _set_channel(self, channel): + def _set_channel(self, channel: ChannelT) -> None: self._channel = channel channel = property(_get_channel, _set_channel) - def revive(self, channel): + async def revive(self, channel: ChannelT) -> None: """Revive the producer after connection loss.""" if is_connection(channel): connection = channel self.__connection__ = connection - channel = ChannelPromise(lambda: connection.default_channel) + channel = ChannelPromise(connection) if isinstance(channel, ChannelPromise): self._channel = channel self.exchange = self.exchange(channel) @@ -230,18 +275,29 @@ class Producer: self._channel.events['basic_return'].add(self.on_return) self.exchange = self.exchange(channel) - def __enter__(self): + def __enter__(self) -> 'Producer': return self - def __exit__(self, *exc_info): + async def __aenter__(self) -> 'Producer': + await self.revive(self.channel) + return self + + def __exit__(self, *exc_info) -> None: self.release() - def release(self): - pass + async def __aexit__(self, *exc_info) -> None: + self.release() + + def release(self) -> None: + ... close = release - def _prepare(self, body, serializer=None, content_type=None, - content_encoding=None, compression=None, headers=None): + def _prepare(self, body: Any, + serializer: str = None, + content_type: str = None, + content_encoding: str = None, + compression: str = None, + headers: Mapping = None) -> Tuple[Any, str, str]: # No content_type? Then we're serializing the data internally. if not content_type: @@ -267,15 +323,14 @@ class Producer: return body, content_type, content_encoding @property - def connection(self): + def connection(self) -> ClientT: try: return self.__connection__ or self.channel.connection.client except AttributeError: pass -@abstract.Consumer.register -class Consumer: +class Consumer(types.ConsumerT): """Message consumer. Arguments: @@ -358,13 +413,23 @@ class Consumer: prefetch_count = None #: Mapping of queues we consume from. - _queues = None + _queues: Mapping[str, QueueT] = None _tags = count(1) # global - - def __init__(self, channel, queues=None, no_ack=None, auto_declare=None, - callbacks=None, on_decode_error=None, on_message=None, - accept=None, prefetch_count=None, tag_prefix=None): + _active_tags: Mapping[str, str] + + def __init__( + self, + channel: ChannelT, + queues: Sequence[QueueT] = None, + no_ack: bool = None, + auto_declare: bool = None, + callbacks: Sequence[Callable] = None, + on_decode_error: Callable = None, + on_message: Callable = None, + accept: Sequence[str] = None, + prefetch_count: int = None, + tag_prefix: str = None) -> None: self.channel = channel self.queues = maybe_list(queues or []) self.no_ack = self.no_ack if no_ack is None else no_ack @@ -380,21 +445,19 @@ class Consumer: self.accept = prepare_accept_content(accept) self.prefetch_count = prefetch_count - if self.channel: - self.revive(self.channel) - @property - def queues(self): + def queues(self) -> Sequence[QueueT]: return list(self._queues.values()) @queues.setter - def queues(self, queues): + def queues(self, queues: Sequence[QueueT]) -> None: self._queues = {q.name: q for q in queues} - def revive(self, channel): + async def revive(self, channel: ChannelT) -> None: """Revive consumer after connection loss.""" self._active_tags.clear() channel = self.channel = maybe_channel(channel) + await channel.open() for qname, queue in self._queues.items(): # name may have changed after declare self._queues.pop(qname, None) @@ -402,22 +465,21 @@ class Consumer: queue.revive(channel) if self.auto_declare: - self.declare() + await self.declare() if self.prefetch_count is not None: - self.qos(prefetch_count=self.prefetch_count) + await self.qos(prefetch_count=self.prefetch_count) - def declare(self): + async def declare(self) -> None: """Declare queues, exchanges and bindings. Note: This is done automatically at instantiation when :attr:`auto_declare` is set. """ - for queue in self._queues.values(): - queue.declare() + [await queue.declare() for queue in self._queues.values()] - def register_callback(self, callback): + def register_callback(self, callback: Callable) -> None: """Register a new callback to be called when a message is received. Note: @@ -427,11 +489,26 @@ class Consumer: """ self.callbacks.append(callback) - def __enter__(self): + def __enter__(self) -> 'Consumer': + self.revive(self.channel) self.consume() return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aenter__(self) -> 'Consumer': + await self.revive(self.channel) + await self.consume() + return self + + async def __aexit__(self, *exc_info) -> None: + if self.channel: + conn_errors = self.channel.connection.client.connection_errors + if not isinstance(exc_info[1], conn_errors): + try: + await self.cancel() + except Exception: + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.channel: conn_errors = self.channel.connection.client.connection_errors if not isinstance(exc_val, conn_errors): @@ -440,7 +517,7 @@ class Consumer: except Exception: pass - def add_queue(self, queue): + async def add_queue(self, queue: QueueT) -> QueueT: """Add a queue to the list of queues to consume from. Note: @@ -449,11 +526,11 @@ class Consumer: """ queue = queue(self.channel) if self.auto_declare: - queue.declare() + await queue.declare() self._queues[queue.name] = queue return queue - def consume(self, no_ack=None): + async def consume(self, no_ack: bool = None) -> None: """Start consuming messages. Can be called multiple times, but note that while it @@ -470,10 +547,10 @@ class Consumer: H, T = queues[:-1], queues[-1] for queue in H: - self._basic_consume(queue, no_ack=no_ack, nowait=True) - self._basic_consume(T, no_ack=no_ack, nowait=False) + await self._basic_consume(queue, no_ack=no_ack, nowait=True) + await self._basic_consume(T, no_ack=no_ack, nowait=False) - def cancel(self): + async def cancel(self) -> None: """End all active queue consumers. Note: @@ -482,38 +559,36 @@ class Consumer: """ cancel = self.channel.basic_cancel for tag in self._active_tags.values(): - cancel(tag) + await cancel(tag) self._active_tags.clear() close = cancel - def cancel_by_queue(self, queue): + async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None: """Cancel consumer by queue name.""" - qname = queue.name if isinstance(queue, Queue) else queue + qname = queue.name if isinstance(queue, QueueT) else queue try: tag = self._active_tags.pop(qname) except KeyError: pass else: - self.channel.basic_cancel(tag) + await self.channel.basic_cancel(tag) finally: self._queues.pop(qname, None) - def consuming_from(self, queue): + def consuming_from(self, queue: Union[QueueT, str]) -> bool: """Return :const:`True` if currently consuming from queue'.""" - name = queue - if isinstance(queue, Queue): - name = queue.name + name = queue.name if isinstance(queue, QueueT) else queue return name in self._active_tags - def purge(self): + async def purge(self) -> int: """Purge messages from all queues. Warning: This will *delete all ready messages*, there is no undo operation. """ - return sum(queue.purge() for queue in self._queues.values()) + return sum(await queue.purge() for queue in self._queues.values()) - def flow(self, active): + async def flow(self, active: bool) -> None: """Enable/disable flow from peer. This is a simple flow-control mechanism that a peer can use @@ -524,9 +599,12 @@ class Consumer: will finish sending the current content (if any), and then wait until flow is reactivated. """ - self.channel.flow(active) + await self.channel.flow(active) - def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False): + async def qos(self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: """Specify quality of service. The client can request that messages should be sent in @@ -550,11 +628,10 @@ class Consumer: apply_global (bool): Apply new settings globally on all channels. """ - return self.channel.basic_qos(prefetch_size, - prefetch_count, - apply_global) + await self.channel.basic_qos( + prefetch_size, prefetch_count, apply_global) - def recover(self, requeue=False): + async def recover(self, requeue: bool = False) -> None: """Redeliver unacknowledged messages. Asks the broker to redeliver all unacknowledged messages @@ -566,9 +643,9 @@ class Consumer: server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. """ - return self.channel.basic_recover(requeue=requeue) + await self.channel.basic_recover(requeue=requeue) - def receive(self, body, message): + async def receive(self, body: Any, message: MessageT) -> None: """Method called when a message is received. This dispatches to the registered :attr:`callbacks`. @@ -584,24 +661,27 @@ class Consumer: callbacks = self.callbacks if not callbacks: raise NotImplementedError('Consumer does not have any callbacks') - [callback(body, message) for callback in callbacks] + for callback in callbacks: + await callback(body, message) - def _basic_consume(self, queue, consumer_tag=None, - no_ack=no_ack, nowait=True): + async def _basic_consume(self, queue: QueueT, + consumer_tag: str = None, + no_ack: bool = no_ack, + nowait: bool = True) -> str: tag = self._active_tags.get(queue.name) if tag is None: tag = self._add_tag(queue, consumer_tag) - queue.consume(tag, self._receive_callback, - no_ack=no_ack, nowait=nowait) + await queue.consume(tag, self._receive_callback, + no_ack=no_ack, nowait=nowait) return tag - def _add_tag(self, queue, consumer_tag=None): + def _add_tag(self, queue: QueueT, consumer_tag: str = None) -> str: tag = consumer_tag or '{0}{1}'.format( self.tag_prefix, next(self._tags)) self._active_tags[queue.name] = tag return tag - def _receive_callback(self, message): + async def _receive_callback(self, message: MessageT) -> None: accept = self.accept on_m, channel, decoded = self.on_message, self.channel, None try: @@ -611,20 +691,20 @@ class Consumer: if accept is not None: message.accept = accept if message.errors: - return message._reraise_error(self.on_decode_error) + message._reraise_error(self.on_decode_error) decoded = None if on_m else message.decode() except Exception as exc: if not self.on_decode_error: raise self.on_decode_error(message, exc) else: - return on_m(message) if on_m else self.receive(decoded, message) + await on_m(message) if on_m else self.receive(decoded, message) - def __repr__(self): + def __repr__(self) -> str: return '<{name}: {0.queues}>'.format(self, name=type(self).__name__) @property - def connection(self): + def connection(self) -> ClientT: try: return self.channel.connection.client except AttributeError: diff --git a/kombu/pidbox.py b/kombu/pidbox.py index 0c6196fc..4775a59b 100644 --- a/kombu/pidbox.py +++ b/kombu/pidbox.py @@ -8,12 +8,19 @@ from copy import copy from itertools import count from threading import local from time import time +from typing import Any, Callable, Mapping, Set, Sequence, cast -from . import Exchange, Queue, Consumer, Producer -from .clocks import LamportClock +from amqp.types import ChannelT + +from .clocks import Clock, LamportClock from .common import maybe_declare, oid_from +from .entity import Exchange, Queue from .exceptions import InconsistencyError from .log import get_logger +from .messaging import Consumer, Producer +from .typing import ( + ClientT, ConsumerT, ExchangeT, ProducerT, MessageT, ResourceT, +) from .utils.functional import maybe_evaluate, reprcall from .utils.objects import cached_property from .utils.uuid import uuid @@ -35,22 +42,26 @@ class Node: """Mailbox node.""" #: hostname of the node. - hostname = None + hostname: str = None #: the :class:`Mailbox` this is a node for. - mailbox = None + mailbox: 'Mailbox' = None #: map of method name/handlers. - handlers = None + handlers: Mapping[str, Callable] = None #: current context (passed on to handlers) - state = None + state: Any = None #: current channel. - channel = None - - def __init__(self, hostname, state=None, channel=None, - handlers=None, mailbox=None): + channel: ChannelT = None + + def __init__(self, + hostname: str, + state: Any = None, + channel: ChannelT = None, + handlers: Mapping[str, Callable] = None, + mailbox: 'Mailbox' = None) -> None: self.channel = channel self.mailbox = mailbox self.hostname = hostname @@ -60,10 +71,14 @@ class Node: handlers = {} self.handlers = handlers - def Consumer(self, channel=None, no_ack=True, accept=None, **options): + def Consumer(self, + channel: ChannelT = None, + no_ack: bool = True, + accept: Set[str] = None, + **options) -> ConsumerT: queue = self.mailbox.get_queue(self.hostname) - def verify_exclusive(name, messages, consumers): + def verify_exclusive(name: str, messages: int, consumers: int) -> None: if consumers: warnings.warn(W_PIDBOX_IN_USE.format(node=self)) queue.on_declared = verify_exclusive @@ -74,22 +89,24 @@ class Node: **options ) - def handler(self, fun): + def handler(self, fun: Callable) -> Callable: self.handlers[fun.__name__] = fun return fun - def on_decode_error(self, message, exc): + def on_decode_error(self, message: str, exc: Exception) -> None: error('Cannot decode message: %r', exc, exc_info=1) - def listen(self, channel=None, callback=None): + def listen(self, + channel: ChannelT = None, + callback: Callable = None) -> ConsumerT: consumer = self.Consumer(channel=channel, callbacks=[callback or self.handle_message], on_decode_error=self.on_decode_error) consumer.consume() return consumer - def dispatch(self, method, arguments=None, - reply_to=None, ticket=None, **kwargs): + def dispatch(self, method: str, arguments: Mapping = None, + reply_to: str = None, ticket: str = None, **kwargs) -> Any: arguments = arguments or {} debug('pidbox received method %s [reply_to:%s ticket:%s]', reprcall(method, (), kwargs=arguments), reply_to, ticket) @@ -109,16 +126,17 @@ class Node: ticket=ticket) return reply - def handle(self, method, arguments={}): + def handle(self, method: str, arguments: Mapping = {}) -> Any: return self.handlers[method](self.state, **arguments) - def handle_call(self, method, arguments): + def handle_call(self, method: str, arguments: Mapping) -> Any: return self.handle(method, arguments) - def handle_cast(self, method, arguments): + def handle_cast(self, method: str, arguments: Mapping) -> Any: return self.handle(method, arguments) - def handle_message(self, body, message=None): + def handle_message(self, body: Any, message: MessageT = None) -> None: + body = cast(Mapping, body) destination = body.get('destination') if message: self.adjust_clock(message.headers.get('clock') or 0) @@ -126,7 +144,8 @@ class Node: return self.dispatch(**body) dispatch_from_message = handle_message - def reply(self, data, exchange, routing_key, ticket, **kwargs): + def reply(self, data: Any, exchange: str, routing_key: str, ticket: str, + **kwargs) -> None: self.mailbox._publish_reply(data, exchange, routing_key, ticket, channel=self.channel, serializer=self.mailbox.serializer) @@ -135,36 +154,42 @@ class Node: class Mailbox: """Process Mailbox.""" - node_cls = Node - exchange_fmt = '%s.pidbox' - reply_exchange_fmt = 'reply.%s.pidbox' + node_cls: type = Node + exchange_fmt: str = '%s.pidbox' + reply_exchange_fmt: str = 'reply.%s.pidbox' #: Name of application. - namespace = None + namespace: str = None #: Connection (if bound). - connection = None + connection: ClientT = None #: Exchange type (usually direct, or fanout for broadcast). - type = 'direct' + type: str = 'direct' #: mailbox exchange (init by constructor). - exchange = None + exchange: ExchangeT = None #: exchange to send replies to. - reply_exchange = None + reply_exchange: ExchangeT = None #: Only accepts json messages by default. - accept = ['json'] + accept: Set[str] = {'json'} #: Message serializer - serializer = None - - def __init__(self, namespace, - type='direct', connection=None, clock=None, - accept=None, serializer=None, producer_pool=None, - queue_ttl=None, queue_expires=None, - reply_queue_ttl=None, reply_queue_expires=10.0): + serializer: str = None + + def __init__(self, namespace: str, + type: str = 'direct', + connection: ClientT = None, + clock: Clock = None, + accept: Set[str] = None, + serializer: str = None, + producer_pool: ResourceT = None, + queue_ttl: float = None, + queue_expires: float = None, + reply_queue_ttl: float = None, + reply_queue_expires: float = 10.0) -> None: self.namespace = namespace self.connection = connection self.type = type @@ -181,36 +206,48 @@ class Mailbox: self.reply_queue_expires = reply_queue_expires self._producer_pool = producer_pool - def __call__(self, connection): + def __call__(self, connection: ClientT) -> 'Mailbox': bound = copy(self) bound.connection = connection return bound - def Node(self, hostname=None, state=None, channel=None, handlers=None): + def Node(self, + hostname: str = None, + state: Any = None, + channel: ChannelT = None, + handlers: Mapping[str, Callable] = None) -> Node: hostname = hostname or socket.gethostname() return self.node_cls(hostname, state, channel, handlers, mailbox=self) - def call(self, destination, command, kwargs={}, - timeout=None, callback=None, channel=None): + def call(self, destination: str, command: str, + kwargs: Mapping[str, Any] = {}, + timeout: float = None, + callback: Callable = None, + channel: ChannelT = None) -> Sequence[Mapping]: return self._broadcast(command, kwargs, destination, reply=True, timeout=timeout, callback=callback, channel=channel) - def cast(self, destination, command, kwargs={}): - return self._broadcast(command, kwargs, destination, reply=False) + def cast(self, destination: str, command: str, + kwargs: Mapping[str, Any] = {}) -> None: + self._broadcast(command, kwargs, destination, reply=False) - def abcast(self, command, kwargs={}): - return self._broadcast(command, kwargs, reply=False) + def abcast(self, command: str, kwargs: Mapping[str, Any] = {}) -> None: + self._broadcast(command, kwargs, reply=False) - def multi_call(self, command, kwargs={}, timeout=1, - limit=None, callback=None, channel=None): + def multi_call(self, command: str, + kwargs: Mapping[str, Any] = {}, + timeout: str = 1, + limit: int = None, + callback: Callable = None, + channel: ChannelT = None) -> Sequence[Mapping]: return self._broadcast(command, kwargs, reply=True, timeout=timeout, limit=limit, callback=callback, channel=channel) - def get_reply_queue(self): + def get_reply_queue(self) -> Queue: oid = self.oid return Queue( '%s.%s' % (oid, self.reply_exchange.name), @@ -223,10 +260,10 @@ class Mailbox: ) @cached_property - def reply_queue(self): + def reply_queue(self) -> Queue: return self.get_reply_queue() - def get_queue(self, hostname): + def get_queue(self, hostname: str) -> Queue: return Queue( '%s.%s.pidbox' % (hostname, self.namespace), exchange=self.exchange, @@ -237,7 +274,9 @@ class Mailbox: ) @contextmanager - def producer_or_acquire(self, producer=None, channel=None): + def producer_or_acquire(self, + producer: ProducerT = None, + channel: ChannelT = None) -> ProducerT: if producer: yield producer elif self.producer_pool: @@ -246,8 +285,11 @@ class Mailbox: else: yield Producer(channel, auto_declare=False) - def _publish_reply(self, reply, exchange, routing_key, ticket, - channel=None, producer=None, **opts): + def _publish_reply(self, reply: Any, + exchange: str, routing_key: str, ticket: str, + channel: ChannelT = None, + producer: ProducerT = None, + **opts) -> None: chan = channel or self.connection.default_channel exchange = Exchange(exchange, exchange_type='direct', delivery_mode='transient', @@ -265,9 +307,13 @@ class Mailbox: # queue probably deleted and no one is expecting a reply. pass - def _publish(self, type, arguments, destination=None, - reply_ticket=None, channel=None, timeout=None, - serializer=None, producer=None): + def _publish(self, type: str, arguments: Mapping[str, Any], + destination: str = None, + reply_ticket: str = None, + channel: ChannelT = None, + timeout: float = None, + serializer: str = None, + producer: ProducerT = None) -> None: message = {'method': type, 'arguments': arguments, 'destination': destination} @@ -287,9 +333,15 @@ class Mailbox: serializer=serializer, ) - def _broadcast(self, command, arguments=None, destination=None, - reply=False, timeout=1, limit=None, - callback=None, channel=None, serializer=None): + def _broadcast(self, command: str, + arguments: Mapping[str, Any] = None, + destination: str = None, + reply: bool = False, + timeout: float = 1.0, + limit: int = None, + callback: Callable = None, + channel: ChannelT = None, + serializer: str = None) -> Sequence[Mapping]: if destination is not None and \ not isinstance(destination, (list, tuple)): raise ValueError( @@ -317,9 +369,12 @@ class Mailbox: callback=callback, channel=chan) - def _collect(self, ticket, - limit=None, timeout=1, callback=None, - channel=None, accept=None): + def _collect(self, ticket: str, + limit: str = None, + timeout: float = 1.0, + callback: Callable = None, + channel: ChannelT = None, + accept: Set[str] = None) -> Sequence[Mapping]: if accept is None: accept = self.accept chan = channel or self.connection.default_channel @@ -334,7 +389,7 @@ class Mailbox: except KeyError: pass - def on_message(body, message): + def on_message(body: Any, message: MessageT) -> None: # ticket header added in kombu 2.5 header = message.headers.get adjust_clock(header('clock') or 0) @@ -361,20 +416,20 @@ class Mailbox: finally: chan.after_reply_message_received(queue.name) - def _get_exchange(self, namespace, type): + def _get_exchange(self, namespace: str, type: str) -> Exchange: return Exchange(self.exchange_fmt % namespace, type=type, durable=False, delivery_mode='transient') - def _get_reply_exchange(self, namespace): + def _get_reply_exchange(self, namespace: str) -> Exchange: return Exchange(self.reply_exchange_fmt % namespace, type='direct', durable=False, delivery_mode='transient') @cached_property - def oid(self): + def oid(self) -> str: try: return self._tls.OID except AttributeError: @@ -382,5 +437,5 @@ class Mailbox: return oid @cached_property - def producer_pool(self): + def producer_pool(self) -> ResourceT: return maybe_evaluate(self._producer_pool) diff --git a/kombu/pools.py b/kombu/pools.py index f9168886..88770212 100644 --- a/kombu/pools.py +++ b/kombu/pools.py @@ -1,23 +1,25 @@ """Public resource pools.""" import os - from itertools import chain - +from typing import Any from .connection import Resource from .messaging import Producer +from .types import ClientT, ProducerT, ResourceT from .utils.collections import EqualityDict from .utils.compat import register_after_fork from .utils.functional import lazy -__all__ = ['ProducerPool', 'PoolGroup', 'register_group', - 'connections', 'producers', 'get_limit', 'set_limit', 'reset'] +__all__ = [ + 'ProducerPool', 'PoolGroup', 'register_group', + 'connections', 'producers', 'get_limit', 'set_limit', 'reset', +] _limit = [10] _groups = [] -use_global_limit = object() +use_global_limit = 44444444444444444444444444 disable_limit_protection = os.environ.get('KOMBU_DISABLE_LIMIT_PROTECTION') -def _after_fork_cleanup_group(group): +def _after_fork_cleanup_group(group: 'PoolGroup') -> None: group.clear() @@ -27,15 +29,19 @@ class ProducerPool(Resource): Producer = Producer close_after_fork = True - def __init__(self, connections, *args, Producer=None, **kwargs): + def __init__(self, + connections: ResourceT, + *args, + Producer: type = None, + **kwargs) -> None: self.connections = connections self.Producer = Producer or self.Producer super().__init__(*args, **kwargs) - def _acquire_connection(self): + def _acquire_connection(self) -> ClientT: return self.connections.acquire(block=True) - def create_producer(self): + def create_producer(self) -> ProducerT: conn = self._acquire_connection() try: return self.Producer(conn) @@ -43,18 +49,18 @@ class ProducerPool(Resource): conn.release() raise - def new(self): + def new(self) -> lazy: return lazy(self.create_producer) - def setup(self): + def setup(self) -> None: if self.limit: for _ in range(self.limit): self._resource.put_nowait(self.new()) - def close_resource(self, resource): + def close_resource(self, resource: ProducerT) -> None: ... - def prepare(self, p): + def prepare(self, p: Any) -> None: if callable(p): p = p() if p._channel is None: @@ -66,7 +72,7 @@ class ProducerPool(Resource): raise return p - def release(self, resource): + def release(self, resource: ProducerT) -> None: if resource.__connection__: resource.__connection__.release() resource.channel = None @@ -76,24 +82,25 @@ class ProducerPool(Resource): class PoolGroup(EqualityDict): """Collection of resource pools.""" - def __init__(self, limit=None, close_after_fork=True): + def __init__(self, limit: int = None, + close_after_fork: bool = True) -> None: self.limit = limit self.close_after_fork = close_after_fork if self.close_after_fork and register_after_fork is not None: register_after_fork(self, _after_fork_cleanup_group) - def create(self, resource, limit): + def create(self, connection: ClientT, limit: int) -> Any: raise NotImplementedError('PoolGroups must define ``create``') - def __missing__(self, resource): + def __missing__(self, resource: Any) -> Any: limit = self.limit - if limit is use_global_limit: + if limit == use_global_limit: limit = get_limit() k = self[resource] = self.create(resource, limit) return k -def register_group(group): +def register_group(group: PoolGroup) -> PoolGroup: """Register group (can be used as decorator).""" _groups.append(group) return group @@ -102,7 +109,7 @@ def register_group(group): class Connections(PoolGroup): """Collection of connection pools.""" - def create(self, connection, limit): + def create(self, connection: ClientT, limit: int) -> Any: return connection.Pool(limit=limit) connections = register_group(Connections(limit=use_global_limit)) # noqa: E305 @@ -110,21 +117,24 @@ connections = register_group(Connections(limit=use_global_limit)) # noqa: E305 class Producers(PoolGroup): """Collection of producer pools.""" - def create(self, connection, limit): + def create(self, connection: ClientT, limit: int) -> Any: return ProducerPool(connections[connection], limit=limit) producers = register_group(Producers(limit=use_global_limit)) # noqa: E305 -def _all_pools(): +def _all_pools() -> chain[ResourceT]: return chain(*[(g.values() if g else iter([])) for g in _groups]) -def get_limit(): +def get_limit() -> int: """Get current connection pool limit.""" return _limit[0] -def set_limit(limit, force=False, reset_after=False, ignore_errors=False): +def set_limit(limit: int, + force: bool = False, + reset_after: bool = False, + ignore_errors: bool = False) -> int: """Set new connection pool limit.""" limit = limit or 0 glimit = _limit[0] or 0 @@ -135,7 +145,7 @@ def set_limit(limit, force=False, reset_after=False, ignore_errors=False): return limit -def reset(*args, **kwargs): +def reset(*args, **kwargs) -> None: """Reset all pools by closing open resources.""" for pool in _all_pools(): try: diff --git a/kombu/resource.py b/kombu/resource.py index d98f49e7..1829b466 100644 --- a/kombu/resource.py +++ b/kombu/resource.py @@ -1,16 +1,15 @@ """Generic resource pool implementation.""" import os - from collections import deque from queue import Empty, LifoQueue as _LifoQueue - +from typing import Any from . import exceptions -from .utils import abstract +from .types import ResourceT from .utils.compat import register_after_fork from .utils.functional import lazy -def _after_fork_cleanup_resource(resource): +def _after_fork_cleanup_resource(resource: ResourceT) -> None: try: resource.force_close_all() except Exception: @@ -20,19 +19,21 @@ def _after_fork_cleanup_resource(resource): class LifoQueue(_LifoQueue): """Last in first out version of Queue.""" - def _init(self, maxsize): + def _init(self, maxsize: int) -> None: self.queue = deque() -@abstract.Resource.register -class Resource: +class Resource(ResourceT): """Pool of resources.""" LimitExceeded = exceptions.LimitExceeded close_after_fork = False - def __init__(self, limit=None, preload=None, close_after_fork=None): + def __init__(self, + limit: int = None, + preload: int = None, + close_after_fork: bool = None): self._limit = limit self.preload = preload or 0 self._closed = False @@ -47,10 +48,10 @@ class Resource: register_after_fork(self, _after_fork_cleanup_resource) self.setup() - def setup(self): + def setup(self) -> None: raise NotImplementedError('subclass responsibility') - def _add_when_empty(self): + def _add_when_empty(self) -> None: if self.limit and len(self._dirty) >= self.limit: raise self.LimitExceeded(self.limit) # All taken, put new on the queue and @@ -58,7 +59,7 @@ class Resource: # will get the resource. self._resource.put_nowait(self.new()) - def acquire(self, block=False, timeout=None): + def acquire(self, block: bool = False, timeout: int = None): """Acquire resource. Arguments: @@ -94,7 +95,7 @@ class Resource: else: R = self.prepare(self.new()) - def release(): + def release() -> None: """Release resource so it can be used by another thread. Warnings: @@ -107,16 +108,16 @@ class Resource: return R - def prepare(self, resource): + def prepare(self, resource: Any) -> Any: return resource - def close_resource(self, resource): + def close_resource(self, resource: Any) -> None: resource.close() - def release_resource(self, resource): + def release_resource(self, resource: Any) -> None: ... - def replace(self, resource): + def replace(self, resource: Any) -> None: """Replace existing resource with a new instance. This can be used in case of defective resources. @@ -125,7 +126,7 @@ class Resource: self._dirty.discard(resource) self.close_resource(resource) - def release(self, resource): + def release(self, resource: Any) -> None: if self.limit: self._dirty.discard(resource) self._resource.put_nowait(resource) @@ -133,10 +134,10 @@ class Resource: else: self.close_resource(resource) - def collect_resource(self, resource): + def collect_resource(self, resource: Any) -> None: ... - def force_close_all(self): + def force_close_all(self) -> None: """Close and remove all resources in the pool (also those in use). Used to close resources from parent processes after fork @@ -169,7 +170,11 @@ class Resource: except AttributeError: pass # Issue #78 - def resize(self, limit, force=False, ignore_errors=False, reset=False): + def resize( + self, limit: int, + force: bool = False, + ignore_errors: bool = False, + reset: bool = False) -> None: prev_limit = self._limit if (self._dirty and limit < self._limit) and not ignore_errors: if not force: @@ -187,7 +192,7 @@ class Resource: if limit < prev_limit: self._shrink_down() - def _shrink_down(self): + def _shrink_down(self) -> None: resource = self._resource # Items to the left are last recently used, so we remove those first. with resource.mutex: @@ -195,11 +200,11 @@ class Resource: self.collect_resource(resource.queue.popleft()) @property - def limit(self): + def limit(self) -> int: return self._limit @limit.setter - def limit(self, limit): + def limit(self, limit: int) -> None: self.resize(limit) if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover @@ -208,7 +213,7 @@ class Resource: _next_resource_id = 0 - def acquire(self, *args, **kwargs): # noqa + def acquire(self, *args, **kwargs) -> Any: # noqa import traceback id = self._next_resource_id = self._next_resource_id + 1 print('+{0} ACQUIRE {1}'.format(id, self.__class__.__name__)) @@ -220,7 +225,7 @@ class Resource: r.acquired_by.append(traceback.format_stack()) return r - def release(self, resource): # noqa + def release(self, resource: Any) -> None: # noqa id = resource._resource_id print('+{0} RELEASE {1}'.format(id, self.__class__.__name__)) r = self._orig_release(resource) diff --git a/kombu/serialization.py b/kombu/serialization.py index 7f885d70..a8a9cc48 100644 --- a/kombu/serialization.py +++ b/kombu/serialization.py @@ -9,9 +9,9 @@ try: except ImportError: # pragma: no cover cpickle = None # noqa -from collections import namedtuple from contextlib import contextmanager from io import BytesIO +from typing import Callable, NamedTuple from .exceptions import ( ContentDisallowed, DecodeError, EncodeError, SerializerNotInstalled @@ -37,7 +37,13 @@ pickle_load = pickle.load #: There's a new protocol (3) but this is only supported by Python 3. pickle_protocol = int(os.environ.get('PICKLE_PROTOCOL', 2)) -codec = namedtuple('codec', ('content_type', 'content_encoding', 'encoder')) + +class codec(NamedTuple): + """Codec registration triple.""" + + content_type: str + content_encoding: str + encoder: Callable @contextmanager diff --git a/kombu/simple.py b/kombu/simple.py index 4facedf7..2436a28e 100644 --- a/kombu/simple.py +++ b/kombu/simple.py @@ -1,28 +1,32 @@ """Simple messaging interface.""" import socket - from collections import deque from time import monotonic from queue import Empty - +from typing import Any, Mapping from . import entity from . import messaging from .connection import maybe_channel +from .types import ChannelArgT, ConsumerT, ProducerT, MessageT, SimpleQueueT __all__ = ['SimpleQueue', 'SimpleBuffer'] -class SimpleBase: +class SimpleBase(SimpleQueueT): Empty = Empty - _consuming = False + _consuming: bool = False - def __enter__(self): + def __enter__(self) -> SimpleQueueT: return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info) -> None: self.close() - def __init__(self, channel, producer, consumer, no_ack=False): + def __init__(self, + channel: ChannelArgT, + producer: ProducerT, + consumer: ConsumerT, + no_ack: bool = False) -> None: self.channel = maybe_channel(channel) self.producer = producer self.consumer = consumer @@ -31,7 +35,7 @@ class SimpleBase: self.buffer = deque() self.consumer.register_callback(self._receive) - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: float = None) -> MessageT: if not block: return self.get_nowait() self._consume() @@ -49,14 +53,18 @@ class SimpleBase: elapsed += monotonic() - time_start remaining = timeout and timeout - elapsed or None - def get_nowait(self): + def get_nowait(self) -> MessageT: m = self.queue.get(no_ack=self.no_ack) if not m: raise self.Empty() return m - def put(self, message, serializer=None, headers=None, compression=None, - routing_key=None, **kwargs): + def put(self, message: Any, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + routing_key: str = None, + **kwargs) -> None: self.producer.publish(message, serializer=serializer, routing_key=routing_key, @@ -64,61 +72,72 @@ class SimpleBase: compression=compression, **kwargs) - def clear(self): + def clear(self) -> int: return self.consumer.purge() - def qsize(self): + def qsize(self) -> int: _, size, _ = self.queue.queue_declare(passive=True) return size - def close(self): + def close(self) -> None: self.consumer.cancel() - def _receive(self, message_data, message): + def _receive(self, message_data: Any, message: MessageT) -> None: self.buffer.append(message) - def _consume(self): + def _consume(self) -> None: if not self._consuming: self.consumer.consume(no_ack=self.no_ack) self._consuming = True - def __len__(self): + def __len__(self) -> int: """`len(self) -> self.qsize()`.""" return self.qsize() - def __bool__(self): + def __bool__(self) -> bool: return True class SimpleQueue(SimpleBase): """Simple API for persistent queues.""" - no_ack = False - queue_opts = {} - exchange_opts = {'type': 'direct'} - - def __init__(self, channel, name, no_ack=None, queue_opts=None, - exchange_opts=None, serializer=None, - compression=None, **kwargs): + no_ack: bool = False + queue_opts: Mapping = {} + exchange_opts: Mapping = {'type': 'direct'} + + def __init__( + self, + channel: ChannelArgT, + name: str, + no_ack: bool = None, + queue_opts: Mapping = None, + exchange_opts: Mapping = None, + serializer: str = None, + compression: str = None, + **kwargs) -> None: queue = name - queue_opts = dict(self.queue_opts, **queue_opts or {}) - exchange_opts = dict(self.exchange_opts, **exchange_opts or {}) + all_queue_opts = dict(self.queue_opts) + if queue_opts: + all_queue_opts.update(queue_opts) + all_exchange_opts = dict(self.exchange_opts) + if exchange_opts: + all_exchange_opts.update(exchange_opts) if no_ack is None: no_ack = self.no_ack if not isinstance(queue, entity.Queue): - exchange = entity.Exchange(name, **exchange_opts) - queue = entity.Queue(name, exchange, name, **queue_opts) + exchange = entity.Exchange(name, **all_exchange_opts) + queue = entity.Queue(name, exchange, name, **all_queue_opts) routing_key = name else: name = queue.name exchange = queue.exchange routing_key = queue.routing_key - consumer = messaging.Consumer(channel, queue) + consumer = messaging.Consumer(channel, [queue]) producer = messaging.Producer(channel, exchange, serializer=serializer, routing_key=routing_key, compression=compression) - super().__init__(channel, producer, consumer, no_ack, **kwargs) + super().__init__(channel, producer, consumer, no_ack) class SimpleBuffer(SimpleQueue): diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 0b1f3f35..059de82f 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -1,20 +1,19 @@ """Base transport interface.""" -import amqp.abstract +import amqp.types import errno import socket - +from typing import Any, Callable, ChannelT, Dict, Mapping, Sequence, Tuple from amqp.exceptions import RecoverableConnectionError - from kombu.exceptions import ChannelError, ConnectionError -from kombu.five import items from kombu.message import Message +from kombu.types import ClientT, ConnectionT, ConsumerT, ProducerT, TransportT from kombu.utils.functional import dictfilter from kombu.utils.objects import cached_property from kombu.utils.time import maybe_s_to_ms __all__ = ['Message', 'StdChannel', 'Management', 'Transport'] -RABBITMQ_QUEUE_ARGUMENTS = { # type: Mapping[str, Tuple[str, Callable]] +RABBITMQ_QUEUE_ARGUMENTS: Mapping[str, Tuple[str, Callable]] = { 'expires': ('x-expires', maybe_s_to_ms), 'message_ttl': ('x-message-ttl', maybe_s_to_ms), 'max_length': ('x-max-length', int), @@ -23,8 +22,7 @@ RABBITMQ_QUEUE_ARGUMENTS = { # type: Mapping[str, Tuple[str, Callable]] } -def to_rabbitmq_queue_arguments(arguments, **options): - # type: (Mapping, **Any) -> Dict +def to_rabbitmq_queue_arguments(arguments: Mapping, **options) -> Dict: """Convert queue arguments to RabbitMQ queue arguments. This is the implementation for Channel.prepare_queue_arguments @@ -52,40 +50,39 @@ def to_rabbitmq_queue_arguments(arguments, **options): """ prepared = dictfilter(dict( _to_rabbitmq_queue_argument(key, value) - for key, value in items(options) + for key, value in options.items() )) return dict(arguments, **prepared) if prepared else arguments -def _to_rabbitmq_queue_argument(key, value): - # type: (str, Any) -> Tuple[str, Any] +def _to_rabbitmq_queue_argument(key: str, value: Any) -> Tuple[str, Any]: opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key] return opt, typ(value) if value is not None else value -def _LeftBlank(obj, method): +def _LeftBlank(obj: Any, method: str) -> Exception: return NotImplementedError( 'Transport {0.__module__}.{0.__name__} does not implement {1}'.format( obj.__class__, method)) -class StdChannel: +class StdChannel(amqp.types.ChannelT): """Standard channel base class.""" no_ack_consumers = None - def Consumer(self, *args, **kwargs): + def Consumer(self, *args, **kwargs) -> ConsumerT: from kombu.messaging import Consumer return Consumer(self, *args, **kwargs) - def Producer(self, *args, **kwargs): + def Producer(self, *args, **kwargs) -> ProducerT: from kombu.messaging import Producer return Producer(self, *args, **kwargs) - def get_bindings(self): + async def get_bindings(self) -> Sequence[Mapping]: raise _LeftBlank(self, 'get_bindings') - def after_reply_message_received(self, queue): + async def after_reply_message_received(self, queue: str) -> None: """Callback called after RPC reply received. Notes: @@ -94,42 +91,39 @@ class StdChannel: """ ... - def prepare_queue_arguments(self, arguments, **kwargs): + def prepare_queue_arguments(self, arguments: Mapping, **kwargs) -> Mapping: return arguments - def __enter__(self): + def __enter__(self) -> ChannelT: return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info) -> None: self.close() -amqp.abstract.Channel.register(StdChannel) - - class Management: """AMQP Management API (incomplete).""" - def __init__(self, transport): + def __init__(self, transport: TransportT): self.transport = transport - def get_bindings(self): + async def get_bindings(self) -> Sequence[Mapping]: raise _LeftBlank(self, 'get_bindings') class Implements(dict): """Helper class used to define transport features.""" - def __getattr__(self, key): + def __getattr__(self, key: str) -> bool: try: return self[key] except KeyError: raise AttributeError(key) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: bool) -> None: self[key] = value - def extend(self, **kwargs): + def extend(self, **kwargs) -> 'Implements': return self.__class__(self, **kwargs) @@ -140,65 +134,66 @@ default_transport_capabilities = Implements( ) -class Transport: +class Transport(amqp.types.ConnectionT): """Base class for transports.""" - Management = Management + Management: type = Management #: The :class:`~kombu.Connection` owning this instance. - client = None + client: ClientT = None #: Set to True if :class:`~kombu.Connection` should pass the URL #: unmodified. - can_parse_url = False + can_parse_url: bool = False #: Default port used when no port has been specified. - default_port = None + default_port: int = None #: Tuple of errors that can happen due to connection failure. - connection_errors = (ConnectionError,) + connection_errors: Tuple[type, ...] = (ConnectionError,) #: Tuple of errors that can happen due to channel/method failure. - channel_errors = (ChannelError,) + channel_errors: Tuple[type, ...] = (ChannelError,) #: Type of driver, can be used to separate transports #: using the AMQP protocol (driver_type: 'amqp'), #: Redis (driver_type: 'redis'), etc... - driver_type = 'N/A' + driver_type: str = 'N/A' #: Name of driver library (e.g. 'py-amqp', 'redis'). - driver_name = 'N/A' + driver_name: str = 'N/A' __reader = None implements = default_transport_capabilities.extend() - def __init__(self, client, **kwargs): + def __init__(self, client: ClientT, **kwargs) -> None: self.client = client - def establish_connection(self): + async def establish_connection(self) -> ConnectionT: raise _LeftBlank(self, 'establish_connection') - def close_connection(self, connection): + async def close_connection(self, connection: ConnectionT) -> None: raise _LeftBlank(self, 'close_connection') - def create_channel(self, connection): + def create_channel(self, connection: ConnectionT) -> ChannelT: raise _LeftBlank(self, 'create_channel') - def close_channel(self, connection): + async def close_channel(self, connection: ConnectionT) -> None: raise _LeftBlank(self, 'close_channel') - def drain_events(self, connection, **kwargs): + async def drain_events(self, connection: ConnectionT, **kwargs) -> None: raise _LeftBlank(self, 'drain_events') - def heartbeat_check(self, connection, rate=2): + async def heartbeat_check(self, connection: ConnectionT, + rate: int = 2) -> None: ... - def driver_version(self): + def driver_version(self) -> str: return 'N/A' - def get_heartbeat_interval(self, connection): - return 0 + def get_heartbeat_interval(self, connection: ConnectionT) -> float: + return 0.0 def register_with_event_loop(self, connection, loop): ... @@ -206,7 +201,7 @@ class Transport: def unregister_from_event_loop(self, connection, loop): ... - def verify_connection(self, connection): + def verify_connection(self, connection: ConnectionT) -> bool: return True def _make_reader(self, connection, timeout=socket.timeout, @@ -228,7 +223,7 @@ class Transport: return _read - def qos_semantics_matches_spec(self, connection): + def qos_semantics_matches_spec(self, connection: ConnectionT) -> bool: return True def on_readable(self, connection, loop): @@ -238,20 +233,20 @@ class Transport: reader(loop) @property - def default_connection_params(self): + def default_connection_params(self) -> Mapping: return {} - def get_manager(self, *args, **kwargs): + def get_manager(self, *args, **kwargs) -> Management: return self.Management(self) @cached_property - def manager(self): + def manager(self) -> Management: return self.get_manager() @property - def supports_heartbeats(self): + def supports_heartbeats(self) -> bool: return self.implements.heartbeats @property - def supports_ev(self): + def supports_ev(self) -> bool: return self.implements.async diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py index 9ba7d8b9..b00c63a9 100644 --- a/kombu/transport/pyamqp.py +++ b/kombu/transport/pyamqp.py @@ -96,14 +96,14 @@ class Transport(base.Transport): def create_channel(self, connection): return connection.channel() - def drain_events(self, connection, **kwargs): - return connection.drain_events(**kwargs) + async def drain_events(self, connection, **kwargs): + await connection.drain_events(**kwargs) def _collect(self, connection): if connection is not None: connection.collect() - def establish_connection(self): + async def establish_connection(self): """Establish connection to the AMQP broker.""" conninfo = self.client for name, default_value in self.default_connection_params.items(): @@ -124,16 +124,16 @@ class Transport(base.Transport): }, **conninfo.transport_options or {}) conn = self.Connection(**opts) conn.client = self.client - conn.connect() + await conn.connect() return conn def verify_connection(self, connection): return connection.connected - def close_connection(self, connection): + async def close_connection(self, connection): """Close the AMQP broker connection.""" connection.client = None - connection.close() + await connection.close() def get_heartbeat_interval(self, connection): return connection.heartbeat @@ -142,8 +142,8 @@ class Transport(base.Transport): connection.transport.raise_on_initial_eintr = True loop.add_reader(connection.sock, self.on_readable, connection, loop) - def heartbeat_check(self, connection, rate=2): - return connection.heartbeat_tick(rate=rate) + async def heartbeat_check(self, connection, rate=2): + await connection.heartbeat_tick(rate=rate) def qos_semantics_matches_spec(self, connection): props = connection.server_properties diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index 6960104c..b3bf63f3 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, unicode_literals - from .base import ( Base64, NotEquivalentError, UndeliverableWarning, BrokerState, QoS, Message, AbstractChannel, Channel, Management, Transport, diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index 75bca055..052a620d 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -2,24 +2,30 @@ Emulates the AMQ API for non-AMQ transports. """ -from __future__ import absolute_import, print_function, unicode_literals - +import abc import base64 import socket import sys import warnings from array import array -from collections import OrderedDict, defaultdict, namedtuple +from asyncio import sleep +from collections import OrderedDict, defaultdict from itertools import count from multiprocessing.util import Finalize -from time import sleep +from queue import Empty +from time import monotonic +from typing import ( + Any, AnyStr, Callable, IO, Iterable, List, Mapping, MutableMapping, + NamedTuple, Optional, Set, Sequence, Tuple, +) from amqp.protocol import queue_declare_ok_t +from amqp.types import ChannelT, ConnectionT from kombu.exceptions import ResourceError, ChannelError -from kombu.five import Empty, items, monotonic from kombu.log import get_logger +from kombu.types import ClientT, MessageT, TransportT from kombu.utils.encoding import str_to_bytes, bytes_to_str from kombu.utils.div import emergency_dump_state from kombu.utils.scheduling import FairCycle @@ -27,7 +33,7 @@ from kombu.utils.uuid import uuid from kombu.transport import base -from .exchange import STANDARD_EXCHANGE_TYPES +from .exchange import STANDARD_EXCHANGE_TYPES, ExchangeType ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H' @@ -50,24 +56,41 @@ RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}' logger = get_logger(__name__) -#: Key format used for queue argument lookups in BrokerState.bindings. -binding_key_t = namedtuple('binding_key_t', ( - 'queue', 'exchange', 'routing_key', -)) -#: BrokerState.queue_bindings generates tuples in this format. -queue_binding_t = namedtuple('queue_binding_t', ( - 'exchange', 'routing_key', 'arguments', -)) +class binding_key_t(NamedTuple): + """Key format used for queue argument lookups in BrokerState.bindings.""" + + queue: str + exchange: str + routing_key: str + + +class queue_binding_t(NamedTuple): + """BrokerState.queue_bindings generates tuples in this format.""" + + exchange: str + routing_key: str + arguments: Mapping -class Base64(object): +class Codec(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def encode(self, s: AnyStr) -> str: + ... + + @abc.abstractmethod + def decode(self, s: AnyStr) -> str: + ... + + +class Base64(Codec): """Base64 codec.""" - def encode(self, s): + def encode(self, s: AnyStr) -> str: return bytes_to_str(base64.b64encode(str_to_bytes(s))) - def decode(self, s): + def decode(self, s: AnyStr) -> str: return base64.b64decode(str_to_bytes(s)) @@ -84,7 +107,7 @@ class BrokerState(object): #: Mapping of exchange name to #: :class:`kombu.transport.virtual.exchange.ExchangeType` - exchanges = None + exchanges: Mapping[str, ExchangeType] = None #: This is the actual bindings registry, used to store bindings and to #: test 'in' relationships in constant time. It has the following @@ -94,7 +117,7 @@ class BrokerState(object): #: (queue, exchange, routing_key): arguments, #: # ..., #: } - bindings = None + bindings: Mapping[binding_key_t, Mapping] = None #: The queue index is used to access directly (constant time) #: all the bindings of a certain queue. It has the following structure:: @@ -105,27 +128,32 @@ class BrokerState(object): #: }, #: # ..., #: } - queue_index = None + queue_index: Mapping[str, binding_key_t] = None - def __init__(self, exchanges=None): + def __init__(self, exchanges: Mapping[str, ExchangeType] = None) -> None: self.exchanges = {} if exchanges is None else exchanges self.bindings = {} self.queue_index = defaultdict(set) - def clear(self): + def clear(self) -> None: self.exchanges.clear() self.bindings.clear() self.queue_index.clear() - def has_binding(self, queue, exchange, routing_key): + def has_binding(self, queue: str, exchange: str, routing_key: str) -> bool: return (queue, exchange, routing_key) in self.bindings - def binding_declare(self, queue, exchange, routing_key, arguments): + def binding_declare(self, + queue: str, + exchange: str, + routing_key: str, + arguments: Mapping) -> None: key = binding_key_t(queue, exchange, routing_key) self.bindings.setdefault(key, arguments) self.queue_index[queue].add(key) - def binding_delete(self, queue, exchange, routing_key): + def binding_delete( + self, queue: str, exchange: str, routing_key: str) -> None: key = binding_key_t(queue, exchange, routing_key) try: del self.bindings[key] @@ -134,7 +162,7 @@ class BrokerState(object): else: self.queue_index[queue].remove(key) - def queue_bindings_delete(self, queue): + def queue_bindings_delete(self, queue: str) -> None: try: bindings = self.queue_index.pop(queue) except KeyError: @@ -142,7 +170,7 @@ class BrokerState(object): else: [self.bindings.pop(binding, None) for binding in bindings] - def queue_bindings(self, queue): + def queue_bindings(self, queue: str) -> Iterable[binding_key_t]: return ( queue_binding_t(key.exchange, key.routing_key, self.bindings[key]) for key in self.queue_index[queue] @@ -160,22 +188,22 @@ class QoS(object): """ #: current prefetch count value - prefetch_count = 0 + prefetch_count: int = 0 #: :class:`~collections.OrderedDict` of active messages. #: *NOTE*: Can only be modified by the consuming thread. - _delivered = None + _delivered: OrderedDict = None #: acks can be done by other threads than the consuming thread. #: Instead of a mutex, which doesn't perform well here, we mark #: the delivery tags as dirty, so subsequent calls to append() can remove #: them. - _dirty = None + _dirty: Set[MessageT] = None #: If disabled, unacked messages won't be restored at shutdown. - restore_at_shutdown = True + restore_at_shutdown: bool = True - def __init__(self, channel, prefetch_count=0): + def __init__(self, channel: ChannelT, prefetch_count: int = 0) -> None: self.channel = channel self.prefetch_count = prefetch_count or 0 @@ -188,7 +216,7 @@ class QoS(object): self, self.restore_unacked_once, exitpriority=1, ) - def can_consume(self): + def can_consume(self) -> bool: """Return true if the channel can be consumed from. Used to ensure the client adhers to currently active @@ -197,7 +225,7 @@ class QoS(object): pcount = self.prefetch_count return not pcount or len(self._delivered) - len(self._dirty) < pcount - def can_consume_max_estimate(self): + def can_consume_max_estimate(self) -> int: """Return the maximum number of messages allowed to be returned. Returns an estimated number of messages that a consumer may be allowed @@ -212,16 +240,16 @@ class QoS(object): if pcount: return max(pcount - (len(self._delivered) - len(self._dirty)), 0) - def append(self, message, delivery_tag): + def append(self, message: MessageT, delivery_tag: str) -> None: """Append message to transactional state.""" if self._dirty: self._flush() self._quick_append(delivery_tag, message) - def get(self, delivery_tag): + def get(self, delivery_tag: str) -> MessageT: return self._delivered[delivery_tag] - def _flush(self): + def _flush(self) -> None: """Flush dirty (acked/rejected) tags from.""" dirty = self._dirty delivered = self._delivered @@ -232,17 +260,17 @@ class QoS(object): break delivered.pop(dirty_tag, None) - def ack(self, delivery_tag): + def ack(self, delivery_tag: str) -> None: """Acknowledge message and remove from transactional state.""" self._quick_ack(delivery_tag) - def reject(self, delivery_tag, requeue=False): + def reject(self, delivery_tag: str, requeue: bool = False) -> None: """Remove from transactional state and requeue message.""" if requeue: self.channel._restore_at_beginning(self._delivered[delivery_tag]) self._quick_ack(delivery_tag) - def restore_unacked(self): + async def restore_unacked(self) -> None: """Restore all unacknowledged messages.""" self._flush() delivered = self._delivered @@ -257,13 +285,13 @@ class QoS(object): break try: - restore(message) + await restore(message) except BaseException as exc: errors.append((exc, message)) delivered.clear() return errors - def restore_unacked_once(self, stderr=None): + async def restore_unacked_once(self, stderr: IO = None) -> None: """Restore all unacknowledged messages at shutdown/gc collect. Note: @@ -284,7 +312,7 @@ class QoS(object): if state: print(RESTORING_FMT.format(len(self._delivered)), file=stderr) - unrestored = self.restore_unacked() + unrestored = await self.restore_unacked() if unrestored: errors, messages = list(zip(*unrestored)) @@ -294,7 +322,7 @@ class QoS(object): finally: state.restored = True - def restore_visible(self, *args, **kwargs): + async def restore_visible(self, *args, **kwargs) -> None: """Restore any pending unackwnowledged messages. To be filled in for visibility_timeout style implementations. @@ -303,13 +331,14 @@ class QoS(object): This is implementation optional, and currently only used by the Redis transport. """ - pass + ... class Message(base.Message): """Message object.""" - def __init__(self, payload, channel=None, **kwargs): + def __init__(self, payload: Any, + channel: ChannelT = None, **kwargs) -> None: self._raw = payload properties = payload['properties'] body = payload.get('body') @@ -327,7 +356,7 @@ class Message(base.Message): postencode='utf-8', **kwargs) - def serializable(self): + def serializable(self) -> Mapping: props = self.properties body, _ = self.channel.encode_body(self.body, props.get('body_encoding')) @@ -354,23 +383,23 @@ class AbstractChannel(object): from :class:`Channel`. """ - def _get(self, queue, timeout=None): + async def _get(self, queue: str, timeout: float = None) -> MessageT: """Get next message from `queue`.""" raise NotImplementedError('Virtual channels must implement _get') - def _put(self, queue, message): + async def _put(self, queue: str, message: MessageT) -> None: """Put `message` onto `queue`.""" raise NotImplementedError('Virtual channels must implement _put') - def _purge(self, queue): + async def _purge(self, queue: str) -> int: """Remove all messages from `queue`.""" raise NotImplementedError('Virtual channels must implement _purge') - def _size(self, queue): + async def _size(self, queue: str) -> int: """Return the number of messages in `queue` as an :class:`int`.""" return 0 - def _delete(self, queue, *args, **kwargs): + async def _delete(self, queue: str, *args, **kwargs) -> None: """Delete `queue`. Note: @@ -379,16 +408,16 @@ class AbstractChannel(object): """ self._purge(queue) - def _new_queue(self, queue, **kwargs): + async def _new_queue(self, queue: str, **kwargs) -> None: """Create new queue. Note: Your transport can override this method if it needs to do something whenever a new queue is declared. """ - pass + ... - def _has_queue(self, queue, **kwargs): + async def _has_queue(self, queue: str, **kwargs) -> bool: """Verify that queue exists. Returns: @@ -397,13 +426,14 @@ class AbstractChannel(object): """ return True - def _poll(self, cycle, callback, timeout=None): + async def _poll(self, cycle: Any, callback: Callable, + timeout: float = None) -> Any: """Poll a list of queues for available messages.""" - return cycle.get(callback) + return await cycle.get(callback) - def _get_and_deliver(self, queue, callback): - message = self._get(queue) - callback(message, queue) + async def _get_and_deliver(self, queue: str, callback: Callable) -> None: + message = await self._get(queue) + await callback(message, queue) class Channel(AbstractChannel, base.StdChannel): @@ -415,44 +445,56 @@ class Channel(AbstractChannel, base.StdChannel): """ #: message class used. - Message = Message + Message: type = Message #: QoS class used. - QoS = QoS + QoS: type = QoS #: flag to restore unacked messages when channel #: goes out of scope. - do_restore = True + do_restore: bool = True #: mapping of exchange types and corresponding classes. - exchange_types = dict(STANDARD_EXCHANGE_TYPES) + exchange_type_classes: Mapping[str, type] = dict( + STANDARD_EXCHANGE_TYPES) + + exchange_types: Mapping[str, ExchangeType] = None #: flag set if the channel supports fanout exchanges. - supports_fanout = False + supports_fanout: bool = False #: Binary <-> ASCII codecs. - codecs = {'base64': Base64()} + codecs: Mapping[str, Codec] = {'base64': Base64()} #: Default body encoding. #: NOTE: ``transport_options['body_encoding']`` will override this value. - body_encoding = 'base64' + body_encoding: str = 'base64' #: counter used to generate delivery tags for this channel. _delivery_tags = count(1) #: Optional queue where messages with no route is delivered. #: Set by ``transport_options['deadletter_queue']``. - deadletter_queue = None + deadletter_queue: str = None # List of options to transfer from :attr:`transport_options`. - from_transport_options = ('body_encoding', 'deadletter_queue') + from_transport_options: Tuple[str, ...] = ( + 'body_encoding', + 'deadletter_queue', + ) # Priority defaults default_priority = 0 min_priority = 0 max_priority = 9 - def __init__(self, connection, **kwargs): + _consumers: Set[str] + _tag_to_queue: Mapping[str, str] + _active_queues: List[str] + closed: bool = False + _qos: QoS = None + + def __init__(self, connection: ConnectionT, **kwargs) -> None: self.connection = connection self._consumers = set() self._cycle = None @@ -462,9 +504,9 @@ class Channel(AbstractChannel, base.StdChannel): self.closed = False # instantiate exchange types - self.exchange_types = dict( - (typ, cls(self)) for typ, cls in items(self.exchange_types) - ) + self.exchange_types = { + typ: cls(self) for typ, cls in self.exchange_type_classes.items() + } try: self.channel_id = self.connection._avail_channel_ids.pop() @@ -482,9 +524,15 @@ class Channel(AbstractChannel, base.StdChannel): except KeyError: pass - def exchange_declare(self, exchange=None, type='direct', durable=False, - auto_delete=False, arguments=None, - nowait=False, passive=False): + async def exchange_declare( + self, + exchange: str = None, + type: str = 'direct', + durable: bool = False, + auto_delete: bool = False, + arguments: Mapping = None, + nowait: bool = False, + passive: bool = False) -> None: """Declare exchange.""" type = type or 'direct' exchange = exchange or 'amq.%s' % type @@ -512,49 +560,74 @@ class Channel(AbstractChannel, base.StdChannel): 'table': [], } - def exchange_delete(self, exchange, if_unused=False, nowait=False): + async def exchange_delete( + self, + exchange: str, + if_unused: bool = False, + nowait: bool = False) -> None: """Delete `exchange` and all its bindings.""" for rkey, _, queue in self.get_table(exchange): - self.queue_delete(queue, if_unused=True, if_empty=True) + await self.queue_delete(queue, if_unused=True, if_empty=True) self.state.exchanges.pop(exchange, None) - def queue_declare(self, queue=None, passive=False, **kwargs): + async def queue_declare( + self, + queue: str = None, + passive: bool = False, + **kwargs) -> queue_declare_ok_t: """Declare queue.""" queue = queue or 'amq.gen-%s' % uuid() - if passive and not self._has_queue(queue, **kwargs): + if passive and not await self._has_queue(queue, **kwargs): raise ChannelError( 'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format( queue, self.connection.client.virtual_host or '/'), (50, 10), 'Channel.queue_declare', '404', ) else: - self._new_queue(queue, **kwargs) - return queue_declare_ok_t(queue, self._size(queue), 0) - - def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs): + await self._new_queue(queue, **kwargs) + return queue_declare_ok_t(queue, await self._size(queue), 0) + + async def queue_delete( + self, + queue: str, + if_unused: bool = False, + if_empty: bool = False, + **kwargs) -> None: """Delete queue.""" - if if_empty and self._size(queue): + if if_empty and await self._size(queue): return for exchange, routing_key, args in self.state.queue_bindings(queue): meta = self.typeof(exchange).prepare_bind( queue, exchange, routing_key, args, ) - self._delete(queue, exchange, *meta, **kwargs) + await self._delete(queue, exchange, *meta, **kwargs) self.state.queue_bindings_delete(queue) - def after_reply_message_received(self, queue): - self.queue_delete(queue) + async def after_reply_message_received(self, queue: str) -> None: + await self.queue_delete(queue) - def exchange_bind(self, destination, source='', routing_key='', - nowait=False, arguments=None): + async def exchange_bind( + self, destination: str, + source: str = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None) -> None: raise NotImplementedError('transport does not support exchange_bind') - def exchange_unbind(self, destination, source='', routing_key='', - nowait=False, arguments=None): + async def exchange_unbind( + self, destination: str, + source: str = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None) -> None: raise NotImplementedError('transport does not support exchange_unbind') - def queue_bind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): + async def queue_bind( + self, queue: str, + exchange: str = None, + routing_key: str = '', + arguments: Mapping = None, + **kwargs) -> None: """Bind `queue` to `exchange` with `routing key`.""" exchange = exchange or 'amq.direct' if self.state.has_binding(queue, exchange, routing_key): @@ -568,10 +641,14 @@ class Channel(AbstractChannel, base.StdChannel): ) table.append(meta) if self.supports_fanout: - self._queue_bind(exchange, *meta) - - def queue_unbind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): + await self._queue_bind(exchange, *meta) + + async def queue_unbind( + self, queue: str, + exchange: str = None, + routing_key: str = '', + arguments: Mapping = None, + **kwargs) -> None: # Remove queue binding: self.state.binding_delete(queue, exchange, routing_key) try: @@ -585,19 +662,21 @@ class Channel(AbstractChannel, base.StdChannel): # Should be optimized. Modifying table in place. table[:] = [meta for meta in table if meta != binding_meta] - def list_bindings(self): - return ((queue, exchange, rkey) + async def list_bindings(self) -> Iterable[queue_binding_t]: + return (queue_binding_t(queue, exchange, rkey) for exchange in self.state.exchanges for rkey, pattern, queue in self.get_table(exchange)) - def queue_purge(self, queue, **kwargs): + async def queue_purge(self, queue: str, **kwargs) -> int: """Remove all ready messages from queue.""" - return self._purge(queue) + return await self._purge(queue) - def _next_delivery_tag(self): + def _next_delivery_tag(self) -> str: return uuid() - def basic_publish(self, message, exchange, routing_key, **kwargs): + async def basic_publish( + self, message: MessageT, exchange: str, routing_key: str, + **kwargs) -> None: """Publish message.""" self._inplace_augment_message(message, exchange, routing_key) if exchange: @@ -605,9 +684,10 @@ class Channel(AbstractChannel, base.StdChannel): message, exchange, routing_key, **kwargs ) # anon exchange: routing_key is the destination queue - return self._put(routing_key, message, **kwargs) + return await self._put(routing_key, message, **kwargs) - def _inplace_augment_message(self, message, exchange, routing_key): + def _inplace_augment_message( + self, message: MessageT, exchange: str, routing_key: str) -> None: message['body'], body_encoding = self.encode_body( message['body'], self.body_encoding, ) @@ -621,23 +701,28 @@ class Channel(AbstractChannel, base.StdChannel): routing_key=routing_key, ) - def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): + async def basic_consume( + self, queue: str, + no_ack: bool = False, + callback: Callable = None, + consumer_tag: str = None, + **kwargs) -> None: """Consume from `queue`.""" self._tag_to_queue[consumer_tag] = queue self._active_queues.append(queue) - def _callback(raw_message): + async def _callback(raw_message): message = self.Message(raw_message, channel=self) if not no_ack: self.qos.append(message, message.delivery_tag) - return callback(message) + return await callback(message) self.connection._callbacks[queue] = _callback self._consumers.add(consumer_tag) self._reset_cycle() - def basic_cancel(self, consumer_tag): + await def basic_cancel(self, consumer_tag: str) -> None: """Cancel consumer by consumer tag.""" if consumer_tag in self._consumers: self._consumers.remove(consumer_tag) @@ -649,32 +734,38 @@ class Channel(AbstractChannel, base.StdChannel): pass self.connection._callbacks.pop(queue, None) - def basic_get(self, queue, no_ack=False, **kwargs): + async def basic_get( + self, queue: str, + no_ack: bool = False, + **kwargs) -> Optional[MessageT]: """Get message by direct access (synchronous).""" try: - message = self.Message(self._get(queue), channel=self) + message = self.Message(await self._get(queue), channel=self) if not no_ack: self.qos.append(message, message.delivery_tag) return message except Empty: pass - def basic_ack(self, delivery_tag, multiple=False): + async def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None: """Acknowledge message.""" self.qos.ack(delivery_tag) - def basic_recover(self, requeue=False): + async def basic_recover(self, requeue: bool = False) -> None: """Recover unacked messages.""" if requeue: - return self.qos.restore_unacked() + return await self.qos.restore_unacked() raise NotImplementedError('Does not support recover(requeue=False)') - def basic_reject(self, delivery_tag, requeue=False): + async def basic_reject(self, delivery_tag: str, requeue: bool = False) -> None: """Reject message.""" self.qos.reject(delivery_tag, requeue=requeue) - def basic_qos(self, prefetch_size=0, prefetch_count=0, - apply_global=False): + async def basic_qos( + self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: """Change QoS settings for this channel. Note: @@ -682,14 +773,14 @@ class Channel(AbstractChannel, base.StdChannel): """ self.qos.prefetch_count = prefetch_count - def get_exchanges(self): + def get_exchanges(self) -> Sequence[str]: return list(self.state.exchanges) - def get_table(self, exchange): + def get_table(self, exchange: str) -> Mapping: """Get table of bindings for `exchange`.""" return self.state.exchanges[exchange]['table'] - def typeof(self, exchange, default='direct'): + def typeof(self, exchange: str, default: str = 'direct') -> ExchangeType: """Get the exchange type instance for `exchange`.""" try: type = self.state.exchanges[exchange]['type'] @@ -697,7 +788,9 @@ class Channel(AbstractChannel, base.StdChannel): type = default return self.exchange_types[type] - def _lookup(self, exchange, routing_key, default=None): + def _lookup( + self, exchange: str, routing_key: str, + default: str = None) -> Sequence[str]: """Find all queues matching `routing_key` for the given `exchange`. Returns: @@ -725,34 +818,42 @@ class Channel(AbstractChannel, base.StdChannel): R = [default] return R - def _restore(self, message): + async def _restore(self, message: MessageT) -> None: """Redeliver message to its original destination.""" delivery_info = message.delivery_info message = message.serializable() message['redelivered'] = True for queue in self._lookup( delivery_info['exchange'], delivery_info['routing_key']): - self._put(queue, message) + await self._put(queue, message) - def _restore_at_beginning(self, message): - return self._restore(message) + async def _restore_at_beginning(self, message: MessageT) -> None: + await self._restore(message) - def drain_events(self, timeout=None, callback=None): + async def drain_events( + self, + timeout: float = None, + callback: Callable = None) -> None: callback = callback or self.connection._deliver if self._consumers and self.qos.can_consume(): if hasattr(self, '_get_many'): - return self._get_many(self._active_queues, timeout=timeout) - return self._poll(self.cycle, callback, timeout=timeout) + await self._get_many(self._active_queues, timeout=timeout) + else: + await self._poll(self.cycle, callback, timeout=timeout) raise Empty() - def message_to_python(self, raw_message): + def message_to_python(self, raw_message: Any) -> MessageT: """Convert raw message to :class:`Message` instance.""" if not isinstance(raw_message, self.Message): return self.Message(payload=raw_message, channel=self) return raw_message - def prepare_message(self, body, priority=None, content_type=None, - content_encoding=None, headers=None, properties=None): + def prepare_message(self, body: Any, + priority: int = None, + content_type: str = None, + content_encoding: str = None, + headers: Mapping = None, + properties: Mapping = None) -> Mapping: """Prepare message data.""" properties = properties or {} properties.setdefault('delivery_info', {}) @@ -764,7 +865,7 @@ class Channel(AbstractChannel, base.StdChannel): 'headers': headers or {}, 'properties': properties or {}} - def flow(self, active=True): + async def flow(self, active: bool = True) -> None: """Enable/disable message flow. Raises: @@ -773,7 +874,7 @@ class Channel(AbstractChannel, base.StdChannel): """ raise NotImplementedError('virtual channels do not support flow.') - def close(self): + async def close(self) -> None: """Close channel. Cancel all consumers, and requeue unacked messages. @@ -781,55 +882,62 @@ class Channel(AbstractChannel, base.StdChannel): if not self.closed: self.closed = True for consumer in list(self._consumers): - self.basic_cancel(consumer) + await self.basic_cancel(consumer) if self._qos: - self._qos.restore_unacked_once() + await self._qos.restore_unacked_once() if self._cycle is not None: self._cycle.close() self._cycle = None if self.connection is not None: - self.connection.close_channel(self) + await self.connection.close_channel(self) self.exchange_types = None - def encode_body(self, body, encoding=None): + def encode_body(self, body: Any, encoding: str = None) -> Tuple[Any, str]: if encoding: return self.codecs.get(encoding).encode(body), encoding return body, encoding - def decode_body(self, body, encoding=None): + def decode_body(self, body: Any, encoding: str = None) -> Any: if encoding: return self.codecs.get(encoding).decode(body) return body - def _reset_cycle(self): + def _reset_cycle(self) -> None: self._cycle = FairCycle( self._get_and_deliver, self._active_queues, Empty) - def __enter__(self): + def __enter__(self) -> 'Channel': return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info) -> None: self.close() + async def __aenter__(self) -> 'Channel': + return self + + async def __aexit__(self, *exc_info) -> None: + ... + @property - def state(self): + def state(self) -> BrokerState: """Broker state containing exchanges and bindings.""" return self.connection.state @property - def qos(self): + def qos(self) -> QoS: """:class:`QoS` manager for this channel.""" if self._qos is None: self._qos = self.QoS(self) return self._qos @property - def cycle(self): + def cycle(self) -> FairCycle: if self._cycle is None: self._reset_cycle() return self._cycle - def _get_message_priority(self, message, reverse=False): + def _get_message_priority(self, message: MessageT, + *, reverse: bool = False) -> int: """Get priority from message. The value is limited to within a boundary of 0 to 9. @@ -852,15 +960,15 @@ class Channel(AbstractChannel, base.StdChannel): class Management(base.Management): """Base class for the AMQP management API.""" - def __init__(self, transport): + def __init__(self, transport: TransportT) -> None: super(Management, self).__init__(transport) self.channel = transport.client.channel() - def get_bindings(self): + def get_bindings(self) -> Sequence[Mapping]: return [dict(destination=q, source=e, routing_key=r) for q, e, r in self.channel.list_bindings()] - def close(self): + def close(self) -> None: self.channel.close() @@ -871,31 +979,31 @@ class Transport(base.Transport): client (kombu.Connection): The client this is a transport for. """ - Channel = Channel - Cycle = FairCycle - Management = Management + Channel: type = Channel + Cycle: type = FairCycle + Management: type = Management #: Global :class:`BrokerState` containing declared exchanges and bindings. - state = BrokerState() + state: BrokerState = BrokerState() #: :class:`~kombu.utils.scheduling.FairCycle` instance #: used to fairly drain events from channels (set by constructor). - cycle = None + cycle: FairCycle = None #: port number used when no port is specified. - default_port = None + default_port: int = None #: active channels. - channels = None + channels: Sequence[ChannelT] = None #: queue/callback map. - _callbacks = None + _callbacks: MutableMapping[str, Callable] = None #: Time to sleep between unsuccessful polls. - polling_interval = 1.0 + polling_interval: float = 1.0 #: Max number of channels - channel_max = 65535 + channel_max: int = 65535 implements = base.Transport.implements.extend( async=False, @@ -903,7 +1011,7 @@ class Transport(base.Transport): heartbeats=False, ) - def __init__(self, client, **kwargs): + def __init__(self, client: ClientT, **kwargs): self.client = client self.channels = [] self._avail_channels = [] @@ -916,7 +1024,7 @@ class Transport(base.Transport): ARRAY_TYPE_H, range(self.channel_max, 0, -1), ) - def create_channel(self, connection): + def create_channel(self, connection: ConnectionT) -> ChannelT: try: return self._avail_channels.pop() except IndexError: @@ -924,7 +1032,7 @@ class Transport(base.Transport): self.channels.append(channel) return channel - def close_channel(self, channel): + async def close_channel(self, channel: ChannelT) -> None: try: self._avail_channel_ids.append(channel.channel_id) try: @@ -934,14 +1042,14 @@ class Transport(base.Transport): finally: channel.connection = None - def establish_connection(self): + async def establish_connection(self) -> 'Transport': # creates channel to verify connection. # this channel is then used as the next requested channel. # (returned by ``create_channel``). self._avail_channels.append(self.create_channel(self)) return self # for drain events - def close_connection(self, connection): + async def close_connection(self, connection: ConnectionT) -> None: self.cycle.close() for l in self._avail_channels, self.channels: while l: @@ -950,24 +1058,25 @@ class Transport(base.Transport): except LookupError: # pragma: no cover pass else: - channel.close() + await channel.close() - def drain_events(self, connection, timeout=None): + async def drain_events(self, connection: ConnectionT, + timeout: float = None) -> None: time_start = monotonic() get = self.cycle.get polling_interval = self.polling_interval while 1: try: - get(self._deliver, timeout=timeout) + await get(self._deliver, timeout=timeout) except Empty: if timeout is not None and monotonic() - time_start >= timeout: raise socket.timeout() if polling_interval is not None: - sleep(polling_interval) + await sleep(polling_interval) else: break - def _deliver(self, message, queue): + async def _deliver(self, message: MessageT, queue: str) -> None: if not queue: raise KeyError( 'Received message without destination queue: {0}'.format( @@ -980,7 +1089,7 @@ class Transport(base.Transport): else: callback(message) - def _reject_inbound_message(self, raw_message): + def _reject_inbound_message(self, raw_message: Any) -> None: for channel in self.channels: if channel: message = channel.Message(raw_message, channel=channel) @@ -988,16 +1097,18 @@ class Transport(base.Transport): channel.basic_reject(message.delivery_tag, requeue=True) break - def on_message_ready(self, channel, message, queue): + def on_message_ready( + self, channel: ChannelT, message: MessageT, queue: str) -> None: if not queue or queue not in self._callbacks: raise KeyError( 'Message for queue {0!r} without consumers: {1}'.format( queue, message)) self._callbacks[queue](message) - def _drain_channel(self, channel, callback, timeout=None): + def _drain_channel(self, channel: ChannelT, callback: Callable, + timeout: float = None) -> None: return channel.drain_events(callback=callback, timeout=timeout) @property - def default_connection_params(self): - return {'port': self.default_port, 'hostname': 'localhost'} + def default_connection_params(self) -> Mapping: + return {'port': self.default_port, 'hostnaime': 'localhost'} diff --git a/kombu/transport/virtual/exchange.py b/kombu/transport/virtual/exchange.py index 33d63066..98c0e6db 100644 --- a/kombu/transport/virtual/exchange.py +++ b/kombu/transport/virtual/exchange.py @@ -3,10 +3,11 @@ Implementations of the standard exchanges defined by the AMQ protocol (excluding the `headers` exchange). """ - import re - +from typing import Mapping, Match, Pattern, Set, Tuple +from amqp.types import ChannelT from kombu.utils.text import escape_regex +from kombu.types import MessageT class ExchangeType: @@ -18,9 +19,10 @@ class ExchangeType: channel (ChannelT): AMQ Channel. """ - type = None + type: str + channel: ChannelT - def __init__(self, channel): + def __init__(self, channel: ChannelT) -> None: self.channel = channel def lookup(self, table, exchange, routing_key, default): @@ -57,13 +59,18 @@ class DirectExchange(ExchangeType): type = 'direct' - def lookup(self, table, exchange, routing_key, default): + def lookup(self, + table: Mapping, + exchange: str, + routing_key: str, + default: str) -> Set[str]: return { queue for rkey, _, queue in table if rkey == routing_key } - def deliver(self, message, exchange, routing_key, **kwargs): + def deliver(self, message: MessageT, + exchange: str, routing_key: str, **kwargs) -> None: _lookup = self.channel._lookup _put = self.channel._put for queue in _lookup(exchange, routing_key): @@ -81,19 +88,26 @@ class TopicExchange(ExchangeType): type = 'topic' #: map of wildcard to regex conversions - wildcards = {'*': r'.*?[^\.]', - '#': r'.*?'} + wildcards: Mapping[str, str] = { + '*': r'.*?[^\.]', + '#': r'.*?', + } #: compiled regex cache - _compiled = {} + _compiled: Mapping[str, Pattern] = {} - def lookup(self, table, exchange, routing_key, default): + def lookup(self, + table: Mapping, + exchange: str, + routing_key: str, + default: str) -> Set[str]: return { queue for rkey, pattern, queue in table if self._match(pattern, routing_key) } - def deliver(self, message, exchange, routing_key, **kwargs): + def deliver(self, message: MessageT, + exchange: str, routing_key: str, **kwargs) -> None: _lookup = self.channel._lookup _put = self.channel._put deadletter = self.channel.deadletter_queue @@ -101,17 +115,21 @@ class TopicExchange(ExchangeType): if q and q != deadletter]: _put(queue, message, **kwargs) - def prepare_bind(self, queue, exchange, routing_key, arguments): + def prepare_bind(self, + queue: str, + exchange: str, + routing_key: str, + arguments: Mapping) -> Tuple[str, Pattern, str]: return routing_key, self.key_to_pattern(routing_key), queue - def key_to_pattern(self, rkey): + def key_to_pattern(self, rkey: str) -> str: """Get the corresponding regex for any routing key.""" return '^%s$' % ('\.'.join( self.wildcards.get(word, word) for word in escape_regex(rkey, '.#*').split('.') )) - def _match(self, pattern, string): + def _match(self, pattern: str, string: str) -> Match: """Match regular expression (cached). Same as :func:`re.match`, except the regex is compiled and cached, @@ -141,17 +159,19 @@ class FanoutExchange(ExchangeType): type = 'fanout' - def lookup(self, table, exchange, routing_key, default): + def lookup(self, table: Mapping, + exchange: str, routing_key: str, default: str) -> Set[str]: return {queue for _, _, queue in table} - def deliver(self, message, exchange, routing_key, **kwargs): + def deliver(self, message: MessageT, + exchange: str, routing_key: str, **kwargs) -> None: if self.channel.supports_fanout: self.channel._put_fanout( exchange, message, routing_key, **kwargs) #: Map of standard exchange types and corresponding classes. -STANDARD_EXCHANGE_TYPES = { +STANDARD_EXCHANGE_TYPES: Mapping[str, type] = { 'direct': DirectExchange, 'topic': TopicExchange, 'fanout': FanoutExchange, diff --git a/kombu/types.py b/kombu/types.py new file mode 100644 index 00000000..5825fb1f --- /dev/null +++ b/kombu/types.py @@ -0,0 +1,893 @@ +import abc +import logging +from collections import Hashable, deque +from numbers import Number +from typing import ( + Any, Callable, Dict, Iterable, List, Mapping, Optional, + Sequence, TypeVar, Tuple, Union, +) +from typing import Set # noqa +from amqp.types import ChannelT, MessageT as _MessageT, TransportT +from amqp.spec import queue_declare_ok_t + + +def _hasattr(C: Any, attr: str) -> bool: + return any(attr in B.__dict__ for B in C.__mro__) + + +class _AbstractClass(abc.ABCMeta): + __required_attributes__ = frozenset() # type: frozenset + + @classmethod + def _subclasshook_using(cls, parent: Any, C: Any): + return ( + cls is parent and + all(_hasattr(C, attr) for attr in cls.__required_attributes__) + ) or NotImplemented + + +class Revivable(metaclass=_AbstractClass): + + async def revive(self, channel: ChannelT) -> None: + ... + + +class ClientT(Revivable, metaclass=_AbstractClass): + + hostname: str + userid: str + password: str + ssl: Any + login_method: str + port: int = None + virtual_host: str = '/' + connect_timeout: float = 5.0 + alt: Sequence[str] + + uri_prefix: str = None + declared_entities: Set['EntityT'] = None + cycle: Iterable + transport_options: Mapping + failover_strategy: str + heartbeat: float = None + resolve_aliases: Mapping[str, str] = resolve_aliases + failover_strategies: Mapping[str, Callable] = failover_strategies + + @property + @abc.abstractmethod + def default_channel(self) -> ChannelT: + ... + + def __init__( + self, + hostname: str = 'localhost', + userid: str = None, + password: str = None, + virtual_host: str = None, + port: int = None, + insist: bool = False, + ssl: Any = None, + transport: Union[type, str] = None, + connect_timeout: float = 5.0, + transport_options: Mapping = None, + login_method: str = None, + uri_prefix: str = None, + heartbeat: float = None, + failover_strategy: str = 'round-robin', + alternates: Sequence[str] = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + def switch(self, url: str) -> None: + ... + + @abc.abstractmethod + def maybe_switch_next(self) -> None: + ... + + @abc.abstractmethod + async def connect(self) -> None: + ... + + @abc.abstractmethod + def channel(self) -> ChannelT: + ... + + @abc.abstractmethod + async def heartbeat_check(self, rate: int = 2) -> None: + ... + + @abc.abstractmethod + async def drain_events(self, timeout: float = None, **kwargs) -> None: + ... + + @abc.abstractmethod + async def maybe_close_channel(self, channel: ChannelT) -> None: + ... + + @abc.abstractmethod + async def collect(self, socket_timeout: float = None) -> None: + ... + + @abc.abstractmethod + async def release(self) -> None: + ... + + @abc.abstractmethod + async def close(self) -> None: + ... + + @abc.abstractmethod + async def ensure_connection( + self, + errback: Callable = None, + max_retries: int = None, + interval_start: float = 2.0, + interval_step: float = 2.0, + interval_max: float = 30.0, + callback: Callable = None, + reraise_as_library_errors: bool = True) -> 'ClientT': + ... + + @abc.abstractmethod + def completes_cycle(self, retries: int) -> bool: + ... + + @abc.abstractmethod + async def ensure( + self, obj: Revivable, fun: Callable, + errback: Callable = None, + max_retries: int = None, + interval_start: float = 1.0, + interval_step: float = 1.0, + interval_max: float = 1.0, + on_revive: Callable = None) -> Any: + ... + + @abc.abstractmethod + def autoretry(self, fun: Callable, + channel: ChannelT = None, **ensure_options) -> Any: + ... + + @abc.abstractmethod + def create_transport(self) -> TransportT: + ... + + @abc.abstractmethod + def get_transport_cls(self) -> type: + ... + + @abc.abstractmethod + def clone(self, **kwargs) -> 'ClientT': + ... + + @abc.abstractmethod + def get_heartbeat_interval(self) -> float: + ... + + @abc.abstractmethod + def info(self) -> Mapping[str, Any]: + ... + + @abc.abstractmethod + def as_uri( + self, + *, + include_password: bool = False, + mask: str = '**', + getfields: Callable = None) -> str: + ... + + @abc.abstractmethod + def Pool(self, limit: int = None, **kwargs) -> 'ResourceT': + ... + + @abc.abstractmethod + def ChannelPool(self, limit: int = None, **kwargs) -> 'ResourceT': + ... + + @abc.abstractmethod + def Producer(self, channel: ChannelT = None, + *args, **kwargs) -> 'ProducerT': + ... + + @abc.abstractmethod + def Consumer( + self, + queues: Sequence['QueueT'] = None, + channel: ChannelT = None, + *args, **kwargs) -> 'ConsumerT': + ... + + @abc.abstractmethod + def SimpleQueue( + self, name: str, + no_ack: bool = None, + queue_opts: Mapping = None, + exchange_opts: Mapping = None, + channel: ChannelT = None, + **kwargs) -> 'SimpleQueueT': + ... + + +class EntityT(Hashable, Revivable, metaclass=_AbstractClass): + + @property + @abc.abstractmethod + def can_cache_declaration(self) -> bool: + ... + + @property + @abc.abstractmethod + def is_bound(self) -> bool: + ... + + @property + @abc.abstractmethod + def channel(self) -> ChannelT: + ... + + @abc.abstractmethod + def __call__(self, channel: ChannelT) -> 'EntityT': + ... + + @abc.abstractmethod + def bind(self, channel: ChannelT) -> 'EntityT': + ... + + @abc.abstractmethod + def maybe_bind(self, channel: Optional[ChannelT]) -> 'EntityT': + ... + + @abc.abstractmethod + def when_bound(self) -> None: + ... + + @abc.abstractmethod + def as_dict(self, recurse: bool = False) -> Dict: + ... + + +ChannelArgT = TypeVar('ChannelArgT', ChannelT, ClientT) + + +class ExchangeT(EntityT, metaclass=_AbstractClass): + + name: str + type: str + channel: ChannelT + arguments: Mapping + durable: bool = True + auto_delete: bool = False + delivery_mode: int + no_declare: bool = False + passive: bool = False + + def __init__( + self, + name: str = '', + type: str = '', + channel: ChannelArgT = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def declare( + self, + nowait: bool = False, + passive: bool = None, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def bind_to( + self, + exchange: Union[str, 'ExchangeT'] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def unbind_from( + self, + source: Union[str, 'ExchangeT'] = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def publish( + self, message: Union[MessageT, str, bytes], + routing_key: str = None, + mandatory: bool = False, + immediate: bool = False, + exchange: str = None) -> None: + ... + + @abc.abstractmethod + async def delete( + self, + if_unused: bool = False, + nowait: bool = False) -> None: + ... + + @abc.abstractmethod + def binding( + self, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> BindingT: + ... + + @abc.abstractmethod + def Message( + self, body: Any, + delivery_mode: Union[str, int] = None, + properties: Mapping = None, + **kwargs) -> Any: + ... + + +class QueueT(EntityT, metaclass=_AbstractClass): + + name: str = '' + exchange: ExchangeT + routing_key: str = '' + alias: str + + bindings: Sequence['BindingT'] = None + + durable: bool = True + exclusive: bool = False + auto_delete: bool = False + no_ack: bool = False + no_declare: bool = False + expires: float + message_ttl: float + max_length: int + max_length_bytes: int + max_priority: int + + queue_arguments: Mapping + binding_arguments: Mapping + consumer_arguments: Mapping + + def __init__( + self, + name: str = '', + exchange: ExchangeT = None, + routing_key: str = '', + channel: ChannelT = None, + bindings: Sequence['BindingT'] = None, + on_declared: Callable = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def declare(self, + nowait: bool = False, + channel: ChannelT = None) -> str: + ... + + @abc.abstractmethod + async def queue_declare( + self, + nowait: bool = False, + passive: bool = False, + channel: ChannelT = None) -> queue_declare_ok_t: + ... + + @abc.abstractmethod + async def queue_bind( + self, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def bind_to( + self, + exchange: Union[str, ExchangeT] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def get( + self, + no_ack: bool = None, + accept: Set[str] = None) -> MessageT: + ... + + @abc.abstractmethod + async def purge(self, nowait: bool = False) -> int: + ... + + @abc.abstractmethod + async def consume( + self, + consumer_tag: str = '', + callback: Callable = None, + no_ack: bool = None, + nowait: bool = False) -> None: + ... + + @abc.abstractmethod + async def cancel(self, consumer_tag: str) -> None: + ... + + @abc.abstractmethod + async def delete( + self, + if_unused: bool = False, + if_empty: bool = False, + nowait: bool = False) -> None: + ... + + @abc.abstractmethod + async def queue_unbind( + self, + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def unbind_from( + self, + exchange: str = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @classmethod + @abc.abstractmethod + def from_dict( + self, queue: str, + exchange: str = None, + exchange_type: str = None, + binding_key: str = None, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + bindings: Sequence = None, + durable: bool = None, + queue_durable: bool = None, + exchange_durable: bool = None, + auto_delete: bool = None, + queue_auto_delete: bool = None, + exchange_auto_delete: bool = None, + exchange_arguments: Mapping = None, + queue_arguments: Mapping = None, + binding_arguments: Mapping = None, + consumer_arguments: Mapping = None, + exclusive: bool = None, + no_ack: bool = None, + **options) -> 'QueueT': + ... + + +class BindingT(metaclass=_AbstractClass): + + exchange: ExchangeT + routing_key: str + arguments: Mapping + unbind_arguments: Mapping + + def __init__( + self, + exchange: ExchangeT = None, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> None: + ... + + def declare(self, channel: ChannelT, nowait: bool = False) -> None: + ... + + def bind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + def unbind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + +class ConsumerT(Revivable, metaclass=_AbstractClass): + + ContentDisallowed: type + + channel: ChannelT + queues: Sequence[QueueT] + no_ack: bool + auto_declare: bool = True + callbacks: Sequence[Callable] + on_message: Callable + on_decode_error: Callable + accept: Set[str] + prefetch_count: int = None + tag_prefix: str + + @property + @abc.abstractmethod + def connection(self) -> ClientT: + ... + + def __init__( + self, + channel: ChannelT, + queues: Sequence[QueueT] = None, + no_ack: bool = None, + auto_declare: bool = None, + callbacks: Sequence[Callable] = None, + on_decode_error: Callable = None, + on_message: Callable = None, + accept: Sequence[str] = None, + prefetch_count: int = None, + tag_prefix: str = None) -> None: + ... + + @abc.abstractmethod + async def declare(self) -> None: + ... + + @abc.abstractmethod + def register_callback(self, callback: Callable) -> None: + ... + + @abc.abstractmethod + def __enter__(self) -> 'ConsumerT': + ... + + @abc.abstractmethod + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + ... + + @abc.abstractmethod + async def __aenter__(self) -> 'ConsumerT': + ... + + @abc.abstractmethod + async def __aexit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + async def add_queue(self, queue: QueueT) -> QueueT: + ... + + @abc.abstractmethod + async def consume(self, no_ack: bool = None) -> None: + ... + + @abc.abstractmethod + async def cancel(self) -> None: + ... + + @abc.abstractmethod + async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None: + ... + + @abc.abstractmethod + def consuming_from(self, queue: Union[QueueT, str]) -> bool: + ... + + @abc.abstractmethod + async def purge(self) -> int: + ... + + @abc.abstractmethod + async def flow(self, active: bool) -> None: + ... + + @abc.abstractmethod + async def qos(self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: + ... + + @abc.abstractmethod + async def recover(self, requeue: bool = False) -> None: + ... + + @abc.abstractmethod + async def receive(self, body: Any, message: MessageT) -> None: + ... + + +class ProducerT(Revivable, metaclass=_AbstractClass): + + exchange: ExchangeT + routing_key: str = '' + serializer: str + compression: str + auto_declare: bool = True + on_return: Callable = None + + __connection__: ClientT + + channel: ChannelT + + @property + @abc.abstractmethod + def connection(self) -> ClientT: + ... + + def __init__(self, + channel: ChannelArgT, + exchange: ExchangeT = None, + routing_key: str = None, + serializer: str = None, + auto_declare: bool = None, + compression: str = None, + on_return: Callable = None) -> None: + ... + + @abc.abstractmethod + async def declare(self) -> None: + ... + + @abc.abstractmethod + async def maybe_declare(self, entity: EntityT, + retry: bool = False, **retry_policy) -> None: + ... + + @abc.abstractmethod + async def publish( + self, body: Any, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + mandatory: bool = False, + immediate: bool = False, + priority: int = 0, + content_type: str = None, + content_encoding: str = None, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + exchange: Union[ExchangeT, str] = None, + retry: bool = False, + retry_policy: Mapping = None, + declare: Sequence[EntityT] = None, + expiration: Number = None, + **properties) -> None: + ... + + @abc.abstractmethod + def __enter__(self) -> 'ProducerT': + ... + + @abc.abstractmethod + async def __aenter__(self) -> 'ProducerT': + ... + + @abc.abstractmethod + def __exit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + async def __aexit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + def release(self) -> None: + ... + + @abc.abstractmethod + def close(self) -> None: + ... + + +class MessageT(_MessageT): + + MessageStateError: type + + _state: str + channel: ChannelT = None + delivery_tag: str + content_type: str + content_encoding: str + delivery_info: Mapping + headers: Mapping + properties: Mapping + accept: Set[str] + body: Any + errors: List[Any] + + @property + @abc.abstractmethod + def acknowledged(self) -> bool: + ... + + @property + @abc.abstractmethod + def payload(self) -> Any: + ... + + def __init__( + self, + body: Any = None, + delivery_tag: str = None, + content_type: str = None, + content_encoding: str = None, + delivery_info: Mapping = None, + properties: Mapping = None, + headers: Mapping = None, + postencode: str = None, + accept: Set[str] = None, + channel: ChannelT = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def ack(self, multiple: bool = False) -> None: + ... + + @abc.abstractmethod + async def ack_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + multiple: bool = False) -> None: + ... + + @abc.abstractmethod + async def reject_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + requeue: bool = False) -> None: + ... + + @abc.abstractmethod + async def reject(self, requeue: bool = False) -> None: + ... + + @abc.abstractmethod + async def requeue(self) -> None: + ... + + @abc.abstractmethod + def decode(self) -> Any: + ... + + +class ResourceT(metaclass=_AbstractClass): + + LimitedExceeeded: type + + close_after_fork: bool = False + + @property + @abc.abstractmethod + def limit(self) -> int: + ... + + def __init__(self, + limit: int = None, + preload: int = None, + close_after_fork: bool = None) -> None: + ... + + @abc.abstractmethod + def setup(self) -> None: + ... + + @abc.abstractmethod + def acquire(self, block: bool = False, timeout: int = None): + ... + + @abc.abstractmethod + def prepare(self, resource: Any) -> Any: + ... + + @abc.abstractmethod + def close_resource(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def release_resource(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def replace(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def release(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def collect_resource(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def force_close_all(self) -> None: + ... + + @abc.abstractmethod + def resize( + self, limit: int, + force: bool = False, + ignore_errors: bool = False, + reset: bool = False) -> None: + ... + + +class SimpleQueueT(metaclass=_AbstractClass): + + Empty: type + + channel: ChannelT + producer: ProducerT + consumer: ConsumerT + no_ack: bool = False + queue: QueueT + buffer: deque + + def __init__( + self, + channel: ChannelArgT, + name: str, + no_ack: bool = None, + queue_opts: Mapping = None, + exchange_opts: Mapping = None, + serializer: str = None, + compression: str = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + def __enter__(self) -> 'SimpleQueueT': + ... + + @abc.abstractmethod + def __exit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + def get(self, block: bool = True, timeout: float = None) -> MessageT: + ... + + @abc.abstractmethod + def get_nowait(self) -> MessageT: + ... + + @abc.abstractmethod + def put(self, message: Any, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + routing_key: str = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + def clear(self) -> int: + ... + + @abc.abstractmethod + def qsize(self) -> int: + ... + + @abc.abstractmethod + def close(self) -> None: + ... + + @abc.abstractmethod + def __len__(self) -> int: + ... + + @abc.abstractmethod + def __bool__(self) -> bool: + ... diff --git a/kombu/utils/abstract.py b/kombu/utils/abstract.py deleted file mode 100644 index 27a27d45..00000000 --- a/kombu/utils/abstract.py +++ /dev/null @@ -1,49 +0,0 @@ -import abc - -from typing import Any -from typing import Set # noqa - - -def _hasattr(C: Any, attr: str) -> bool: - return any(attr in B.__dict__ for B in C.__mro__) - - -class _AbstractClass(object, metaclass=abc.ABCMeta): - __required_attributes__ = frozenset() # type: frozenset - - @classmethod - def _subclasshook_using(cls, parent: Any, C: Any): - return ( - cls is parent and - all(_hasattr(C, attr) for attr in cls.__required_attributes__) - ) or NotImplemented - - @classmethod - def register(cls, other: Any) -> Any: - # we override `register` to return other for use as a decorator. - type(cls).register(cls, other) - return other - - -class Connection(_AbstractClass): - ... - - -class Entity(_AbstractClass): - ... - - -class Consumer(_AbstractClass): - ... - - -class Producer(_AbstractClass): - ... - - -class Messsage(_AbstractClass): - ... - - -class Resource(_AbstractClass): - ... diff --git a/kombu/utils/amq_manager.py b/kombu/utils/amq_manager.py index fcf1c6d5..c8cb8508 100644 --- a/kombu/utils/amq_manager.py +++ b/kombu/utils/amq_manager.py @@ -1,9 +1,9 @@ """AMQP Management API utilities.""" from typing import Any, Union -from . import abstract +from kombu.types import ClientT -def get_manager(client: abstract.Connection, +def get_manager(client: ClientT, hostname: str = None, port: Union[int, str] = None, userid: str = None, diff --git a/kombu/utils/collections.py b/kombu/utils/collections.py index 0ee27399..b202470b 100644 --- a/kombu/utils/collections.py +++ b/kombu/utils/collections.py @@ -1,4 +1,5 @@ """Custom maps, sequences, etc.""" +from typing import Any, Union class HashedSeq(list): @@ -10,15 +11,15 @@ class HashedSeq(list): __slots__ = 'hashvalue' - def __init__(self, *seq): + def __init__(self, *seq) -> None: self[:] = seq self.hashvalue = hash(seq) - def __hash__(self): + def __hash__(self) -> int: return self.hashvalue -def eqhash(o): +def eqhash(o: Any) -> Union[int, HashedSeq]: """Call ``obj.__eqhash__``.""" try: return o.__eqhash__() @@ -29,14 +30,17 @@ def eqhash(o): class EqualityDict(dict): """Dict using the eq operator for keying.""" - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: h = eqhash(key) if h not in self: return self.__missing__(key) return dict.__getitem__(self, h) - def __setitem__(self, key, value): - return dict.__setitem__(self, eqhash(key), value) + def __setitem__(self, key: Any, value: Any) -> None: + dict.__setitem__(self, eqhash(key), value) - def __delitem__(self, key): - return dict.__delitem__(self, eqhash(key)) + def __delitem__(self, key: Any) -> None: + dict.__delitem__(self, eqhash(key)) + + def __missing__(self, key: Any) -> Any: + raise KeyError(key) diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py index c9390ec3..b12dd41e 100644 --- a/kombu/utils/compat.py +++ b/kombu/utils/compat.py @@ -3,6 +3,8 @@ import numbers import sys from functools import wraps from contextlib import contextmanager +from typing import Any, Callable, Iterator, Optional, Tuple, Union +from .typing import SupportsFileno try: from io import UnsupportedOperation @@ -21,17 +23,17 @@ except ImportError: # pragma: no cover _environment = None -def coro(gen): +def coro(gen: Callable) -> Callable: """Decorator to mark generator as co-routine.""" @wraps(gen) - def wind_up(*args, **kwargs): + def wind_up(*args, **kwargs) -> Any: it = gen(*args, **kwargs) next(it) return it return wind_up -def _detect_environment(): +def _detect_environment() -> str: # ## -eventlet- if 'eventlet' in sys.modules: try: @@ -57,14 +59,15 @@ def _detect_environment(): return 'default' -def detect_environment(): +def detect_environment() -> str: """Detect the current environment: default, eventlet, or gevent.""" global _environment if _environment is None: _environment = _detect_environment() return _environment -def entrypoints(namespace): + +def entrypoints(namespace: str) -> Iterator[Tuple[str, Any]]: """Return setuptools entrypoints for namespace.""" try: from pkg_resources import iter_entry_points @@ -73,14 +76,14 @@ def entrypoints(namespace): return ((ep, ep.load()) for ep in iter_entry_points(namespace)) -def fileno(f): +def fileno(f: Union[SupportsFileno, numbers.Integral]) -> numbers.Integral: """Get fileno from file-like object.""" if isinstance(f, numbers.Integral): return f return f.fileno() -def maybe_fileno(f): +def maybe_fileno(f: Any) -> Optional[numbers.Integral]: """Get object fileno, or :const:`None` if not defined.""" try: return fileno(f) diff --git a/kombu/utils/debug.py b/kombu/utils/debug.py index 6bc59020..df98a04d 100644 --- a/kombu/utils/debug.py +++ b/kombu/utils/debug.py @@ -1,10 +1,7 @@ """Debugging support.""" import logging - -from typing import Any, Optional, Sequence, Union - +from typing import Any, Sequence, Union from vine.utils import wraps - from kombu.log import get_logger __all__ = ['setup_logging', 'Logwrapped'] @@ -29,8 +26,8 @@ class Logwrapped: __ignore = ('__enter__', '__exit__') def __init__(self, instance: Any, - logger: Optional[LoggerArg]=None, - ident: Optional[str]=None) -> None: + logger: LoggerArg = None, + ident: str = None) -> None: self.instance = instance self.logger = get_logger(logger) # type: logging.Logger self.ident = ident diff --git a/kombu/utils/div.py b/kombu/utils/div.py index 21b4c59a..0535c30f 100644 --- a/kombu/utils/div.py +++ b/kombu/utils/div.py @@ -1,9 +1,13 @@ """Div. Utilities.""" -from .encoding import default_encode import sys +from typing import Any, Callable, IO +from .encoding import default_encode -def emergency_dump_state(state, open_file=open, dump=None, stderr=None): +def emergency_dump_state(state: Any, + open_file: Callable = open, + dump: Callable = None, + stderr: IO = None): """Dump message state to stdout or file.""" from pprint import pformat from tempfile import mktemp diff --git a/kombu/utils/encoding.py b/kombu/utils/encoding.py index a68fef25..8faf32a1 100644 --- a/kombu/utils/encoding.py +++ b/kombu/utils/encoding.py @@ -8,7 +8,6 @@ applications without crashing from the infamous import sys import traceback - from typing import Any, AnyStr, IO, Optional str_t = str @@ -67,7 +66,7 @@ def default_encode(obj: Any) -> Any: return obj -def safe_str(s: Any, errors: str='replace') -> str: +def safe_str(s: Any, errors: str = 'replace') -> str: """Safe form of str(), void of unicode errors.""" s = bytes_to_str(s) if not isinstance(s, (str, bytes)): @@ -75,7 +74,7 @@ def safe_str(s: Any, errors: str='replace') -> str: return _safe_str(s, errors) -def _safe_str(s: Any, errors: str='replace', file: IO=None) -> str: +def _safe_str(s: Any, errors: str = 'replace', file: IO = None) -> str: if isinstance(s, str): return s try: @@ -85,7 +84,7 @@ def _safe_str(s: Any, errors: str='replace', file: IO=None) -> str: type(s), exc, '\n'.join(traceback.format_stack())) -def safe_repr(o: Any, errors='replace') -> str: +def safe_repr(o: Any, errors: str = 'replace') -> str: """Safe form of repr, void of Unicode errors.""" try: return repr(o) diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py index 8be03494..d10f5868 100644 --- a/kombu/utils/eventio.py +++ b/kombu/utils/eventio.py @@ -8,7 +8,7 @@ from numbers import Integral from typing import Any, Callable, Optional, Sequence, IO, cast from typing import Set, Tuple # noqa from . import fileno -from .typing import Fd, Timeout +from .typing import Fd from .compat import detect_environment __all__ = ['poll'] @@ -81,7 +81,7 @@ class _epoll(BasePoller): except (PermissionError, FileNotFoundError): pass - def poll(self, timeout: Timeout) -> Optional[Sequence]: + def poll(self, timeout: float) -> Optional[Sequence]: try: return self._epoll.poll(timeout if timeout is not None else -1) except Exception as exc: @@ -148,7 +148,7 @@ class _kqueue(BasePoller): except ValueError: pass - def poll(self, timeout: Timeout) -> Sequence: + def poll(self, timeout: float) -> Sequence: try: kevents = self._kcontrol(None, 1000, timeout) except Exception as exc: @@ -212,7 +212,7 @@ class _poll(BasePoller): self._quick_unregister(fd) return fd - def poll(self, timeout: Timeout, + def poll(self, timeout: float, round: Callable=math.ceil, POLLIN: int=POLLIN, POLLOUT: int=POLLOUT, POLLERR: int=POLLERR, READ: int=READ, WRITE: int=WRITE, ERR: int=ERR, @@ -284,7 +284,7 @@ class _select(BasePoller): self._wfd.discard(fd) self._efd.discard(fd) - def poll(self, timeout: Timeout) -> Sequence: + def poll(self, timeout: float) -> Sequence: try: read, write, error = _selectf( cast(Sequence, self._rfd), diff --git a/kombu/utils/functional.py b/kombu/utils/functional.py index 85feffa5..8d631809 100644 --- a/kombu/utils/functional.py +++ b/kombu/utils/functional.py @@ -6,9 +6,11 @@ from itertools import count, repeat from time import sleep from typing import ( Any, Callable, Dict, Iterable, Iterator, - KeysView, Mapping, Optional, Sequence, Tuple, + KeysView, ItemsView, Mapping, Optional, Sequence, Tuple, ValuesView, ) +from amqp.types import ChannelT from vine.utils import wraps +from ..types import ClientT from .encoding import safe_repr as _safe_repr __all__ = [ @@ -17,27 +19,36 @@ __all__ = [ ] KEYWORD_MARK = object() - MemoizeKeyFun = Callable[[Sequence, Mapping], Any] class ChannelPromise(object): + __value__: ChannelT + + def __init__(self, connection: ClientT) -> None: + self.__connection__ = connection - def __init__(self, contract): - self.__contract__ = contract + def __call__(self) -> Any: + try: + return self.__value__ + except AttributeError: + value = self.__value__ = self.__connection__.default_channel + return value - def __call__(self): + async def resolve(self): try: return self.__value__ except AttributeError: - value = self.__value__ = self.__contract__() + await self.__connection__.connect() + value = self.__value__ = self.__connection__.default_channel + await value.open() return value - def __repr__(self): + def __repr__(self) -> str: try: return repr(self.__value__) except AttributeError: - return '<promise: 0x{0:x}>'.format(id(self.__contract__)) + return '<promise: 0x{0:x}>'.format(id(self.__connection__)) class LRUCache(UserDict): @@ -53,7 +64,7 @@ class LRUCache(UserDict): def __init__(self, limit: int = None) -> None: self.limit = limit self.mutex = threading.RLock() - self.data = OrderedDict() # type: OrderedDict + self.data: OrderedDict = OrderedDict() def __getitem__(self, key: Any) -> Any: with self.mutex: @@ -69,7 +80,7 @@ class LRUCache(UserDict): for _ in range(len(data) - limit): data.popitem(last=False) - def popitem(self, last: bool=True) -> Any: + def popitem(self, last: bool = True) -> Any: with self.mutex: return self.data.popitem(last) @@ -83,7 +94,7 @@ class LRUCache(UserDict): def __iter__(self) -> Iterator: return iter(self.data) - def items(self) -> Iterator[Tuple[Any, Any]]: + def items(self) -> ItemsView[Any, Any]: with self.mutex: for k in self: try: @@ -91,7 +102,7 @@ class LRUCache(UserDict): except KeyError: # pragma: no cover pass - def values(self) -> Iterator[Any]: + def values(self) -> ValuesView[Any]: with self.mutex: for k in self: try: @@ -104,7 +115,7 @@ class LRUCache(UserDict): with self.mutex: return self.data.keys() - def incr(self, key: Any, delta: int=1) -> Any: + def incr(self, key: Any, delta: int = 1) -> Any: with self.mutex: # this acts as memcached does- store as a string, but return a # integer as long as it exists and we can cast it @@ -122,9 +133,9 @@ class LRUCache(UserDict): self.mutex = threading.RLock() -def memoize(maxsize: Optional[int]=None, - keyfun: Optional[MemoizeKeyFun]=None, - Cache: Any=LRUCache) -> Callable: +def memoize(maxsize: int = None, + keyfun: MemoizeKeyFun = None, + Cache: Any = LRUCache) -> Callable: """Decorator to cache function return value.""" def _memoize(fun: Callable) -> Callable: @@ -189,10 +200,10 @@ class lazy: def __repr__(self) -> str: return repr(self()) - def __eq__(self, rhs) -> bool: + def __eq__(self, rhs: Any) -> bool: return self() == rhs - def __ne__(self, rhs) -> bool: + def __ne__(self, rhs: Any) -> bool: return self() != rhs def __deepcopy__(self, memo: Dict) -> Any: @@ -200,8 +211,10 @@ class lazy: return self def __reduce__(self) -> Any: - return (self.__class__, (self._fun,), {'_args': self._args, - '_kwargs': self._kwargs}) + return (self.__class__, (self._fun,), { + '_args': self._args, + '_kwargs': self._kwargs, + }) def maybe_evaluate(value: Any) -> Any: @@ -212,8 +225,9 @@ def maybe_evaluate(value: Any) -> Any: def is_list(l: Any, - scalars: tuple=(Mapping, str), - iters: tuple=(Iterable,)) -> bool: + *, + scalars: tuple = (Mapping, str), + iters: tuple = (Iterable,)) -> bool: """Return true if the object is iterable. Note: @@ -222,18 +236,19 @@ def is_list(l: Any, return isinstance(l, iters) and not isinstance(l, scalars or ()) -def maybe_list(l: Any, scalars: tuple=(Mapping, str)) -> Optional[Sequence]: +def maybe_list(l: Any, *, + scalars: tuple = (Mapping, str)) -> Optional[Sequence]: """Return list of one element if ``l`` is a scalar.""" return l if l is None or is_list(l, scalars) else [l] -def dictfilter(d: Optional[Mapping]=None, **kw) -> Mapping: +def dictfilter(d: Mapping = None, **kw) -> Mapping: """Remove all keys from dict ``d`` whose value is :const:`None`.""" d = kw if d is None else (dict(d, **kw) if kw else d) return {k: v for k, v in d.items() if v is not None} -def shufflecycle(it): +def shufflecycle(it: Sequence) -> Iterator: it = list(it) # don't modify callers list shuffle = random.shuffle for _ in repeat(None): @@ -241,7 +256,10 @@ def shufflecycle(it): yield it[0] -def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False): +def fxrange(start: float = 1.0, + stop: float = None, + step: float = 1.0, + repeatlast: bool = False) -> Iterator[float]: cur = start * 1.0 while 1: if not stop or cur <= stop: @@ -253,7 +271,10 @@ def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False): yield cur - step -def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0): +def fxrangemax(start: float = 1.0, + stop: float = None, + step: float = 1.0, + max: float = 100.0) -> Iterator[float]: sum_, cur = 0, start * 1.0 while 1: if sum_ >= max: @@ -266,9 +287,18 @@ def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0): sum_ += cur -def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, - max_retries=None, interval_start=2, interval_step=2, - interval_max=30, callback=None): +async def retry_over_time( + fun: Callable, + catch: Tuple[Any, ...], + args: Sequence = [], + kwargs: Mapping[str, Any] = {}, + *, + errback: Callable = None, + max_retries: int = None, + interval_start: float = 2.0, + interval_step: float = 2.0, + interval_max: float = 30.0, + callback: Callable = None) -> Any: """Retry the function over and over until max retries is exceeded. For each retry we sleep a for a while before we try again, this interval @@ -303,7 +333,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, interval_step, repeatlast=True) for retries in count(): try: - return fun(*args, **kwargs) + return await fun(*args, **kwargs) except catch as exc: if max_retries and retries >= max_retries: raise @@ -320,11 +350,17 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, sleep(abs(int(tts) - tts)) -def reprkwargs(kwargs, sep=', ', fmt='{0}={1}'): +def reprkwargs(kwargs: Mapping, + *, + sep: str = ', ', fmt: str = '{0}={1}') -> str: return sep.join(fmt.format(k, _safe_repr(v)) for k, v in kwargs.items()) -def reprcall(name, args=(), kwargs={}, sep=', '): +def reprcall(name: str, + args: Sequence = (), + kwargs: Mapping = {}, + *, + sep: str = ', ') -> str: return '{0}({1}{2}{3})'.format( name, sep.join(map(_safe_repr, args or ())), (args and kwargs) and sep or '', diff --git a/kombu/utils/imports.py b/kombu/utils/imports.py index 5262b41d..345ec4bb 100644 --- a/kombu/utils/imports.py +++ b/kombu/utils/imports.py @@ -1,10 +1,18 @@ """Import related utilities.""" import importlib import sys +from typing import Any, Callable, Mapping -def symbol_by_name(name, aliases={}, imp=None, package=None, - sep='.', default=None, **kwargs): +def symbol_by_name( + name: Any, + aliases: Mapping[str, str] = {}, + *, + imp: Callable = None, + package: str = None, + sep: str = '.', + default: Any = None, + **kwargs) -> Any: """Get symbol by qualified name. The name should be the full dot-separated path to the class:: diff --git a/kombu/utils/json.py b/kombu/utils/json.py index 6668a73e..b3c6dfb5 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -10,7 +10,7 @@ from .typing import AnyBuffer try: from django.utils.functional import Promise as DjangoPromise except ImportError: # pragma: no cover - class DjangoPromise(object): # noqa + class DjangoPromise(object): # noqa, type: ignore """Dummy object.""" try: @@ -24,20 +24,21 @@ except ImportError: # pragma: no cover else: from simplejson.decoder import JSONDecodeError as _DecodeError -_encoder_cls = type(json._default_encoder) -_default_encoder = None # ... set to JSONEncoder below. +_encoder_cls: type = type(json._default_encoder) # type: ignore +_default_encoder: type = None # ... set to JSONEncoder below. class JSONEncoder(_encoder_cls): """Kombu custom json encoder.""" - def default(self, o, + def default(self, o: Any, + *, dates=(datetime.datetime, datetime.date), times=(datetime.time,), textual=(decimal.Decimal, uuid.UUID, DjangoPromise), isinstance=isinstance, datetime=datetime.datetime, - str=str): + str=str) -> Any: reducer = getattr(o, '__json__', None) if reducer is not None: o = reducer() @@ -59,6 +60,7 @@ _default_encoder = JSONEncoder def dumps(s: Any, + *, _dumps: Callable = json.dumps, cls: Any = None, default_kwargs: Dict = _json_extra_kwargs, @@ -68,7 +70,7 @@ def dumps(s: Any, **dict(default_kwargs, **kwargs)) -def loads(s: AnyBuffer, _loads: Callable = json.loads) -> Any: +def loads(s: AnyBuffer, *, _loads: Callable = json.loads) -> Any: """Deserialize json from string.""" # None of the json implementations supports decoding from # a buffer/memoryview, or even reading from a stream diff --git a/kombu/utils/limits.py b/kombu/utils/limits.py index ee9c8003..2c34f0de 100644 --- a/kombu/utils/limits.py +++ b/kombu/utils/limits.py @@ -1,8 +1,7 @@ """Token bucket implementation for rate limiting.""" from collections import deque from time import monotonic -from typing import Any, Optional - +from typing import Any from .typing import Float __all__ = ['TokenBucket'] @@ -24,21 +23,21 @@ class TokenBucket: """ #: The rate in tokens/second that the bucket will be refilled. - fill_rate = None + fill_rate: float = None #: Maximum number of tokens in the bucket. - capacity = 1.0 # type: float + capacity: float = 1.0 #: Timestamp of the last time a token was taken out of the bucket. - timestamp = None # type: Optional[float] + timestamp: float = None - def __init__(self, fill_rate: Optional[Float], - capacity: Float=1.0) -> None: + def __init__(self, fill_rate: Float = None, + capacity: Float = 1.0) -> None: self.capacity = float(capacity) self._tokens = capacity self.fill_rate = float(fill_rate) self.timestamp = monotonic() - self.contents = deque() # type: deque + self.contents = deque() def add(self, item: Any) -> None: self.contents.append(item) @@ -49,7 +48,7 @@ class TokenBucket: def clear_pending(self) -> None: self.contents.clear() - def can_consume(self, tokens: int=1) -> bool: + def can_consume(self, tokens: int = 1) -> bool: """Check if one or more tokens can be consumed. Returns: diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py index 74d4f25a..624a06d0 100644 --- a/kombu/utils/scheduling.py +++ b/kombu/utils/scheduling.py @@ -1,8 +1,7 @@ """Scheduling Utilities.""" from itertools import count -from typing import Any, Callable, Iterable, Optional, Sequence, Union +from typing import Any, Callable, Iterable, Sequence, Union from typing import List # noqa - from .imports import symbol_by_name __all__ = [ @@ -24,12 +23,12 @@ class FairCycle: Arguments: fun (Callable): Callback to call. - resources (Sequence[Any]): List of resources. + resources (Sequence): List of resources. predicate (type): Exception predicate. """ def __init__(self, fun: Callable, resources: Sequence, - predicate: Any=Exception) -> None: + predicate: Any = Exception) -> None: self.fun = fun self.resources = resources self.predicate = predicate @@ -72,8 +71,8 @@ class BaseCycle: class round_robin_cycle(BaseCycle): """Iterator that cycles between items in round-robin.""" - def __init__(self, it: Optional[Iterable]=None) -> None: - self.items = list(it if it is not None else []) # type: List + def __init__(self, it: Iterable = None) -> None: + self.items = list(it if it is not None else []) def update(self, it: Sequence) -> None: """Update items from iterable.""" diff --git a/kombu/utils/text.py b/kombu/utils/text.py index 6f7977fc..f2e8280d 100644 --- a/kombu/utils/text.py +++ b/kombu/utils/text.py @@ -1,11 +1,18 @@ # -*- coding: utf-8 -*- """Text Utilities.""" from difflib import SequenceMatcher -from typing import Iterator, Sequence, NamedTuple, Tuple - +from numbers import Number +from typing import ( + Iterator, Sequence, NamedTuple, Optional, SupportsInt, Tuple, Union, +) from kombu import version_info_t -fmatch_t = NamedTuple('fmatch_t', [('ratio', float), ('key', str)]) + +class fmatch_t(NamedTuple): + """Return value of :func:`fmatch_iter`.""" + + ratio: float + key: str def escape_regex(p: str, white: str = '') -> str: @@ -30,14 +37,14 @@ def fmatch_iter(needle: str, haystack: Sequence[str], def fmatch_best(needle: str, haystack: Sequence[str], - min_ratio: float = 0.6) -> str: + min_ratio: float = 0.6) -> Optional[str]: """Fuzzy match - Find best match (scalar).""" try: return sorted( fmatch_iter(needle, haystack, min_ratio), reverse=True, )[0][1] except IndexError: - pass + return None def version_string_as_tuple(s: str) -> version_info_t: @@ -52,7 +59,9 @@ def version_string_as_tuple(s: str) -> version_info_t: return v -def _unpack_version(major: int, minor: int = 0, micro: int = 0, +def _unpack_version(major: Union[SupportsInt, str, bytes], + minor: Union[SupportsInt, str, bytes] = 0, + micro: Union[SupportsInt, str, bytes] = 0, releaselevel: str = '', serial: str = '') -> version_info_t: return version_info_t(int(major), int(minor), micro, releaselevel, serial) diff --git a/kombu/utils/time.py b/kombu/utils/time.py index 64d31d3b..72448c3f 100644 --- a/kombu/utils/time.py +++ b/kombu/utils/time.py @@ -1,10 +1,8 @@ """Time Utilities.""" -from __future__ import absolute_import, unicode_literals - +from typing import Optional, Union __all__ = ['maybe_s_to_ms'] -def maybe_s_to_ms(v): - # type: (Optional[Union[int, float]]) -> int +def maybe_s_to_ms(v: Optional[Union[int, float]]) -> Optional[int]: """Convert seconds to milliseconds, but return None for None.""" return int(float(v) * 1000.0) if v is not None else v diff --git a/kombu/utils/typing.py b/kombu/utils/typing.py deleted file mode 100644 index 1f4843f6..00000000 --- a/kombu/utils/typing.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import abstractmethod -from typing import ( - _Protocol, AnyStr, SupportsFloat, SupportsInt, Union, -) - -Float = Union[SupportsInt, SupportsFloat, AnyStr] -Int = Union[SupportsInt, AnyStr] - -Port = Union[SupportsInt, str] - -AnyBuffer = Union[AnyStr, memoryview] - - -class SupportsFileno(_Protocol): - __slots__ = () - - @abstractmethod - def __fileno__(self) -> int: - ... - -Fd = Union[int, SupportsFileno] diff --git a/kombu/utils/url.py b/kombu/utils/url.py index 75db4d89..7dae3464 100644 --- a/kombu/utils/url.py +++ b/kombu/utils/url.py @@ -1,30 +1,25 @@ """URL Utilities.""" from functools import partial from typing import Any, Dict, Mapping, NamedTuple +from urllib.parse import parse_qsl, quote, unquote, urlparse from .typing import Port -try: - from urllib.parse import parse_qsl, quote, unquote, urlparse -except ImportError: - from urllib import quote, unquote # noqa - from urlparse import urlparse, parse_qsl # noqa safequote = partial(quote, safe='') -urlparts = NamedTuple('urlparts', [ - ('scheme', str), - ('hostname', str), - ('port', int), - ('username', str), - ('password', str), - ('path', str), - ('query', Dict), -]) +class urlparts(NamedTuple): + scheme: str + hostname: str + port: int + username: str + password: str + path: str + query: Dict -def parse_url(url: str) -> Dict: +def parse_url(url: str) -> Mapping: """Parse URL into mapping of components.""" - scheme, host, port, user, password, path, query = _parse_url(url) + scheme, host, port, user, password, path, query = url_to_parts(url) return dict(transport=scheme, hostname=host, port=port, userid=user, password=password, virtual_host=path, **query) @@ -47,7 +42,6 @@ def url_to_parts(url: str) -> urlparts: unquote(path or '') or None, dict(parse_qsl(parts.query)), ) -_parse_url = url_to_parts # noqa def as_url(scheme: str, @@ -79,7 +73,7 @@ def as_url(scheme: str, def sanitize_url(url: str, mask: str = '**') -> str: """Return copy of URL with password removed.""" - return as_url(*_parse_url(url), sanitize=True, mask=mask) + return as_url(*url_to_parts(url), sanitize=True, mask=mask) def maybe_sanitize_url(url: Any, mask: str = '**') -> Any: diff --git a/kombu/utils/uuid.py b/kombu/utils/uuid.py index 7de82b29..04130ae2 100644 --- a/kombu/utils/uuid.py +++ b/kombu/utils/uuid.py @@ -1,8 +1,9 @@ """UUID utilities.""" from uuid import uuid4 +from typing import Callable -def uuid(_uuid=uuid4): +def uuid(*, _uuid: Callable = uuid4) -> str: """Generate unique id in UUID4 format. See Also: diff --git a/kombu/transport/SLMQ.py b/moved/transport/SLMQ.py index 91eebe5f..91eebe5f 100644 --- a/kombu/transport/SLMQ.py +++ b/moved/transport/SLMQ.py diff --git a/kombu/transport/SQS.py b/moved/transport/SQS.py index f5766edb..f5766edb 100644 --- a/kombu/transport/SQS.py +++ b/moved/transport/SQS.py diff --git a/kombu/transport/consul.py b/moved/transport/consul.py index fa77e6e9..fa77e6e9 100644 --- a/kombu/transport/consul.py +++ b/moved/transport/consul.py diff --git a/kombu/transport/etcd.py b/moved/transport/etcd.py index 4d5c652a..3c25d508 100644 --- a/kombu/transport/etcd.py +++ b/moved/transport/etcd.py @@ -4,22 +4,16 @@ It uses Etcd as a store to transport messages in Queues It uses python-etcd for talking to Etcd's HTTP API """ -from __future__ import absolute_import, unicode_literals - import os import socket - from collections import defaultdict from contextlib import contextmanager - +from queue import Empty from kombu.exceptions import ChannelError -from kombu.five import Empty from kombu.log import get_logger from kombu.utils.json import loads, dumps from kombu.utils.objects import cached_property - from . import virtual - try: import etcd except ImportError: diff --git a/kombu/transport/filesystem.py b/moved/transport/filesystem.py index 2a5737fb..2a5737fb 100644 --- a/kombu/transport/filesystem.py +++ b/moved/transport/filesystem.py diff --git a/kombu/transport/librabbitmq.py b/moved/transport/librabbitmq.py index 2ea5b779..2ea5b779 100644 --- a/kombu/transport/librabbitmq.py +++ b/moved/transport/librabbitmq.py diff --git a/kombu/transport/memory.py b/moved/transport/memory.py index e3b4e441..e3b4e441 100644 --- a/kombu/transport/memory.py +++ b/moved/transport/memory.py diff --git a/kombu/transport/mongodb.py b/moved/transport/mongodb.py index 7c8b5bb3..7c8b5bb3 100644 --- a/kombu/transport/mongodb.py +++ b/moved/transport/mongodb.py diff --git a/kombu/transport/pyro.py b/moved/transport/pyro.py index c52532aa..c52532aa 100644 --- a/kombu/transport/pyro.py +++ b/moved/transport/pyro.py diff --git a/kombu/transport/qpid.py b/moved/transport/qpid.py index 22046241..9aa4fb48 100644 --- a/kombu/transport/qpid.py +++ b/moved/transport/qpid.py @@ -76,8 +76,6 @@ options override and replace any other default or specified values. If using Celery, this can be accomplished by setting the *BROKER_TRANSPORT_OPTIONS* Celery option. """ -from __future__ import absolute_import, unicode_literals - from collections import OrderedDict import os import select @@ -87,6 +85,8 @@ import sys import uuid from gettext import gettext as _ +from time import monotonic +from queue import Empty import amqp.protocol @@ -116,7 +116,6 @@ except ImportError: # pragma: no cover qpid = None -from kombu.five import Empty, items, monotonic from kombu.log import get_logger from kombu.transport.virtual import Base64, Message from kombu.transport import base @@ -1585,7 +1584,7 @@ class Transport(base.Transport): """ conninfo = self.client - for name, default_value in items(self.default_connection_params): + for name, default_value in self.default_connection_params.items(): if not getattr(conninfo, name, None): setattr(conninfo, name, default_value) if conninfo.ssl: diff --git a/kombu/transport/redis.py b/moved/transport/redis.py index e3e6de10..4b65e7e2 100644 --- a/kombu/transport/redis.py +++ b/moved/transport/redis.py @@ -3,9 +3,9 @@ import numbers import socket from bisect import bisect -from collections import namedtuple from contextlib import contextmanager from time import time +from typing import NamedTuple, Tuple from queue import Empty from vine import promise @@ -18,7 +18,7 @@ from kombu.utils.encoding import bytes_to_str from kombu.utils.json import loads, dumps from kombu.utils.objects import cached_property from kombu.utils.scheduling import cycle_by_name -from kombu.utils.url import _parse_url +from kombu.utils.url import url_to_parts from kombu.utils.uuid import uuid from . import virtual @@ -42,15 +42,19 @@ DEFAULT_DB = 0 PRIORITY_STEPS = [0, 3, 6, 9] -error_classes_t = namedtuple('error_classes_t', ( - 'connection_errors', 'channel_errors', -)) - NO_ROUTE_ERROR = """ Cannot route message for exchange {0!r}: Table empty or key no longer exists. Probably the key ({1!r}) has been removed from the Redis database. """ + +class error_classes_t(NamedTuple): + """Return value of :func:`get_redis_error_classes`.""" + + connection_errors: Tuple[type, ...] + channel_errors: Tuple[type, ...] + + # This implementation may seem overly complex, but I assure you there is # a good reason for doing it this way. # @@ -892,7 +896,7 @@ class Channel(virtual.Channel): pass host = connparams['host'] if '://' in host: - scheme, _, _, _, _, path, query = _parse_url(host) + scheme, _, _, _, _, path, query = url_to_parts(host) if scheme == 'socket': connparams = self._filter_tcp_connparams(**connparams) connparams.update({ diff --git a/kombu/transport/zookeeper.py b/moved/transport/zookeeper.py index cf9f0dc8..cf9f0dc8 100644 --- a/kombu/transport/zookeeper.py +++ b/moved/transport/zookeeper.py diff --git a/requirements/test.txt b/requirements/test.txt index e1036604..6de0f611 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,3 +1,4 @@ pytz>dev case>=1.5.2 pytest +pytest-asyncio diff --git a/t/integration/test_async.py b/t/integration/test_async.py new file mode 100644 index 00000000..094b24c3 --- /dev/null +++ b/t/integration/test_async.py @@ -0,0 +1,51 @@ +import os +import pytest +from kombu import Connection, Exchange, Queue, Producer, Consumer +from kombu.pools import producers + +BROKER_URL = os.environ.get('BROKER_URL', 'pyamqp://') + +queue1 = Queue('testq32', Exchange('testq32', 'direct'), 'testq32') + + +@pytest.fixture() +def connection(): + return Connection(BROKER_URL) + + +@pytest.mark.asyncio +async def test_queue_declare(connection): + async with connection: + async with connection.default_channel as channel: + ret = await queue1(channel).declare() + assert ret == queue1.name + + +@pytest.mark.asyncio +async def test_produce_consume(connection): + + messages_received = [0] + + async def on_message(message): + messages_received[0] += 1 + await message.ack() + + async with connection as connection: + async with Consumer(connection, + queues=[queue1], + on_message=on_message) as consumer: + async with connection.clone() as w_connection: + await w_connection.connect() + assert w_connection._connection + async with producers[w_connection].acquire() as producer: + for i in range(10): + await producer.publish( + str(i), + exchange=queue1.exchange, + routing_key=queue1.routing_key, + retry=False, + declare=[queue1], + ) + while messages_received[0] < 10: + await connection.drain_events() + assert messages_received[0] == 10 diff --git a/t/integration/__init__.py b/t/oldint/__init__.py index 23ffb4c1..23ffb4c1 100644 --- a/t/integration/__init__.py +++ b/t/oldint/__init__.py diff --git a/t/integration/tests/__init__.py b/t/oldint/tests/__init__.py index 094b08e0..094b08e0 100644 --- a/t/integration/tests/__init__.py +++ b/t/oldint/tests/__init__.py diff --git a/t/integration/tests/test_SLMQ.py b/t/oldint/tests/test_SLMQ.py index 8428f7d1..8428f7d1 100644 --- a/t/integration/tests/test_SLMQ.py +++ b/t/oldint/tests/test_SLMQ.py diff --git a/t/integration/tests/test_SQS.py b/t/oldint/tests/test_SQS.py index 571cff24..571cff24 100644 --- a/t/integration/tests/test_SQS.py +++ b/t/oldint/tests/test_SQS.py diff --git a/t/integration/tests/test_amqp.py b/t/oldint/tests/test_amqp.py index f7aa762e..f7aa762e 100644 --- a/t/integration/tests/test_amqp.py +++ b/t/oldint/tests/test_amqp.py diff --git a/t/integration/tests/test_librabbitmq.py b/t/oldint/tests/test_librabbitmq.py index 41a21b3c..41a21b3c 100644 --- a/t/integration/tests/test_librabbitmq.py +++ b/t/oldint/tests/test_librabbitmq.py diff --git a/t/integration/tests/test_mongodb.py b/t/oldint/tests/test_mongodb.py index 495208c7..495208c7 100644 --- a/t/integration/tests/test_mongodb.py +++ b/t/oldint/tests/test_mongodb.py diff --git a/t/integration/tests/test_pyamqp.py b/t/oldint/tests/test_pyamqp.py index f7aa762e..f7aa762e 100644 --- a/t/integration/tests/test_pyamqp.py +++ b/t/oldint/tests/test_pyamqp.py diff --git a/t/integration/tests/test_qpid.py b/t/oldint/tests/test_qpid.py index a5c9141d..a5c9141d 100644 --- a/t/integration/tests/test_qpid.py +++ b/t/oldint/tests/test_qpid.py diff --git a/t/integration/tests/test_redis.py b/t/oldint/tests/test_redis.py index 610b0149..610b0149 100644 --- a/t/integration/tests/test_redis.py +++ b/t/oldint/tests/test_redis.py diff --git a/t/integration/tests/test_zookeeper.py b/t/oldint/tests/test_zookeeper.py index 150c4a35..150c4a35 100644 --- a/t/integration/tests/test_zookeeper.py +++ b/t/oldint/tests/test_zookeeper.py diff --git a/t/integration/transport.py b/t/oldint/transport.py index 36b4d33f..36b4d33f 100644 --- a/t/integration/transport.py +++ b/t/oldint/transport.py diff --git a/t/unit/__init__.py b/t/unit/__init__.py index 01e6d4f4..e69de29b 100644 --- a/t/unit/__init__.py +++ b/t/unit/__init__.py @@ -1 +0,0 @@ -from __future__ import absolute_import, unicode_literals diff --git a/t/unit/test_clocks.py b/t/unit/test_clocks.py index 5ed30bf1..c4f9587c 100644 --- a/t/unit/test_clocks.py +++ b/t/unit/test_clocks.py @@ -1,12 +1,7 @@ -from __future__ import absolute_import, unicode_literals - import pickle - from heapq import heappush from time import time - from case import Mock - from kombu.clocks import LamportClock, timetuple diff --git a/t/unit/test_entity.py b/t/unit/test_entity.py index 004c98f8..b923d480 100644 --- a/t/unit/test_entity.py +++ b/t/unit/test_entity.py @@ -5,7 +5,7 @@ import pytest from case import Mock, call from kombu import Connection, Exchange, Producer, Queue, binding -from kombu.abstract import MaybeChannelBound +from kombu.abstract import Entity from kombu.exceptions import NotBoundError from kombu.serialization import registry @@ -403,7 +403,7 @@ class test_Queue: assert 'Queue' in repr(b) -class test_MaybeChannelBound: +class test_Entity: def test_repr(self): - assert repr(MaybeChannelBound()) + assert repr(Entity()) diff --git a/t/unit/test_exceptions.py b/t/unit/test_exceptions.py index f72f3d6d..016364db 100644 --- a/t/unit/test_exceptions.py +++ b/t/unit/test_exceptions.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import, unicode_literals - from case import Mock - from kombu.exceptions import HttpError diff --git a/t/unit/transport/test_base.py b/t/unit/transport/test_base.py index f1e0b6c9..3ed9cbda 100644 --- a/t/unit/transport/test_base.py +++ b/t/unit/transport/test_base.py @@ -1,4 +1,3 @@ -from __future__ import absolute_import, unicode_literals import pytest from case import Mock from kombu import Connection, Consumer, Exchange, Producer, Queue diff --git a/t/unit/transport/test_etcd.py b/t/unit/transport/test_etcd.py index b61fd7a7..5e33b7ca 100644 --- a/t/unit/transport/test_etcd.py +++ b/t/unit/transport/test_etcd.py @@ -1,11 +1,6 @@ -from __future__ import absolute_import, unicode_literals - import pytest - +from queue import Empty from case import Mock, patch, skip - -from kombu.five import Empty - from kombu.transport.etcd import Channel, Transport diff --git a/t/unit/transport/test_qpid.py b/t/unit/transport/test_qpid.py index bf796212..7d87c789 100644 --- a/t/unit/transport/test_qpid.py +++ b/t/unit/transport/test_qpid.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, unicode_literals - import pytest import select import ssl @@ -7,16 +5,15 @@ import socket import sys import time import uuid - from collections import Callable, OrderedDict from itertools import count - +from queue import Empty +from time import monotonic from case import Mock, call, patch, skip - -from kombu.five import Empty, keys, range, monotonic -from kombu.transport.qpid import (AuthenticationFailure, Channel, Connection, - ConnectionError, Message, NotFound, QoS, - Transport) +from kombu.transport.qpid import ( + AuthenticationFailure, Channel, Connection, + ConnectionError, Message, NotFound, QoS, Transport, +) from kombu.transport.virtual import Base64 @@ -42,15 +39,15 @@ class ExtraAssertionsMixin(object): Also asserts that the value of each key is the same in a and b using the is operator. """ - assert set(keys(a)) == set(keys(b)) - for key in keys(a): + assert set(a.keys()) == set(b.keys()) + for key in a.keys(): assert a[key] == b[key] def assertDictContainsSubset(self, a, b, msg=None): """ Assert that all the key/value pairs in a exist in b. """ - for key in keys(a): + for key in a.keys(): assert key in b assert a[key] == b[key] diff --git a/t/unit/transport/virtual/test_exchange.py b/t/unit/transport/virtual/test_exchange.py index 93d48228..bc67a676 100644 --- a/t/unit/transport/virtual/test_exchange.py +++ b/t/unit/transport/virtual/test_exchange.py @@ -1,12 +1,7 @@ -from __future__ import absolute_import, unicode_literals - import pytest - from case import Mock - from kombu import Connection from kombu.transport.virtual import exchange - from t.mocks import Transport diff --git a/t/unit/utils/test_compat.py b/t/unit/utils/test_compat.py index c79b3202..3c7397b5 100644 --- a/t/unit/utils/test_compat.py +++ b/t/unit/utils/test_compat.py @@ -1,12 +1,7 @@ -from __future__ import absolute_import, unicode_literals - import socket import sys import types - from case import Mock, mock, patch - -from kombu.five import bytes_if_py2 from kombu.utils import compat from kombu.utils.compat import entrypoints, maybe_fileno @@ -70,17 +65,15 @@ class test_detect_environment: def test_detect_environment_no_eventlet_or_gevent(self): try: - sys.modules['eventlet'] = types.ModuleType( - bytes_if_py2('eventlet')) - sys.modules['eventlet.patcher'] = types.ModuleType( - bytes_if_py2('patcher')) + sys.modules['eventlet'] = types.ModuleType('eventlet') + sys.modules['eventlet.patcher'] = types.ModuleType('patcher') assert compat._detect_environment() == 'default' finally: sys.modules.pop('eventlet.patcher', None) sys.modules.pop('eventlet', None) compat._detect_environment() try: - sys.modules['gevent'] = types.ModuleType(bytes_if_py2('gevent')) + sys.modules['gevent'] = types.ModuleType('gevent') assert compat._detect_environment() == 'default' finally: sys.modules.pop('gevent', None) diff --git a/t/unit/utils/test_div.py b/t/unit/utils/test_div.py index f0b1a058..bbf60b40 100644 --- a/t/unit/utils/test_div.py +++ b/t/unit/utils/test_div.py @@ -1,9 +1,5 @@ -from __future__ import absolute_import, unicode_literals - import pickle - from io import StringIO, BytesIO - from kombu.utils.div import emergency_dump_state diff --git a/t/unit/utils/test_encoding.py b/t/unit/utils/test_encoding.py index e3d1040a..bff3a321 100644 --- a/t/unit/utils/test_encoding.py +++ b/t/unit/utils/test_encoding.py @@ -1,13 +1,7 @@ # -*- coding: utf-8 -*- -from __future__ import absolute_import, unicode_literals - import sys - from contextlib import contextmanager - from case import patch, skip - -from kombu.five import bytes_t, string_t from kombu.utils.encoding import ( get_default_encoding_file, safe_str, set_default_encoding_file, default_encoding, @@ -50,13 +44,13 @@ class test_default_encoding: @skip.if_python3() def test_str_to_bytes(): with clean_encoding() as e: - assert isinstance(e.str_to_bytes('foobar'), bytes_t) + assert isinstance(e.str_to_bytes('foobar'), bytes) @skip.if_python3() def test_from_utf8(): with clean_encoding() as e: - assert isinstance(e.from_utf8('foobar'), bytes_t) + assert isinstance(e.from_utf8('foobar'), bytes) @skip.if_python3() @@ -75,7 +69,7 @@ class test_safe_str: assert safe_str('foo') == 'foo' def test_when_unicode(self): - assert isinstance(safe_str('foo'), string_t) + assert isinstance(safe_str('foo'), str) def test_when_encoding_utf8(self): self._encoding.return_value = 'utf-8' diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py index 79216b56..b6ddeeca 100644 --- a/t/unit/utils/test_scheduling.py +++ b/t/unit/utils/test_scheduling.py @@ -1,9 +1,5 @@ -from __future__ import absolute_import, unicode_literals - import pytest - from case import Mock - from kombu.utils.scheduling import FairCycle, cycle_by_name diff --git a/t/unit/utils/test_time.py b/t/unit/utils/test_time.py index 4f6ddc0b..33d9fd6c 100644 --- a/t/unit/utils/test_time.py +++ b/t/unit/utils/test_time.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import, unicode_literals - import pytest - from kombu.utils.time import maybe_s_to_ms diff --git a/t/unit/utils/test_url.py b/t/unit/utils/test_url.py index 3d2b0ede..e1bc680b 100644 --- a/t/unit/utils/test_url.py +++ b/t/unit/utils/test_url.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import, unicode_literals - import pytest - from kombu.utils.url import as_url, parse_url, maybe_sanitize_url diff --git a/t/unit/utils/test_utils.py b/t/unit/utils/test_utils.py index f4668b83..82cd182b 100644 --- a/t/unit/utils/test_utils.py +++ b/t/unit/utils/test_utils.py @@ -1,7 +1,4 @@ -from __future__ import absolute_import, unicode_literals - import pytest - from kombu import version_info_t from kombu.utils.text import version_string_as_tuple |