summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2017-02-16 10:44:48 -0800
committerAsk Solem <ask@celeryproject.org>2017-02-16 10:44:48 -0800
commite6fab2f68b562cf1400bd8167e9b755f0482aafe (patch)
treeb4d6be32dc8c62fa032e3c1a1a74636ac8360a38
parentf2f7c67651106e77fb2db60ded134404ccc0a626 (diff)
downloadkombu-5.0-devel.tar.gz
-rw-r--r--Makefile3
-rw-r--r--docs/reference/kombu.abstract.rst2
-rw-r--r--kombu/__init__.py7
-rw-r--r--kombu/abstract.py75
-rw-r--r--kombu/async/timer.py10
-rw-r--r--kombu/clocks.py93
-rw-r--r--kombu/common.py93
-rw-r--r--kombu/compression.py16
-rw-r--r--kombu/connection.py213
-rw-r--r--kombu/entity.py465
-rw-r--r--kombu/message.py63
-rw-r--r--kombu/messaging.py288
-rw-r--r--kombu/pidbox.py191
-rw-r--r--kombu/pools.py60
-rw-r--r--kombu/resource.py55
-rw-r--r--kombu/serialization.py10
-rw-r--r--kombu/simple.py81
-rw-r--r--kombu/transport/base.py101
-rw-r--r--kombu/transport/pyamqp.py16
-rw-r--r--kombu/transport/virtual/__init__.py2
-rw-r--r--kombu/transport/virtual/base.py459
-rw-r--r--kombu/transport/virtual/exchange.py54
-rw-r--r--kombu/types.py893
-rw-r--r--kombu/utils/abstract.py49
-rw-r--r--kombu/utils/amq_manager.py4
-rw-r--r--kombu/utils/collections.py20
-rw-r--r--kombu/utils/compat.py17
-rw-r--r--kombu/utils/debug.py9
-rw-r--r--kombu/utils/div.py8
-rw-r--r--kombu/utils/encoding.py7
-rw-r--r--kombu/utils/eventio.py10
-rw-r--r--kombu/utils/functional.py102
-rw-r--r--kombu/utils/imports.py12
-rw-r--r--kombu/utils/json.py14
-rw-r--r--kombu/utils/limits.py17
-rw-r--r--kombu/utils/scheduling.py11
-rw-r--r--kombu/utils/text.py21
-rw-r--r--kombu/utils/time.py6
-rw-r--r--kombu/utils/typing.py21
-rw-r--r--kombu/utils/url.py30
-rw-r--r--kombu/utils/uuid.py3
-rw-r--r--moved/transport/SLMQ.py (renamed from kombu/transport/SLMQ.py)0
-rw-r--r--moved/transport/SQS.py (renamed from kombu/transport/SQS.py)0
-rw-r--r--moved/transport/consul.py (renamed from kombu/transport/consul.py)0
-rw-r--r--moved/transport/etcd.py (renamed from kombu/transport/etcd.py)8
-rw-r--r--moved/transport/filesystem.py (renamed from kombu/transport/filesystem.py)0
-rw-r--r--moved/transport/librabbitmq.py (renamed from kombu/transport/librabbitmq.py)0
-rw-r--r--moved/transport/memory.py (renamed from kombu/transport/memory.py)0
-rw-r--r--moved/transport/mongodb.py (renamed from kombu/transport/mongodb.py)0
-rw-r--r--moved/transport/pyro.py (renamed from kombu/transport/pyro.py)0
-rw-r--r--moved/transport/qpid.py (renamed from kombu/transport/qpid.py)7
-rw-r--r--moved/transport/redis.py (renamed from kombu/transport/redis.py)18
-rw-r--r--moved/transport/zookeeper.py (renamed from kombu/transport/zookeeper.py)0
-rw-r--r--requirements/test.txt1
-rw-r--r--t/integration/test_async.py51
-rw-r--r--t/oldint/__init__.py (renamed from t/integration/__init__.py)0
-rw-r--r--t/oldint/tests/__init__.py (renamed from t/integration/tests/__init__.py)0
-rw-r--r--t/oldint/tests/test_SLMQ.py (renamed from t/integration/tests/test_SLMQ.py)0
-rw-r--r--t/oldint/tests/test_SQS.py (renamed from t/integration/tests/test_SQS.py)0
-rw-r--r--t/oldint/tests/test_amqp.py (renamed from t/integration/tests/test_amqp.py)0
-rw-r--r--t/oldint/tests/test_librabbitmq.py (renamed from t/integration/tests/test_librabbitmq.py)0
-rw-r--r--t/oldint/tests/test_mongodb.py (renamed from t/integration/tests/test_mongodb.py)0
-rw-r--r--t/oldint/tests/test_pyamqp.py (renamed from t/integration/tests/test_pyamqp.py)0
-rw-r--r--t/oldint/tests/test_qpid.py (renamed from t/integration/tests/test_qpid.py)0
-rw-r--r--t/oldint/tests/test_redis.py (renamed from t/integration/tests/test_redis.py)0
-rw-r--r--t/oldint/tests/test_zookeeper.py (renamed from t/integration/tests/test_zookeeper.py)0
-rw-r--r--t/oldint/transport.py (renamed from t/integration/transport.py)0
-rw-r--r--t/unit/__init__.py1
-rw-r--r--t/unit/test_clocks.py5
-rw-r--r--t/unit/test_entity.py6
-rw-r--r--t/unit/test_exceptions.py3
-rw-r--r--t/unit/transport/test_base.py1
-rw-r--r--t/unit/transport/test_etcd.py7
-rw-r--r--t/unit/transport/test_qpid.py21
-rw-r--r--t/unit/transport/virtual/test_exchange.py5
-rw-r--r--t/unit/utils/test_compat.py13
-rw-r--r--t/unit/utils/test_div.py4
-rw-r--r--t/unit/utils/test_encoding.py12
-rw-r--r--t/unit/utils/test_scheduling.py4
-rw-r--r--t/unit/utils/test_time.py3
-rw-r--r--t/unit/utils/test_url.py3
-rw-r--r--t/unit/utils/test_utils.py3
82 files changed, 2571 insertions, 1216 deletions
diff --git a/Makefile b/Makefile
index fbf715d7..2e119cf0 100644
--- a/Makefile
+++ b/Makefile
@@ -151,3 +151,6 @@ build:
distcheck: lint test clean
dist: readme contrib clean-dist build
+
+typecheck:
+ $(PYTHON) -m mypy --fast-parser --python-version=3.6 --ignore-missing-imports $(PROJ)
diff --git a/docs/reference/kombu.abstract.rst b/docs/reference/kombu.abstract.rst
index 436bc535..c51f5b6a 100644
--- a/docs/reference/kombu.abstract.rst
+++ b/docs/reference/kombu.abstract.rst
@@ -9,6 +9,6 @@
.. contents::
:local:
- .. autoclass:: MaybeChannelBound
+ .. autoclass:: Entity
:members:
:undoc-members:
diff --git a/kombu/__init__.py b/kombu/__init__.py
index 848d2b99..6e0211e1 100644
--- a/kombu/__init__.py
+++ b/kombu/__init__.py
@@ -71,7 +71,7 @@ for module, items in all_by_module.items():
object_origins[item] = module
-class module(ModuleType):
+class _module(ModuleType):
"""Customized Python module."""
def __getattr__(self, name):
@@ -100,7 +100,7 @@ except NameError: # pragma: no cover
# keep a reference to this module so that it's not garbage collected
old_module = sys.modules[__name__]
-new_module = sys.modules[__name__] = module(__name__)
+new_module = sys.modules[__name__] = _module(__name__) # type: ignore
new_module.__dict__.update({
'__file__': __file__,
'__path__': __path__,
@@ -118,6 +118,7 @@ new_module.__dict__.update({
})
if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover
- os.environ.update(KOMBU_LOG_CHANNEL='1', KOMBU_LOG_CONNECTION='1')
+ os.environ['KOMBU_LOG_CHANNEL'] = '1'
+ os.environ['KOMBU_LOG_CONNECTION'] = '1'
from .utils import debug
debug.setup_logging()
diff --git a/kombu/abstract.py b/kombu/abstract.py
index 916b1eea..4f824415 100644
--- a/kombu/abstract.py
+++ b/kombu/abstract.py
@@ -1,14 +1,13 @@
-import amqp.abstract
-
from copy import copy
-from typing import Any, Dict
+from typing import Any, Dict, Optional
from typing import Sequence, Tuple # noqa
-
+from amqp.types import ChannelT
from .connection import maybe_channel
from .exceptions import NotBoundError
+from .types import EntityT
from .utils.functional import ChannelPromise
-__all__ = ['Object', 'MaybeChannelBound']
+__all__ = ['Entity']
def unpickle_dict(cls: Any, kwargs: Dict) -> Any:
@@ -19,13 +18,16 @@ def _any(v: Any) -> Any:
return v
-class Object:
- """Common base class.
-
- Supports automatic kwargs->attributes handling, and cloning.
- """
+class Entity(EntityT):
+ """Mixin for classes that can be bound to an AMQP channel."""
attrs = () # type: Sequence[Tuple[str, Any]]
+ #: Defines whether maybe_declare can skip declaring this entity twice.
+ can_cache_declaration = False
+
+ _channel: ChannelT
+ _is_bound: bool = False
+
def __init__(self, *args, **kwargs) -> None:
for name, type_ in self.attrs:
value = kwargs.get(name)
@@ -37,41 +39,15 @@ class Object:
except AttributeError:
setattr(self, name, None)
- def as_dict(self, recurse: bool=False) -> Dict:
- def f(obj: Any, type: Any) -> Any:
- if recurse and isinstance(obj, Object):
- return obj.as_dict(recurse=True)
- return type(obj) if type and obj is not None else obj
- return {
- attr: f(getattr(self, attr), type) for attr, type in self.attrs
- }
-
- def __reduce__(self) -> Any:
- return unpickle_dict, (self.__class__, self.as_dict())
-
- def __copy__(self) -> Any:
- return self.__class__(**self.as_dict())
-
-
-class MaybeChannelBound(Object):
- """Mixin for classes that can be bound to an AMQP channel."""
-
- _channel = None # type: amqp.abstract.Channel
- _is_bound = False # type: bool
-
- #: Defines whether maybe_declare can skip declaring this entity twice.
- can_cache_declaration = False
-
- def __call__(self, channel: amqp.abstract.Channel) -> 'MaybeChannelBound':
+ def __call__(self, channel: ChannelT) -> EntityT:
"""`self(channel) -> self.bind(channel)`"""
return self.bind(channel)
- def bind(self, channel: amqp.abstract.Channel) -> 'MaybeChannelBound':
+ def bind(self, channel: ChannelT) -> EntityT:
"""Create copy of the instance that is bound to a channel."""
return copy(self).maybe_bind(channel)
- def maybe_bind(self,
- channel: amqp.abstract.Channel) -> 'MaybeChannelBound':
+ def maybe_bind(self, channel: Optional[ChannelT]) -> EntityT:
"""Bind instance to channel if not already bound."""
if not self.is_bound and channel:
self._channel = maybe_channel(channel)
@@ -79,7 +55,7 @@ class MaybeChannelBound(Object):
self._is_bound = True
return self
- def revive(self, channel: amqp.abstract.Channel) -> None:
+ def revive(self, channel: ChannelT) -> None:
"""Revive channel after the connection has been re-established.
Used by :meth:`~kombu.Connection.ensure`.
@@ -96,20 +72,35 @@ class MaybeChannelBound(Object):
def __repr__(self) -> str:
return self._repr_entity(type(self).__name__)
- def _repr_entity(self, item: str='') -> str:
+ def _repr_entity(self, item: str = '') -> str:
item = item or type(self).__name__
if self.is_bound:
return '<{0} bound to chan:{1}>'.format(
item or type(self).__name__, self.channel.channel_id)
return '<unbound {0}>'.format(item)
+ def as_dict(self, recurse: bool = False) -> Dict:
+ def f(obj: Any, type: Any) -> Any:
+ if recurse and isinstance(obj, Entity):
+ return obj.as_dict(recurse=True)
+ return type(obj) if type and obj is not None else obj
+ return {
+ attr: f(getattr(self, attr), type) for attr, type in self.attrs
+ }
+
+ def __reduce__(self) -> Any:
+ return unpickle_dict, (self.__class__, self.as_dict())
+
+ def __copy__(self) -> Any:
+ return self.__class__(**self.as_dict())
+
@property
def is_bound(self) -> bool:
"""Flag set if the channel is bound."""
return self._is_bound and self._channel is not None
@property
- def channel(self) -> amqp.abstract.Channel:
+ def channel(self) -> ChannelT:
"""Current channel if the object is bound."""
channel = self._channel
if channel is None:
diff --git a/kombu/async/timer.py b/kombu/async/timer.py
index 69d93a36..b67f4ddf 100644
--- a/kombu/async/timer.py
+++ b/kombu/async/timer.py
@@ -3,10 +3,10 @@
import heapq
import sys
-from collections import namedtuple
from datetime import datetime
from functools import total_ordering
from time import monotonic
+from typing import NamedTuple
from weakref import proxy as weakrefproxy
from vine.utils import wraps
@@ -27,7 +27,13 @@ DEFAULT_MAX_INTERVAL = 2
EPOCH = datetime.utcfromtimestamp(0).replace(tzinfo=utc)
IS_PYPY = hasattr(sys, 'pypy_version_info')
-scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry'))
+
+class scheduled(NamedTuple):
+ """Information about scheduled item."""
+
+ eta: float
+ priority: int
+ entry: 'Entry'
def to_timestamp(d, default_timezone=utc, time=monotonic):
diff --git a/kombu/clocks.py b/kombu/clocks.py
index cef4570b..ff82ffb0 100644
--- a/kombu/clocks.py
+++ b/kombu/clocks.py
@@ -2,9 +2,9 @@
from threading import Lock
from itertools import islice
from operator import itemgetter
-from typing import Any, List, Sequence
+from typing import Any, Callable, List, Sequence
-__all__ = ['LamportClock', 'timetuple']
+__all__ = ['Clock', 'LamportClock', 'timetuple']
R_CLOCK = '_lamport(clock={0}, timestamp={1}, id={2} {3!r})'
@@ -24,7 +24,7 @@ class timetuple(tuple):
__slots__ = ()
def __new__(cls, clock: int, timestamp: float,
- id: str, obj: Any=None) -> 'timetuple':
+ id: str, obj: Any = None) -> 'timetuple':
return tuple.__new__(cls, (clock, timestamp, id, obj))
def __repr__(self) -> str:
@@ -61,7 +61,54 @@ class timetuple(tuple):
obj = property(itemgetter(3))
-class LamportClock:
+class Clock:
+
+ value: int = 0
+
+ def __init__(self, initial_value: int = 0, **kwargs) -> None:
+ self.value = initial_value
+
+ def adjust(self, other: int) -> int:
+ raise NotImplementedError()
+
+ def forward(self) -> int:
+ raise NotImplementedError()
+
+ def __str__(self) -> str:
+ return str(self.value)
+
+ def __repr__(self) -> str:
+ return '<{name}: {0.value}>'.format(self, name=type(self).__name__)
+
+ def sort_heap(self, h: List[Sequence]) -> Any:
+ """Sort heap of events.
+
+ List of tuples containing at least two elements, representing
+ an event, where the first element is the event's scalar clock value,
+ and the second element is the id of the process (usually
+ ``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])``
+
+ The list must already be sorted, which is why we refer to it as a
+ heap.
+
+ The tuple will not be unpacked, so more than two elements can be
+ present.
+
+ Will return the latest event.
+ """
+ if h[0][0] == h[1][0]:
+ same = []
+ for PN in zip(h, islice(h, 1, None)):
+ if PN[0][0] != PN[1][0]:
+ break # Prev and Next's clocks differ
+ same.append(PN[0])
+ # return first item sorted by process id
+ return sorted(same, key=lambda event: event[1])[0]
+ # clock values unique, return first item
+ return h[0]
+
+
+class LamportClock(Clock):
"""Lamport's logical clock.
From Wikipedia:
@@ -100,9 +147,10 @@ class LamportClock:
#: The clocks current value.
value = 0
- def __init__(self, initial_value: int=0, Lock: Any=Lock) -> None:
- self.value = initial_value
+ def __init__(self, initial_value: int = 0,
+ Lock: Callable = Lock, **kwargs) -> None:
self.mutex = Lock()
+ super().__init__(initial_value)
def adjust(self, other: int) -> int:
with self.mutex:
@@ -113,36 +161,3 @@ class LamportClock:
with self.mutex:
self.value += 1
return self.value
-
- def sort_heap(self, h: List[Sequence]) -> Any:
- """Sort heap of events.
-
- List of tuples containing at least two elements, representing
- an event, where the first element is the event's scalar clock value,
- and the second element is the id of the process (usually
- ``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])``
-
- The list must already be sorted, which is why we refer to it as a
- heap.
-
- The tuple will not be unpacked, so more than two elements can be
- present.
-
- Will return the latest event.
- """
- if h[0][0] == h[1][0]:
- same = []
- for PN in zip(h, islice(h, 1, None)):
- if PN[0][0] != PN[1][0]:
- break # Prev and Next's clocks differ
- same.append(PN[0])
- # return first item sorted by process id
- return sorted(same, key=lambda event: event[1])[0]
- # clock values unique, return first item
- return h[0]
-
- def __str__(self) -> str:
- return str(self.value)
-
- def __repr__(self) -> str:
- return '<LamportClock: {0.value}>'.format(self)
diff --git a/kombu/common.py b/kombu/common.py
index 0b1c8bbb..85134482 100644
--- a/kombu/common.py
+++ b/kombu/common.py
@@ -1,5 +1,4 @@
"""Common Utilities."""
-import amqp.abstract
import os
import socket
import threading
@@ -15,13 +14,15 @@ from typing import (
)
from uuid import uuid4, uuid3, NAMESPACE_OID
-from amqp import RecoverableConnectionError
+from amqp import ChannelT, RecoverableConnectionError
from .entity import Exchange, Queue
from .log import get_logger
from .serialization import registry as serializers
+from .types import (
+ ClientT, ConsumerT, EntityT, MessageT, ProducerT, ResourceT,
+)
from .utils import abstract
-from .utils.typing import Timeout
from .utils.uuid import uuid
try:
@@ -89,10 +90,10 @@ class Broadcast(Queue):
attrs = Queue.attrs + (('queue', None),)
- def __init__(self, name: Optional[str]=None,
- queue: Optional[abstract.Entity]=None,
- auto_delete: bool=True,
- exchange: Optional[Union[Exchange, str]]=None,
+ def __init__(self, name: str = None,
+ queue: EntityT = None,
+ auto_delete: bool = True,
+ exchange: Union[Exchange, str] = None,
alias: Optional[str]=None, **kwargs) -> None:
queue = queue or 'bcast.{0}'.format(uuid())
return super().__init__(
@@ -106,15 +107,15 @@ class Broadcast(Queue):
)
-def declaration_cached(entity: abstract.Entity,
- channel: amqp.abstract.Channel) -> bool:
+def declaration_cached(entity: EntityT,
+ channel: ChannelT) -> bool:
return entity in channel.connection.client.declared_entities
-def maybe_declare(entity: abstract.Entity,
- channel: amqp.abstract.Channel = None,
- retry: bool = False,
- **retry_policy) -> bool:
+async def maybe_declare(entity: EntityT,
+ channel: ChannelT = None,
+ retry: bool = False,
+ **retry_policy) -> bool:
"""Declare entity (cached)."""
is_bound = entity.is_bound
orig = entity
@@ -135,18 +136,18 @@ def maybe_declare(entity: abstract.Entity,
return False
if retry:
- return _imaybe_declare(entity, declared, ident,
- channel, orig, **retry_policy)
- return _maybe_declare(entity, declared, ident, channel, orig)
+ return await _imaybe_declare(entity, declared, ident,
+ channel, orig, **retry_policy)
+ return await _maybe_declare(entity, declared, ident, channel, orig)
-def _maybe_declare(entity: abstract.Entity, declared: MutableSet, ident: int,
- channel: Optional[amqp.abstract.Channel],
- orig: abstract.Entity = None) -> bool:
+async def _maybe_declare(entity: EntityT, declared: MutableSet, ident: int,
+ channel: Optional[ChannelT],
+ orig: EntityT = None) -> bool:
channel = channel or entity.channel
if not channel.connection:
raise RecoverableConnectionError('channel disconnected')
- entity.declare(channel=channel)
+ await entity.declare(channel=channel)
if declared is not None and ident:
declared.add(ident)
if orig is not None:
@@ -154,22 +155,22 @@ def _maybe_declare(entity: abstract.Entity, declared: MutableSet, ident: int,
return True
-def _imaybe_declare(entity: abstract.Entity,
- declared: MutableSet,
- ident: int,
- channel: Optional[amqp.abstract.Channel],
- orig: Optional[abstract.Entity]=None,
- **retry_policy) -> bool:
- return entity.channel.connection.client.ensure(
+async def _imaybe_declare(entity: EntityT,
+ declared: MutableSet,
+ ident: int,
+ channel: Optional[ChannelT],
+ orig: EntityT = None,
+ **retry_policy) -> bool:
+ return await entity.channel.connection.client.ensure(
entity, _maybe_declare, **retry_policy)(
entity, declared, ident, channel, orig)
def drain_consumer(
- consumer: abstract.Consumer,
+ consumer: ConsumerT,
limit: int = 1,
- timeout: Timeout = None,
- callbacks: Sequence[Callable] = None) -> Iterator[abstract.Message]:
+ timeout: float = None,
+ callbacks: Sequence[Callable] = None) -> Iterator[MessageT]:
"""Drain messages from consumer instance."""
acc = deque()
@@ -189,12 +190,12 @@ def drain_consumer(
def itermessages(
conn: abstract.Connection,
- channel: Optional[amqp.abstract.Channel],
- queue: abstract.Entity,
+ channel: Optional[ChannelT],
+ queue: EntityT,
limit: int = 1,
- timeout: Timeout = None,
+ timeout: float = None,
callbacks: Sequence[Callable] = None,
- **kwargs) -> Iterator[abstract.Message]:
+ **kwargs) -> Iterator[MessageT]:
"""Iterator over messages."""
return drain_consumer(
conn.Consumer(queues=[queue], channel=channel, **kwargs),
@@ -202,9 +203,9 @@ def itermessages(
)
-def eventloop(conn: abstract.Connection,
+def eventloop(conn: ClientT,
limit: int = None,
- timeout: Timeout = None,
+ timeout: float = None,
ignore_timeouts: bool = False) -> Iterator[Any]:
"""Best practice generator wrapper around ``Connection.drain_events``.
@@ -243,9 +244,9 @@ def eventloop(conn: abstract.Connection,
raise
-def send_reply(exchange: Union[abstract.Exchange, str],
- req: abstract.Message, msg: Any,
- producer: abstract.Producer = None,
+def send_reply(exchange: Union[Exchange, str],
+ req: MessageT, msg: Any,
+ producer: ProducerT = None,
retry: bool = False,
retry_policy: Mapping = None, **props) -> None:
"""Send reply for request.
@@ -271,9 +272,9 @@ def send_reply(exchange: Union[abstract.Exchange, str],
def collect_replies(
- conn: abstract.Connection,
- channel: Optional[amqp.abstract.Channel],
- queue: abstract.Entity,
+ conn: ClientT,
+ channel: Optional[ChannelT],
+ queue: EntityT,
*args, **kwargs) -> Iterator[Any]:
"""Generator collecting replies from ``queue``."""
no_ack = kwargs.setdefault('no_ack', True)
@@ -305,7 +306,7 @@ def _ignore_errors(conn) -> Iterator:
pass
-def ignore_errors(conn: abstract.Connection,
+def ignore_errors(conn: ClientT,
fun: Callable = None, *args, **kwargs) -> Any:
"""Ignore connection and channel errors.
@@ -339,14 +340,14 @@ def ignore_errors(conn: abstract.Connection,
return _ignore_errors(conn)
-def revive_connection(connection: abstract.Connection,
- channel: amqp.abstract.Channel,
+def revive_connection(connection: ClientT,
+ channel: ChannelT,
on_revive: Callable = None) -> None:
if on_revive:
on_revive(channel)
-def insured(pool: abstract.Resource,
+def insured(pool: ResourceT,
fun: Callable,
args: Sequence, kwargs: Dict,
errback: Callable = None,
diff --git a/kombu/compression.py b/kombu/compression.py
index 8e5ad99a..cec0282c 100644
--- a/kombu/compression.py
+++ b/kombu/compression.py
@@ -11,15 +11,15 @@ __all__ = [
'get_decoder', 'compress', 'decompress',
]
-TEncoder = Callable[[bytes], bytes]
-TDecoder = Callable[[bytes], bytes]
+EncoderT = Callable[[bytes], bytes]
+DecoderT = Callable[[bytes], bytes]
-_aliases = {} # type: MutableMapping[str, str]
-_encoders = {} # type: MutableMapping[str, TEncoder]
-_decoders = {} # type: MutableMapping[str, TDecoder]
+_aliases: MutableMapping[str, str] = {}
+_encoders: MutableMapping[str, EncoderT] = {}
+_decoders: MutableMapping[str, DecoderT] = {}
-def register(encoder: TEncoder, decoder: TDecoder, content_type: str,
+def register(encoder: EncoderT, decoder: DecoderT, content_type: str,
aliases: Sequence[str]=[]) -> None:
"""Register new compression method.
@@ -42,13 +42,13 @@ def encoders() -> Sequence[str]:
return list(_encoders)
-def get_encoder(t: str) -> Tuple[TEncoder, str]:
+def get_encoder(t: str) -> Tuple[EncoderT, str]:
"""Get encoder by alias name."""
t = _aliases.get(t, t)
return _encoders[t], t
-def get_decoder(t: str) -> TDecoder:
+def get_decoder(t: str) -> DecoderT:
"""Get decoder by alias name."""
return _decoders[_aliases.get(t, t)]
diff --git a/kombu/connection.py b/kombu/connection.py
index c4c6ff7d..11970ea1 100644
--- a/kombu/connection.py
+++ b/kombu/connection.py
@@ -7,6 +7,9 @@ from collections import OrderedDict
from contextlib import contextmanager
from itertools import count, cycle
from operator import itemgetter
+from typing import Any, Callable, Iterable, Mapping, Set, Sequence, Union
+
+from amqp.types import ChannelT, ConnectionT
# jython breaks on relative import for .exceptions for some reason
# (Issue #112)
@@ -15,7 +18,7 @@ from kombu import exceptions
from .log import get_logger
from .resource import Resource
from .transport import get_transport_cls, supports_librabbitmq
-from .utils import abstract
+from .types import ClientT, EntityT, ResourceT, TransportT
from .utils.collections import HashedSeq
from .utils.functional import dictfilter, lazy, retry_over_time, shufflecycle
from .utils.objects import cached_property
@@ -27,22 +30,21 @@ logger = get_logger(__name__)
roundrobin_failover = cycle
-resolve_aliases = {
+resolve_aliases: Mapping[str, str] = {
'pyamqp': 'amqp',
'librabbitmq': 'amqp',
}
-failover_strategies = {
+failover_strategies: Mapping[str, Callable] = {
'round-robin': roundrobin_failover,
'shuffle': shufflecycle,
}
-_log_connection = os.environ.get('KOMBU_LOG_CONNECTION', False)
-_log_channel = os.environ.get('KOMBU_LOG_CHANNEL', False)
+_log_connection = bool(os.environ.get('KOMBU_LOG_CONNECTION', False))
+_log_channel = bool(os.environ.get('KOMBU_LOG_CHANNEL', False))
-@abstract.Connection.register
-class Connection:
+class Connection(ClientT):
"""A connection to the broker.
Example:
@@ -106,48 +108,65 @@ class Connection:
:keyword port: Default port if not provided in the URL.
"""
- port = None
- virtual_host = '/'
- connect_timeout = 5
+ hostname: str
+ userid: str
+ password: str
+ ssl: Any
+ login_method: str
+ port: int = None
+ virtual_host: str = '/'
+ connect_timeout: float = 5.0
- _closed = None
- _connection = None
- _default_channel = None
- _transport = None
- _logger = False
- uri_prefix = None
+ uri_prefix: str = None
#: The cache of declared entities is per connection,
#: in case the server loses data.
- declared_entities = None
+ declared_entities: Set[EntityT] = None
#: Iterator returning the next broker URL to try in the event
#: of connection failure (initialized by :attr:`failover_strategy`).
- cycle = None
+ cycle: Iterable = None
#: Additional transport specific options,
#: passed on to the transport instance.
- transport_options = None
+ transport_options: Mapping = None
#: Strategy used to select new hosts when reconnecting after connection
#: failure. One of "round-robin", "shuffle" or any custom iterator
#: constantly yielding new URLs to try.
- failover_strategy = 'round-robin'
+ failover_strategy: str = 'round-robin'
#: Heartbeat value, currently only supported by the py-amqp transport.
- heartbeat = None
+ heartbeat: float = None
- resolve_aliases = resolve_aliases
- failover_strategies = failover_strategies
+ resolve_aliases: Mapping[str, str] = resolve_aliases
+ failover_strategies: Mapping[str, Callable] = failover_strategies
- hostname = userid = password = ssl = login_method = None
+ _closed: bool = None
+ _connection: ConnectionT = None
+ _default_channel: ChannelT = None
+ _transport: TransportT = None
+ _logger: bool = False
+ _initial_params: Mapping
- def __init__(self, hostname='localhost', userid=None,
- password=None, virtual_host=None, port=None, insist=False,
- ssl=False, transport=None, connect_timeout=5,
- transport_options=None, login_method=None, uri_prefix=None,
- heartbeat=0, failover_strategy='round-robin',
- alternates=None, **kwargs):
+ def __init__(
+ self,
+ hostname: str = 'localhost',
+ userid: str = None,
+ password: str = None,
+ virtual_host: str = None,
+ port: int = None,
+ insist: bool = False,
+ ssl: Any = None,
+ transport: Union[type, str] = None,
+ connect_timeout: float = 5.0,
+ transport_options: Mapping = None,
+ login_method: str = None,
+ uri_prefix: str = None,
+ heartbeat: float = None,
+ failover_strategy: str = 'round-robin',
+ alternates: Sequence[str] = None,
+ **kwargs):
alt = [] if alternates is None else alternates
# have to spell the args out, just to get nice docstrings :(
params = self._initial_params = {
@@ -208,7 +227,7 @@ class Connection:
self.declared_entities = set()
- def switch(self, url):
+ def switch(self, url: str) -> None:
"""Switch connection parameters to use a new URL.
Note:
@@ -219,14 +238,15 @@ class Connection:
self._closed = False
self._init_params(**dict(self._initial_params, **parse_url(url)))
- def maybe_switch_next(self):
+ def maybe_switch_next(self) -> None:
"""Switch to next URL given by the current failover strategy."""
if self.cycle:
self.switch(next(self.cycle))
- def _init_params(self, hostname, userid, password, virtual_host, port,
- insist, ssl, transport, connect_timeout,
- login_method, heartbeat):
+ def _init_params(self, hostname: str, userid: str, password: str,
+ virtual_host: str, port: int, insist: bool, ssl: Any,
+ transport: Union[type, str], connect_timeout: float,
+ login_method: str, heartbeat: float) -> None:
transport = transport or 'amqp'
if transport == 'amqp' and supports_librabbitmq():
transport = 'librabbitmq'
@@ -245,18 +265,20 @@ class Connection:
def register_with_event_loop(self, loop):
self.transport.register_with_event_loop(self.connection, loop)
- def _debug(self, msg, *args, **kwargs):
+ def _debug(self, msg: str, *args, **kwargs) -> None:
if self._logger: # pragma: no cover
fmt = '[Kombu connection:{id:#x}] {msg}'
logger.debug(fmt.format(id=id(self), msg=str(msg)),
*args, **kwargs)
- def connect(self):
+ async def connect(self) -> None:
"""Establish connection to server immediately."""
self._closed = False
- return self.connection
+ if self._connection is None:
+ await self._start_connection()
+ return self._connection
- def channel(self):
+ def channel(self) -> ChannelT:
"""Create and return a new channel."""
self._debug('create channel')
chan = self.transport.create_channel(self.connection)
@@ -266,7 +288,7 @@ class Connection:
'[Kombu channel:{0.channel_id}] ')
return chan
- def heartbeat_check(self, rate=2):
+ async def heartbeat_check(self, rate: int = 2) -> None:
"""Check heartbeats.
Allow the transport to perform any periodic tasks
@@ -283,9 +305,9 @@ class Connection:
is called every 3 / 2 seconds, then the rate is 2.
This value is currently unused by any transports.
"""
- return self.transport.heartbeat_check(self.connection, rate=rate)
+ await self.transport.heartbeat_check(self.connection, rate=rate)
- def drain_events(self, **kwargs):
+ async def drain_events(self, timeout: float = None, **kwargs) -> None:
"""Wait for a single event from the server.
Arguments:
@@ -294,30 +316,31 @@ class Connection:
Raises:
socket.timeout: if the timeout is exceeded.
"""
- return self.transport.drain_events(self.connection, **kwargs)
+ await self.transport.drain_events(
+ self.connection, timeout=timeout, **kwargs)
- def maybe_close_channel(self, channel):
+ async def maybe_close_channel(self, channel: ChannelT) -> None:
"""Close given channel, but ignore connection and channel errors."""
try:
- channel.close()
+ await channel.close()
except (self.connection_errors + self.channel_errors):
pass
- def _do_close_self(self):
+ async def _do_close_self(self):
# Close only connection and channel(s), but not transport.
self.declared_entities.clear()
if self._default_channel:
- self.maybe_close_channel(self._default_channel)
+ await self.maybe_close_channel(self._default_channel)
if self._connection:
try:
- self.transport.close_connection(self._connection)
+ await self.transport.close_connection(self._connection)
except self.connection_errors + (AttributeError, socket.error):
pass
self._connection = None
- def _close(self):
+ async def _close(self):
"""Really close connection, even if part of a connection pool."""
- self._do_close_self()
+ await self._do_close_self()
self._do_close_transport()
self._debug('closed')
self._closed = True
@@ -327,7 +350,7 @@ class Connection:
self._transport.client = None
self._transport = None
- def collect(self, socket_timeout=None):
+ async def collect(self, socket_timeout=None):
# amqp requires communication to close, we don't need that just
# to clear out references, Transport._collect can also be implemented
# by other transports that want fast after fork
@@ -337,7 +360,7 @@ class Connection:
_timeo = socket.getdefaulttimeout()
socket.setdefaulttimeout(socket_timeout)
try:
- self._do_close_self()
+ await self._do_close_self()
except socket.timeout:
pass
finally:
@@ -349,14 +372,22 @@ class Connection:
self.declared_entities.clear()
self._connection = None
- def release(self):
+ async def release(self):
"""Close the connection (if open)."""
- self._close()
- close = release
+ await self._close()
+
+ async def close(self):
+ await self._close()
- def ensure_connection(self, errback=None, max_retries=None,
- interval_start=2, interval_step=2, interval_max=30,
- callback=None, reraise_as_library_errors=True):
+ async def ensure_connection(
+ self,
+ errback=None,
+ max_retries=None,
+ interval_start=2,
+ interval_step=2,
+ interval_max=30,
+ callback=None,
+ reraise_as_library_errors=True):
"""Ensure we have a connection to the server.
If not retry establishing the connection with the settings
@@ -395,10 +426,12 @@ class Connection:
if not reraise_as_library_errors:
ctx = self._dummy_context
with ctx():
- retry_over_time(self.connect, self.recoverable_connection_errors,
- (), {}, on_error, max_retries,
- interval_start, interval_step, interval_max,
- callback)
+ await retry_over_time(
+ self.connect, self.recoverable_connection_errors,
+ (), {}, on_error, max_retries,
+ interval_start, interval_step, interval_max,
+ callback,
+ )
return self
@contextmanager
@@ -423,19 +456,23 @@ class Connection:
"""Return true if the cycle is complete after number of `retries`."""
return not (retries + 1) % len(self.alt) if self.alt else True
- def revive(self, new_channel):
+ async def revive(self, new_channel):
"""Revive connection after connection re-established."""
if self._default_channel and new_channel is not self._default_channel:
- self.maybe_close_channel(self._default_channel)
+ await self.maybe_close_channel(self._default_channel)
self._default_channel = None
def _default_ensure_callback(self, exc, interval):
logger.error("Ensure: Operation error: %r. Retry in %ss",
exc, interval, exc_info=True)
- def ensure(self, obj, fun, errback=None, max_retries=None,
- interval_start=1, interval_step=1, interval_max=1,
- on_revive=None):
+ async def ensure(self, obj, fun,
+ errback=None,
+ max_retries=None,
+ interval_start=1,
+ interval_step=1,
+ interval_max=1,
+ on_revive=None):
"""Ensure operation completes.
Regardless of any channel/connection errors occurring.
@@ -475,7 +512,7 @@ class Connection:
... errback=errback, max_retries=3)
>>> publish({'hello': 'world'}, routing_key='dest')
"""
- def _ensured(*args, **kwargs):
+ async def _ensured(*args, **kwargs):
got_connection = 0
conn_errors = self.recoverable_connection_errors
chan_errors = self.recoverable_channel_errors
@@ -485,7 +522,7 @@ class Connection:
with self._reraise_as_library_errors():
for retries in count(0): # for infinity
try:
- return fun(*args, **kwargs)
+ return await fun(*args, **kwargs)
except conn_errors as exc:
if got_connection and not has_modern_errors:
# transport can not distinguish between
@@ -497,21 +534,21 @@ class Connection:
raise
self._debug('ensure connection error: %r',
exc, exc_info=1)
- self.collect()
+ await self.collect()
errback and errback(exc, 0)
remaining_retries = None
if max_retries is not None:
remaining_retries = max(max_retries - retries, 1)
- self.ensure_connection(
+ await self.ensure_connection(
errback,
remaining_retries,
interval_start, interval_step, interval_max,
reraise_as_library_errors=False,
)
channel = self.default_channel
- obj.revive(channel)
+ await obj.revive(channel)
if on_revive:
- on_revive(channel)
+ await on_revive(channel)
got_connection += 1
except chan_errors as exc:
if max_retries is not None and retries > max_retries:
@@ -641,7 +678,7 @@ class Connection:
sanitize=not include_password, mask=mask,
)
- def Pool(self, limit=None, **kwargs):
+ def Pool(self, limit=None, **kwargs) -> ResourceT:
"""Pool of connections.
See Also:
@@ -667,7 +704,7 @@ class Connection:
"""
return ConnectionPool(self, limit, **kwargs)
- def ChannelPool(self, limit=None, **kwargs):
+ def ChannelPool(self, limit=None, **kwargs) -> ResourceT:
"""Pool of channels.
See Also:
@@ -746,9 +783,9 @@ class Connection:
return SimpleBuffer(channel or self, name, no_ack, queue_opts,
exchange_opts, **kwargs)
- def _establish_connection(self):
+ async def _establish_connection(self):
self._debug('establishing connection...')
- conn = self.transport.establish_connection()
+ conn = await self.transport.establish_connection()
self._debug('connection established: %r', self)
return conn
@@ -765,11 +802,19 @@ class Connection:
return self.__class__, tuple(self.info().values()), None
def __enter__(self):
+ self.connect()
return self
def __exit__(self, *args):
self.release()
+ async def __aenter__(self):
+ await self.connect()
+ return self
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.release()
+
@property
def qos_semantics_matches_spec(self):
return self.transport.qos_semantics_matches_spec(self.connection)
@@ -781,6 +826,12 @@ class Connection:
self._connection is not None and
self.transport.verify_connection(self._connection))
+ async def _start_connection(self):
+ self.declared_entities.clear()
+ self._default_channel = None
+ self._connection = await self._establish_connection()
+ self._closed = False
+
@property
def connection(self):
"""The underlying connection object.
@@ -789,13 +840,7 @@ class Connection:
This instance is transport specific, so do not
depend on the interface of this object.
"""
- if not self._closed:
- if not self.connected:
- self.declared_entities.clear()
- self._default_channel = None
- self._connection = self._establish_connection()
- self._closed = False
- return self._connection
+ return self._connection
@property
def default_channel(self):
diff --git a/kombu/entity.py b/kombu/entity.py
index 52ae605e..80522b55 100644
--- a/kombu/entity.py
+++ b/kombu/entity.py
@@ -1,42 +1,46 @@
"""Exchange and Queue declarations."""
import numbers
-
-from .abstract import MaybeChannelBound
+from typing import (
+ Any, Callable, Dict, Mapping, Set, Sequence, Tuple, Union,
+)
+from amqp.protocol import queue_declare_ok_t
+from amqp.types import ChannelT
+from .types import EntityT
from .exceptions import ContentDisallowed
from .serialization import prepare_accept_content
-from .utils import abstract
-
-TRANSIENT_DELIVERY_MODE = 1
-PERSISTENT_DELIVERY_MODE = 2
-DELIVERY_MODES = {'transient': TRANSIENT_DELIVERY_MODE,
- 'persistent': PERSISTENT_DELIVERY_MODE}
+from .types import BindingT, ExchangeT, MessageT, QueueT
__all__ = ['Exchange', 'Queue', 'binding', 'maybe_delivery_mode']
-INTERNAL_EXCHANGE_PREFIX = ('amq.',)
+TRANSIENT_DELIVERY_MODE = 1
+PERSISTENT_DELIVERY_MODE = 2
+DELIVERY_MODES: Mapping[str, int] = {
+ 'transient': TRANSIENT_DELIVERY_MODE,
+ 'persistent': PERSISTENT_DELIVERY_MODE,
+}
+INTERNAL_EXCHANGE_PREFIX: Tuple[str] = ('amq.',)
-def _reprstr(s):
- s = repr(s)
- if isinstance(s, str) and s.startswith("u'"):
- return s[2:-1]
- return s[1:-1]
+def _reprstr(s: Any) -> str:
+ return repr(s).strip("'")
-def pretty_bindings(bindings):
+def pretty_bindings(bindings: Sequence) -> str:
return '[{0}]'.format(', '.join(map(str, bindings)))
def maybe_delivery_mode(
- v, modes=DELIVERY_MODES, default=PERSISTENT_DELIVERY_MODE):
+ v: Union[numbers.Integral, str],
+ *,
+ modes: Mapping[str, int] = DELIVERY_MODES,
+ default: int = PERSISTENT_DELIVERY_MODE) -> int:
"""Get delivery mode by name (or none if undefined)."""
if v:
return v if isinstance(v, numbers.Integral) else modes[v]
return default
-@abstract.Entity.register
-class Exchange(MaybeChannelBound):
+class Exchange(ExchangeT):
"""An Exchange declaration.
Arguments:
@@ -137,7 +141,7 @@ class Exchange(MaybeChannelBound):
durable = True
auto_delete = False
passive = False
- delivery_mode = None
+ delivery_mode = None # type: int
no_declare = False
attrs = (
@@ -151,21 +155,29 @@ class Exchange(MaybeChannelBound):
('no_declare', bool),
)
- def __init__(self, name='', type='', channel=None, **kwargs):
+ def __init__(self,
+ name: str = '',
+ type: str = '',
+ channel: ChannelT = None,
+ **kwargs) -> None:
super().__init__(**kwargs)
self.name = name or self.name
self.type = type or self.type
self.maybe_bind(channel)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash('E|%s' % (self.name,))
- def _can_declare(self):
+ def _can_declare(self) -> bool:
return not self.no_declare and (
self.name and not self.name.startswith(
INTERNAL_EXCHANGE_PREFIX))
- def declare(self, nowait=False, passive=None, channel=None):
+ async def declare(
+ self,
+ nowait: bool = False,
+ passive: bool = None,
+ channel: ChannelT = None) -> None:
"""Declare the exchange.
Creates the exchange on the broker, unless passive is set
@@ -177,14 +189,20 @@ class Exchange(MaybeChannelBound):
"""
if self._can_declare():
passive = self.passive if passive is None else passive
- return (channel or self.channel).exchange_declare(
+ await (channel or self.channel).exchange_declare(
exchange=self.name, type=self.type, durable=self.durable,
auto_delete=self.auto_delete, arguments=self.arguments,
nowait=nowait, passive=passive,
)
- def bind_to(self, exchange='', routing_key='',
- arguments=None, nowait=False, channel=None, **kwargs):
+ async def bind_to(
+ self,
+ exchange: Union[str, ExchangeT] = '',
+ routing_key: str = '',
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None,
+ **kwargs) -> None:
"""Bind the exchange to another exchange.
Arguments:
@@ -192,9 +210,9 @@ class Exchange(MaybeChannelBound):
will not block waiting for a response.
Default is :const:`False`.
"""
- if isinstance(exchange, Exchange):
+ if isinstance(exchange, ExchangeT):
exchange = exchange.name
- return (channel or self.channel).exchange_bind(
+ await (channel or self.channel).exchange_bind(
destination=self.name,
source=exchange,
routing_key=routing_key,
@@ -202,12 +220,17 @@ class Exchange(MaybeChannelBound):
arguments=arguments,
)
- def unbind_from(self, source='', routing_key='',
- nowait=False, arguments=None, channel=None):
+ async def unbind_from(
+ self,
+ source: Union[str, ExchangeT] = '',
+ routing_key: str = '',
+ nowait: bool = False,
+ arguments: Mapping = None,
+ channel: ChannelT = None) -> None:
"""Delete previously created exchange binding from the server."""
- if isinstance(source, Exchange):
+ if isinstance(source, ExchangeT):
source = source.name
- return (channel or self.channel).exchange_unbind(
+ await (channel or self.channel).exchange_unbind(
destination=self.name,
source=source,
routing_key=routing_key,
@@ -215,43 +238,11 @@ class Exchange(MaybeChannelBound):
arguments=arguments,
)
- def Message(self, body, delivery_mode=None, properties=None, **kwargs):
- """Create message instance to be sent with :meth:`publish`.
-
- Arguments:
- body (Any): Message body.
-
- delivery_mode (bool): Set custom delivery mode.
- Defaults to :attr:`delivery_mode`.
-
- priority (int): Message priority, 0 to broker configured
- max priority, where higher is better.
-
- content_type (str): The messages content_type. If content_type
- is set, no serialization occurs as it is assumed this is either
- a binary object, or you've done your own serialization.
- Leave blank if using built-in serialization as our library
- properly sets content_type.
-
- content_encoding (str): The character set in which this object
- is encoded. Use "binary" if sending in raw binary objects.
- Leave blank if using built-in serialization as our library
- properly sets content_encoding.
-
- properties (Dict): Message properties.
-
- headers (Dict): Message headers.
- """
- # XXX This method is unused by kombu itself AFAICT [ask].
- properties = {} if properties is None else properties
- properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode)
- return self.channel.prepare_message(
- body,
- properties=properties,
- **kwargs)
-
- def publish(self, message, routing_key=None, mandatory=False,
- immediate=False, exchange=None):
+ async def publish(self, message: Union[MessageT, str, bytes],
+ routing_key: str = None,
+ mandatory: bool = False,
+ immediate: bool = False,
+ exchange: str = None) -> None:
"""Publish message.
Arguments:
@@ -264,7 +255,7 @@ class Exchange(MaybeChannelBound):
if isinstance(message, str):
message = self.Message(message)
exchange = exchange or self.name
- return self.channel.basic_publish(
+ await self.channel.basic_publish(
message,
exchange=exchange,
routing_key=routing_key,
@@ -272,7 +263,10 @@ class Exchange(MaybeChannelBound):
immediate=immediate,
)
- def delete(self, if_unused=False, nowait=False):
+ async def delete(
+ self,
+ if_unused: bool = False,
+ nowait: bool = False) -> None:
"""Delete the exchange declaration on server.
Arguments:
@@ -281,14 +275,20 @@ class Exchange(MaybeChannelBound):
nowait (bool): If set the server will not respond, and a
response will not be waited for. Default is :const:`False`.
"""
- return self.channel.exchange_delete(exchange=self.name,
- if_unused=if_unused,
- nowait=nowait)
+ await self.channel.exchange_delete(
+ exchange=self.name,
+ if_unused=if_unused,
+ nowait=nowait,
+ )
- def binding(self, routing_key='', arguments=None, unbind_arguments=None):
+ def binding(
+ self,
+ routing_key: str = '',
+ arguments: Mapping = None,
+ unbind_arguments: Mapping = None) -> 'binding':
return binding(self, routing_key, arguments, unbind_arguments)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, Exchange):
return (self.name == other.name and
self.type == other.type and
@@ -298,28 +298,70 @@ class Exchange(MaybeChannelBound):
self.delivery_mode == other.delivery_mode)
return NotImplemented
- def __ne__(self, other):
- return not self.__eq__(other)
+ def __ne__(self, other: Any) -> bool:
+ ret = self.__eq__(other)
+ if ret is NotImplemented:
+ return ret
+ return not ret
- def __repr__(self):
+ def __repr__(self) -> str:
return self._repr_entity(self)
- def __str__(self):
- return 'Exchange {0}({1})'.format(
+ def __str__(self) -> str:
+ return '{name} {0}({1})'.format(
_reprstr(self.name) or repr(''), self.type,
+ name=type(self).__name__,
)
+ def Message(
+ self, body: Any,
+ delivery_mode: Union[str, int] = None,
+ properties: Mapping = None,
+ **kwargs) -> Any:
+ """Create message instance to be sent with :meth:`publish`.
+
+ Arguments:
+ body (Any): Message body.
+
+ delivery_mode (Union[str, int]): Set custom delivery mode.
+ Defaults to :attr:`delivery_mode`.
+
+ priority (int): Message priority, 0 to broker configured
+ max priority, where higher is better.
+
+ content_type (str): The messages content_type. If content_type
+ is set, no serialization occurs as it is assumed this is either
+ a binary object, or you've done your own serialization.
+ Leave blank if using built-in serialization as our library
+ properly sets content_type.
+
+ content_encoding (str): The character set in which this object
+ is encoded. Use "binary" if sending in raw binary objects.
+ Leave blank if using built-in serialization as our library
+ properly sets content_encoding.
+
+ properties (Dict): Message properties.
+
+ headers (Dict): Message headers.
+ """
+ # XXX This method is unused by kombu itself AFAICT [ask].
+ properties = {} if properties is None else properties
+ properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode)
+ return self.channel.prepare_message(
+ body,
+ properties=properties,
+ **kwargs)
+
@property
- def can_cache_declaration(self):
+ def can_cache_declaration(self) -> bool:
return not self.auto_delete
-@abstract.Entity.register
-class binding:
+class binding(BindingT):
"""Represents a queue or exchange binding.
Arguments:
- exchange (Exchange): Exchange to bind to.
+ exchange (ExchangeT): Exchange to bind to.
routing_key (str): Routing key used as binding key.
arguments (Dict): Arguments for bind operation.
unbind_arguments (Dict): Arguments for unbind operation.
@@ -332,45 +374,58 @@ class binding:
('unbind_arguments', None)
)
- def __init__(self, exchange=None, routing_key='',
- arguments=None, unbind_arguments=None):
+ def __init__(
+ self,
+ exchange: ExchangeT = None,
+ routing_key: str = '',
+ arguments: Mapping = None,
+ unbind_arguments: Mapping = None) -> None:
self.exchange = exchange
self.routing_key = routing_key
self.arguments = arguments
self.unbind_arguments = unbind_arguments
- def declare(self, channel, nowait=False):
+ def declare(self, channel: ChannelT, nowait: bool = False) -> None:
"""Declare destination exchange."""
if self.exchange and self.exchange.name:
self.exchange.declare(channel=channel, nowait=nowait)
- def bind(self, entity, nowait=False, channel=None):
+ def bind(
+ self, entity: EntityT,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
"""Bind entity to this binding."""
- entity.bind_to(exchange=self.exchange,
- routing_key=self.routing_key,
- arguments=self.arguments,
- nowait=nowait,
- channel=channel)
+ entity.bind_to(
+ exchange=self.exchange,
+ routing_key=self.routing_key,
+ arguments=self.arguments,
+ nowait=nowait,
+ channel=channel,
+ )
- def unbind(self, entity, nowait=False, channel=None):
+ def unbind(
+ self, entity: EntityT,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
"""Unbind entity from this binding."""
- entity.unbind_from(self.exchange,
- routing_key=self.routing_key,
- arguments=self.unbind_arguments,
- nowait=nowait,
- channel=channel)
+ entity.unbind_from(
+ self.exchange,
+ routing_key=self.routing_key,
+ arguments=self.unbind_arguments,
+ nowait=nowait,
+ channel=channel
+ )
- def __repr__(self):
+ def __repr__(self) -> str:
return '<binding: {0}>'.format(self)
- def __str__(self):
+ def __str__(self) -> str:
return '{0}->{1}'.format(
_reprstr(self.exchange.name), _reprstr(self.routing_key),
)
-@abstract.Entity.register
-class Queue(MaybeChannelBound):
+class Queue(QueueT):
"""A Queue declaration.
Arguments:
@@ -561,9 +616,15 @@ class Queue(MaybeChannelBound):
('max_priority', int)
)
- def __init__(self, name='', exchange=None, routing_key='',
- channel=None, bindings=None, on_declared=None,
- **kwargs):
+ def __init__(
+ self,
+ name: str = '',
+ exchange: ExchangeT = None,
+ routing_key: str = '',
+ channel: ChannelT = None,
+ bindings: Sequence[BindingT] = None,
+ on_declared: Callable = None,
+ **kwargs) -> None:
super().__init__(**kwargs)
self.name = name or self.name
self.exchange = exchange or self.exchange
@@ -582,44 +643,56 @@ class Queue(MaybeChannelBound):
self.auto_delete = True
self.maybe_bind(channel)
- def bind(self, channel):
+ def bind(self, channel: ChannelT) -> EntityT:
on_declared = self.on_declared
bound = super().bind(channel)
bound.on_declared = on_declared
return bound
- def __hash__(self):
+ def __hash__(self) -> int:
return hash('Q|%s' % (self.name,))
- def when_bound(self):
+ def when_bound(self) -> None:
if self.exchange:
self.exchange = self.exchange(self.channel)
- def declare(self, nowait=False, channel=None):
+ async def declare(self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> str:
"""Declare queue and exchange then binds queue to exchange."""
if not self.no_declare:
# - declare main binding.
- self._create_exchange(nowait=nowait, channel=channel)
- self._create_queue(nowait=nowait, channel=channel)
- self._create_bindings(nowait=nowait, channel=channel)
+ await self._create_exchange(nowait=nowait, channel=channel)
+ await self._create_queue(nowait=nowait, channel=channel)
+ await self._create_bindings(nowait=nowait, channel=channel)
return self.name
- def _create_exchange(self, nowait=False, channel=None):
+ async def _create_exchange(self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
if self.exchange:
- self.exchange.declare(nowait=nowait, channel=channel)
+ await self.exchange.declare(nowait=nowait, channel=channel)
- def _create_queue(self, nowait=False, channel=None):
- self.queue_declare(nowait=nowait, passive=False, channel=channel)
+ async def _create_queue(self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ await self.queue_declare(nowait=nowait, passive=False, channel=channel)
if self.exchange and self.exchange.name:
- self.queue_bind(nowait=nowait, channel=channel)
+ await self.queue_bind(nowait=nowait, channel=channel)
- def _create_bindings(self, nowait=False, channel=None):
+ async def _create_bindings(self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
for B in self.bindings:
channel = channel or self.channel
- B.declare(channel)
- B.bind(self, nowait=nowait, channel=channel)
-
- def queue_declare(self, nowait=False, passive=False, channel=None):
+ await B.declare(channel)
+ await B.bind(self, nowait=nowait, channel=channel)
+
+ async def queue_declare(
+ self,
+ nowait: bool = False,
+ passive: bool = False,
+ channel: ChannelT = None) -> queue_declare_ok_t:
"""Declare queue on the server.
Arguments:
@@ -637,7 +710,7 @@ class Queue(MaybeChannelBound):
max_length_bytes=self.max_length_bytes,
max_priority=self.max_priority,
)
- ret = channel.queue_declare(
+ ret = await channel.queue_declare(
queue=self.name,
passive=passive,
durable=self.durable,
@@ -652,18 +725,26 @@ class Queue(MaybeChannelBound):
self.on_declared(*ret)
return ret
- def queue_bind(self, nowait=False, channel=None):
+ async def queue_bind(
+ self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
"""Create the queue binding on the server."""
- return self.bind_to(self.exchange, self.routing_key,
- self.binding_arguments,
- channel=channel, nowait=nowait)
-
- def bind_to(self, exchange='', routing_key='',
- arguments=None, nowait=False, channel=None):
+ await self.bind_to(
+ self.exchange, self.routing_key, self.binding_arguments,
+ channel=channel, nowait=nowait)
+
+ async def bind_to(
+ self,
+ exchange: Union[str, ExchangeT] = '',
+ routing_key: str = '',
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
if isinstance(exchange, Exchange):
exchange = exchange.name
- return (channel or self.channel).queue_bind(
+ await (channel or self.channel).queue_bind(
queue=self.name,
exchange=exchange,
routing_key=routing_key,
@@ -671,7 +752,10 @@ class Queue(MaybeChannelBound):
nowait=nowait,
)
- def get(self, no_ack=None, accept=None):
+ async def get(
+ self,
+ no_ack: bool = None,
+ accept: Set[str] = None) -> MessageT:
"""Poll the server for a new message.
This method provides direct access to the messages in a
@@ -686,10 +770,10 @@ class Queue(MaybeChannelBound):
Arguments:
no_ack (bool): If enabled the broker will
automatically ack messages.
- accept (Set[str]): Custom list of accepted content types.
+ accept (Container): Custom list of accepted content types.
"""
no_ack = self.no_ack if no_ack is None else no_ack
- message = self.channel.basic_get(queue=self.name, no_ack=no_ack)
+ message = await self.channel.basic_get(queue=self.name, no_ack=no_ack)
if message is not None:
m2p = getattr(self.channel, 'message_to_python', None)
if m2p:
@@ -699,13 +783,17 @@ class Queue(MaybeChannelBound):
message.accept = prepare_accept_content(accept)
return message
- def purge(self, nowait=False):
+ async def purge(self, nowait: bool = False) -> int:
"""Remove all ready messages from the queue."""
return self.channel.queue_purge(queue=self.name,
nowait=nowait) or 0
- def consume(self, consumer_tag='', callback=None,
- no_ack=None, nowait=False):
+ async def consume(
+ self,
+ consumer_tag: str = '',
+ callback: Callable = None,
+ no_ack: bool = None,
+ nowait: bool = False) -> None:
"""Start a queue consumer.
Consumers last as long as the channel they were created on, or
@@ -726,7 +814,7 @@ class Queue(MaybeChannelBound):
"""
if no_ack is None:
no_ack = self.no_ack
- return self.channel.basic_consume(
+ await self.channel.basic_consume(
queue=self.name,
no_ack=no_ack,
consumer_tag=consumer_tag or '',
@@ -734,11 +822,15 @@ class Queue(MaybeChannelBound):
nowait=nowait,
arguments=self.consumer_arguments)
- def cancel(self, consumer_tag):
+ async def cancel(self, consumer_tag: str) -> None:
"""Cancel a consumer by consumer tag."""
- return self.channel.basic_cancel(consumer_tag)
+ await self.channel.basic_cancel(consumer_tag)
- def delete(self, if_unused=False, if_empty=False, nowait=False):
+ async def delete(
+ self,
+ if_unused: bool = False,
+ if_empty: bool = False,
+ nowait: bool = False) -> None:
"""Delete the queue.
Arguments:
@@ -751,19 +843,30 @@ class Queue(MaybeChannelBound):
nowait (bool): Do not wait for a reply.
"""
- return self.channel.queue_delete(queue=self.name,
- if_unused=if_unused,
- if_empty=if_empty,
- nowait=nowait)
-
- def queue_unbind(self, arguments=None, nowait=False, channel=None):
- return self.unbind_from(self.exchange, self.routing_key,
- arguments, nowait, channel)
+ await self.channel.queue_delete(
+ queue=self.name,
+ if_unused=if_unused,
+ if_empty=if_empty,
+ nowait=nowait,
+ )
- def unbind_from(self, exchange='', routing_key='',
- arguments=None, nowait=False, channel=None):
+ async def queue_unbind(
+ self,
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ await self.unbind_from(self.exchange, self.routing_key,
+ arguments, nowait, channel)
+
+ async def unbind_from(
+ self,
+ exchange: str = '',
+ routing_key: str = '',
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
"""Unbind queue by deleting the binding from the server."""
- return (channel or self.channel).queue_unbind(
+ await (channel or self.channel).queue_unbind(
queue=self.name,
exchange=exchange.name,
routing_key=routing_key,
@@ -771,7 +874,7 @@ class Queue(MaybeChannelBound):
nowait=nowait,
)
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
if isinstance(other, Queue):
return (self.name == other.name and
self.exchange == other.exchange and
@@ -785,9 +888,12 @@ class Queue(MaybeChannelBound):
return NotImplemented
def __ne__(self, other):
- return not self.__eq__(other)
+ ret = self.__eq__(other)
+ if ret is NotImplemented:
+ return ret
+ return not ret
- def __repr__(self):
+ def __repr__(self) -> str:
if self.bindings:
return self._repr_entity('Queue {name} -> {bindings}'.format(
name=_reprstr(self.name),
@@ -801,30 +907,31 @@ class Queue(MaybeChannelBound):
)
@property
- def can_cache_declaration(self):
+ def can_cache_declaration(self) -> bool:
return not self.auto_delete
@classmethod
- def from_dict(self, queue,
- exchange=None,
- exchange_type=None,
- binding_key=None,
- routing_key=None,
- delivery_mode=None,
- bindings=None,
- durable=None,
- queue_durable=None,
- exchange_durable=None,
- auto_delete=None,
- queue_auto_delete=None,
- exchange_auto_delete=None,
- exchange_arguments=None,
- queue_arguments=None,
- binding_arguments=None,
- consumer_arguments=None,
- exclusive=None,
- no_ack=None,
- **options):
+ def from_dict(
+ self, queue: str,
+ exchange: str = None,
+ exchange_type: str = None,
+ binding_key: str = None,
+ routing_key: str = None,
+ delivery_mode: Union[int, str] = None,
+ bindings: Sequence = None,
+ durable: bool = None,
+ queue_durable: bool = None,
+ exchange_durable: bool = None,
+ auto_delete: bool = None,
+ queue_auto_delete: bool = None,
+ exchange_auto_delete: bool = None,
+ exchange_arguments: Mapping = None,
+ queue_arguments: Mapping = None,
+ binding_arguments: Mapping = None,
+ consumer_arguments: Mapping = None,
+ exclusive: bool = None,
+ no_ack: bool = None,
+ **options) -> QueueT:
return Queue(
queue,
exchange=Exchange(
@@ -850,7 +957,7 @@ class Queue(MaybeChannelBound):
bindings=bindings,
)
- def as_dict(self, recurse=False):
+ def as_dict(self, recurse: bool = False) -> Dict:
res = super().as_dict(recurse)
if not recurse:
return res
diff --git a/kombu/message.py b/kombu/message.py
index dd6f0b4e..9d74571b 100644
--- a/kombu/message.py
+++ b/kombu/message.py
@@ -1,10 +1,12 @@
"""Message class."""
+import logging
import sys
-
+from typing import Any, Callable, List, Mapping, Set, Tuple
+from amqp.types import ChannelT
+from . import types
from .compression import decompress
from .exceptions import MessageStateError
from .serialization import loads
-from .utils import abstract
from .utils.functional import dictfilter
__all__ = ['Message']
@@ -13,7 +15,7 @@ ACK_STATES = {'ACK', 'REJECTED', 'REQUEUED'}
IS_PYPY = hasattr(sys, 'pypy_version_info')
-@abstract.Message.register
+@types.MessageT.register
class Message:
"""Base class for received messages.
@@ -48,7 +50,7 @@ class Message:
MessageStateError = MessageStateError
- errors = None
+ errors: List[Any] = None
if not IS_PYPY: # pragma: no cover
__slots__ = (
@@ -58,10 +60,19 @@ class Message:
'body', '_decoded_cache', 'accept', '__dict__',
)
- def __init__(self, body=None, delivery_tag=None,
- content_type=None, content_encoding=None, delivery_info={},
- properties=None, headers=None, postencode=None,
- accept=None, channel=None, **kwargs):
+ def __init__(
+ self,
+ body: Any = None,
+ delivery_tag: str = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ delivery_info: Mapping = None,
+ properties: Mapping = None,
+ headers: Mapping = None,
+ postencode: str = None,
+ accept: Set[str] = None,
+ channel: ChannelT = None,
+ **kwargs) -> None:
self.errors = [] if self.errors is None else self.errors
self.channel = channel
self.delivery_tag = delivery_tag
@@ -88,7 +99,7 @@ class Message:
self.errors.append(sys.exc_info())
self.body = body
- def _reraise_error(self, callback=None):
+ def _reraise_error(self, callback: Callable = None) -> None:
try:
raise self.errors[0][1].with_traceback(self.errors[0][2])
except Exception as exc:
@@ -96,7 +107,7 @@ class Message:
raise
callback(self, exc)
- def ack(self, multiple=False):
+ async def ack(self, multiple: bool = False) -> None:
"""Acknowledge this message as being processed.
This will remove the message from the queue.
@@ -120,24 +131,28 @@ class Message:
raise self.MessageStateError(
'Message already acknowledged with state: {0._state}'.format(
self))
- self.channel.basic_ack(self.delivery_tag, multiple=multiple)
+ await self.channel.basic_ack(self.delivery_tag, multiple=multiple)
self._state = 'ACK'
- def ack_log_error(self, logger, errors, multiple=False):
+ async def ack_log_error(
+ self, logger: logging.Logger, errors: Tuple[type, ...],
+ multiple: bool = False) -> None:
try:
- self.ack(multiple=multiple)
+ await self.ack(multiple=multiple)
except errors as exc:
logger.critical("Couldn't ack %r, reason:%r",
self.delivery_tag, exc, exc_info=True)
- def reject_log_error(self, logger, errors, requeue=False):
+ async def reject_log_error(
+ self, logger: logging.Logger, errors: Tuple[type, ...],
+ requeue: bool = False) -> None:
try:
- self.reject(requeue=requeue)
+ await self.reject(requeue=requeue)
except errors as exc:
logger.critical("Couldn't reject %r, reason: %r",
self.delivery_tag, exc, exc_info=True)
- def reject(self, requeue=False):
+ async def reject(self, requeue: bool = False) -> None:
"""Reject this message.
The message will be discarded by the server.
@@ -153,10 +168,10 @@ class Message:
raise self.MessageStateError(
'Message already acknowledged with state: {0._state}'.format(
self))
- self.channel.basic_reject(self.delivery_tag, requeue=requeue)
+ await self.channel.basic_reject(self.delivery_tag, requeue=requeue)
self._state = 'REJECTED'
- def requeue(self):
+ async def requeue(self) -> None:
"""Reject this message and put it back on the queue.
Warning:
@@ -174,10 +189,10 @@ class Message:
raise self.MessageStateError(
'Message already acknowledged with state: {0._state}'.format(
self))
- self.channel.basic_reject(self.delivery_tag, requeue=True)
+ await self.channel.basic_reject(self.delivery_tag, requeue=True)
self._state = 'REQUEUED'
- def decode(self):
+ def decode(self) -> Any:
"""Deserialize the message body.
Returning the original python structure sent by the publisher.
@@ -190,21 +205,21 @@ class Message:
self._decoded_cache = self._decode()
return self._decoded_cache
- def _decode(self):
+ def _decode(self) -> Any:
return loads(self.body, self.content_type,
self.content_encoding, accept=self.accept)
@property
- def acknowledged(self):
+ def acknowledged(self) -> bool:
"""Set to true if the message has been acknowledged."""
return self._state in ACK_STATES
@property
- def payload(self):
+ def payload(self) -> Any:
"""The decoded message body."""
return self._decoded_cache if self._decoded_cache else self.decode()
- def __repr__(self):
+ def __repr__(self) -> str:
return '<{0} object at {1:#x} with details {2!r}>'.format(
type(self).__name__, id(self), dictfilter(
state=self._state,
diff --git a/kombu/messaging.py b/kombu/messaging.py
index 3dd03d14..67ccd15d 100644
--- a/kombu/messaging.py
+++ b/kombu/messaging.py
@@ -1,20 +1,23 @@
"""Sending and receiving messages."""
+import numbers
from itertools import count
-
+from typing import Any, Callable, Mapping, Sequence, Tuple, Union
+from amqp import ChannelT
+from . import types
+from .abstract import Entity
from .common import maybe_declare
from .compression import compress
from .connection import maybe_channel, is_connection
from .entity import Exchange, Queue, maybe_delivery_mode
from .exceptions import ContentDisallowed
from .serialization import dumps, prepare_accept_content
-from .utils import abstract
+from .types import ChannelArgT, ClientT, MessageT, QueueT
from .utils.functional import ChannelPromise, maybe_list
__all__ = ['Exchange', 'Queue', 'Producer', 'Consumer']
-@abstract.Producer.register
-class Producer:
+class Producer(types.ProducerT):
"""Message Producer.
Arguments:
@@ -34,7 +37,7 @@ class Producer:
"""
#: Default exchange
- exchange = None
+ exchange = None # type: Exchange
#: Default routing key.
routing_key = ''
@@ -56,9 +59,16 @@ class Producer:
#: default_channel).
__connection__ = None
- def __init__(self, channel, exchange=None, routing_key=None,
- serializer=None, auto_declare=None, compression=None,
- on_return=None):
+ _channel: ChannelT
+
+ def __init__(self,
+ channel: ChannelArgT,
+ exchange: Exchange = None,
+ routing_key: str = None,
+ serializer: str = None,
+ auto_declare: bool = None,
+ compression: str = None,
+ on_return: Callable = None):
self._channel = channel
self.exchange = exchange
self.routing_key = routing_key or self.routing_key
@@ -71,20 +81,17 @@ class Producer:
if auto_declare is not None:
self.auto_declare = auto_declare
- if self._channel:
- self.revive(self._channel)
-
- def __repr__(self):
+ def __repr__(self) -> str:
return '<Producer: {0._channel}>'.format(self)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Any, ...]:
return self.__class__, self.__reduce_args__()
- def __reduce_args__(self):
+ def __reduce_args__(self) -> Tuple[Any, ...]:
return (None, self.exchange, self.routing_key, self.serializer,
self.auto_declare, self.compression)
- def declare(self):
+ async def declare(self) -> None:
"""Declare the exchange.
Note:
@@ -92,16 +99,18 @@ class Producer:
the :attr:`auto_declare` flag is enabled.
"""
if self.exchange.name:
- self.exchange.declare()
+ await self.exchange.declare()
- def maybe_declare(self, entity, retry=False, **retry_policy):
+ async def maybe_declare(self, entity: Entity,
+ retry: bool = False, **retry_policy) -> None:
"""Declare exchange if not already declared during this session."""
if entity:
- return maybe_declare(entity, self.channel, retry, **retry_policy)
+ await maybe_declare(entity, self.channel, retry, **retry_policy)
- def _delivery_details(self, exchange, delivery_mode=None,
- maybe_delivery_mode=maybe_delivery_mode,
- Exchange=Exchange):
+ def _delivery_details(self, exchange: Union[Exchange, str],
+ delivery_mode: Union[int, str]=None,
+ maybe_delivery_mode: Callable = maybe_delivery_mode,
+ Exchange: Any = Exchange) -> Tuple[str, int]:
if isinstance(exchange, Exchange):
return exchange.name, maybe_delivery_mode(
delivery_mode or exchange.delivery_mode,
@@ -112,12 +121,23 @@ class Producer:
delivery_mode or self.exchange.delivery_mode,
)
- def publish(self, body, routing_key=None, delivery_mode=None,
- mandatory=False, immediate=False, priority=0,
- content_type=None, content_encoding=None, serializer=None,
- headers=None, compression=None, exchange=None, retry=False,
- retry_policy=None, declare=None, expiration=None,
- **properties):
+ async def publish(self, body: Any,
+ routing_key: str = None,
+ delivery_mode: Union[int, str] = None,
+ mandatory: bool = False,
+ immediate: bool = False,
+ priority: int = 0,
+ content_type: str = None,
+ content_encoding: str = None,
+ serializer: str = None,
+ headers: Mapping = None,
+ compression: str = None,
+ exchange: Union[Exchange, str] = None,
+ retry: bool = False,
+ retry_policy: Mapping = None,
+ declare: Sequence[Entity] = None,
+ expiration: numbers.Number = None,
+ **properties) -> None:
"""Publish message to the specified exchange.
Arguments:
@@ -172,54 +192,79 @@ class Producer:
declare.append(self.exchange)
if retry:
- _publish = self.connection.ensure(self, _publish, **retry_policy)
- return _publish(
+ _publish = await self.connection.ensure(
+ self, _publish, **retry_policy)
+ await _publish(
body, priority, content_type, content_encoding,
headers, properties, routing_key, mandatory, immediate,
exchange_name, declare,
)
- def _publish(self, body, priority, content_type, content_encoding,
- headers, properties, routing_key, mandatory,
- immediate, exchange, declare):
- channel = self.channel
+ async def _publish(self,
+ body: Any,
+ priority: int = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ headers: Mapping = None,
+ properties: Mapping = None,
+ routing_key: str = None,
+ mandatory: bool = None,
+ immediate: bool = None,
+ exchange: str = None,
+ declare: Sequence = None) -> None:
+ channel = await self._resolve_channel()
message = channel.prepare_message(
body, priority, content_type,
content_encoding, headers, properties,
)
if declare:
maybe_declare = self.maybe_declare
- [maybe_declare(entity) for entity in declare]
+ for entity in declare:
+ await maybe_declare(entity)
# handle autogenerated queue names for reply_to
reply_to = properties.get('reply_to')
if isinstance(reply_to, Queue):
properties['reply_to'] = reply_to.name
- return channel.basic_publish(
+ await channel.basic_publish(
message,
exchange=exchange, routing_key=routing_key,
mandatory=mandatory, immediate=immediate,
)
- def _get_channel(self):
+ async def _resolve_channel(self):
+ channel = self._channel
+ if isinstance(channel, ChannelPromise):
+ channel = self._channel = await channel.resolve()
+ if self.exchange:
+ self.exchange.revive(channel)
+ if self.on_return:
+ channel.events['basic_return'].add(self.on_return)
+ else:
+ channel = maybe_channel(channel)
+ await channel.open()
+ return channel
+
+ def _get_channel(self) -> ChannelT:
channel = self._channel
if isinstance(channel, ChannelPromise):
channel = self._channel = channel()
- self.exchange.revive(channel)
+ if self.exchange:
+ self.exchange.revive(channel)
if self.on_return:
channel.events['basic_return'].add(self.on_return)
return channel
- def _set_channel(self, channel):
+ def _set_channel(self, channel: ChannelT) -> None:
self._channel = channel
channel = property(_get_channel, _set_channel)
- def revive(self, channel):
+ async def revive(self, channel: ChannelT) -> None:
"""Revive the producer after connection loss."""
if is_connection(channel):
connection = channel
self.__connection__ = connection
- channel = ChannelPromise(lambda: connection.default_channel)
+ channel = ChannelPromise(connection)
if isinstance(channel, ChannelPromise):
self._channel = channel
self.exchange = self.exchange(channel)
@@ -230,18 +275,29 @@ class Producer:
self._channel.events['basic_return'].add(self.on_return)
self.exchange = self.exchange(channel)
- def __enter__(self):
+ def __enter__(self) -> 'Producer':
return self
- def __exit__(self, *exc_info):
+ async def __aenter__(self) -> 'Producer':
+ await self.revive(self.channel)
+ return self
+
+ def __exit__(self, *exc_info) -> None:
self.release()
- def release(self):
- pass
+ async def __aexit__(self, *exc_info) -> None:
+ self.release()
+
+ def release(self) -> None:
+ ...
close = release
- def _prepare(self, body, serializer=None, content_type=None,
- content_encoding=None, compression=None, headers=None):
+ def _prepare(self, body: Any,
+ serializer: str = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ compression: str = None,
+ headers: Mapping = None) -> Tuple[Any, str, str]:
# No content_type? Then we're serializing the data internally.
if not content_type:
@@ -267,15 +323,14 @@ class Producer:
return body, content_type, content_encoding
@property
- def connection(self):
+ def connection(self) -> ClientT:
try:
return self.__connection__ or self.channel.connection.client
except AttributeError:
pass
-@abstract.Consumer.register
-class Consumer:
+class Consumer(types.ConsumerT):
"""Message consumer.
Arguments:
@@ -358,13 +413,23 @@ class Consumer:
prefetch_count = None
#: Mapping of queues we consume from.
- _queues = None
+ _queues: Mapping[str, QueueT] = None
_tags = count(1) # global
-
- def __init__(self, channel, queues=None, no_ack=None, auto_declare=None,
- callbacks=None, on_decode_error=None, on_message=None,
- accept=None, prefetch_count=None, tag_prefix=None):
+ _active_tags: Mapping[str, str]
+
+ def __init__(
+ self,
+ channel: ChannelT,
+ queues: Sequence[QueueT] = None,
+ no_ack: bool = None,
+ auto_declare: bool = None,
+ callbacks: Sequence[Callable] = None,
+ on_decode_error: Callable = None,
+ on_message: Callable = None,
+ accept: Sequence[str] = None,
+ prefetch_count: int = None,
+ tag_prefix: str = None) -> None:
self.channel = channel
self.queues = maybe_list(queues or [])
self.no_ack = self.no_ack if no_ack is None else no_ack
@@ -380,21 +445,19 @@ class Consumer:
self.accept = prepare_accept_content(accept)
self.prefetch_count = prefetch_count
- if self.channel:
- self.revive(self.channel)
-
@property
- def queues(self):
+ def queues(self) -> Sequence[QueueT]:
return list(self._queues.values())
@queues.setter
- def queues(self, queues):
+ def queues(self, queues: Sequence[QueueT]) -> None:
self._queues = {q.name: q for q in queues}
- def revive(self, channel):
+ async def revive(self, channel: ChannelT) -> None:
"""Revive consumer after connection loss."""
self._active_tags.clear()
channel = self.channel = maybe_channel(channel)
+ await channel.open()
for qname, queue in self._queues.items():
# name may have changed after declare
self._queues.pop(qname, None)
@@ -402,22 +465,21 @@ class Consumer:
queue.revive(channel)
if self.auto_declare:
- self.declare()
+ await self.declare()
if self.prefetch_count is not None:
- self.qos(prefetch_count=self.prefetch_count)
+ await self.qos(prefetch_count=self.prefetch_count)
- def declare(self):
+ async def declare(self) -> None:
"""Declare queues, exchanges and bindings.
Note:
This is done automatically at instantiation
when :attr:`auto_declare` is set.
"""
- for queue in self._queues.values():
- queue.declare()
+ [await queue.declare() for queue in self._queues.values()]
- def register_callback(self, callback):
+ def register_callback(self, callback: Callable) -> None:
"""Register a new callback to be called when a message is received.
Note:
@@ -427,11 +489,26 @@ class Consumer:
"""
self.callbacks.append(callback)
- def __enter__(self):
+ def __enter__(self) -> 'Consumer':
+ self.revive(self.channel)
self.consume()
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ async def __aenter__(self) -> 'Consumer':
+ await self.revive(self.channel)
+ await self.consume()
+ return self
+
+ async def __aexit__(self, *exc_info) -> None:
+ if self.channel:
+ conn_errors = self.channel.connection.client.connection_errors
+ if not isinstance(exc_info[1], conn_errors):
+ try:
+ await self.cancel()
+ except Exception:
+ pass
+
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self.channel:
conn_errors = self.channel.connection.client.connection_errors
if not isinstance(exc_val, conn_errors):
@@ -440,7 +517,7 @@ class Consumer:
except Exception:
pass
- def add_queue(self, queue):
+ async def add_queue(self, queue: QueueT) -> QueueT:
"""Add a queue to the list of queues to consume from.
Note:
@@ -449,11 +526,11 @@ class Consumer:
"""
queue = queue(self.channel)
if self.auto_declare:
- queue.declare()
+ await queue.declare()
self._queues[queue.name] = queue
return queue
- def consume(self, no_ack=None):
+ async def consume(self, no_ack: bool = None) -> None:
"""Start consuming messages.
Can be called multiple times, but note that while it
@@ -470,10 +547,10 @@ class Consumer:
H, T = queues[:-1], queues[-1]
for queue in H:
- self._basic_consume(queue, no_ack=no_ack, nowait=True)
- self._basic_consume(T, no_ack=no_ack, nowait=False)
+ await self._basic_consume(queue, no_ack=no_ack, nowait=True)
+ await self._basic_consume(T, no_ack=no_ack, nowait=False)
- def cancel(self):
+ async def cancel(self) -> None:
"""End all active queue consumers.
Note:
@@ -482,38 +559,36 @@ class Consumer:
"""
cancel = self.channel.basic_cancel
for tag in self._active_tags.values():
- cancel(tag)
+ await cancel(tag)
self._active_tags.clear()
close = cancel
- def cancel_by_queue(self, queue):
+ async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None:
"""Cancel consumer by queue name."""
- qname = queue.name if isinstance(queue, Queue) else queue
+ qname = queue.name if isinstance(queue, QueueT) else queue
try:
tag = self._active_tags.pop(qname)
except KeyError:
pass
else:
- self.channel.basic_cancel(tag)
+ await self.channel.basic_cancel(tag)
finally:
self._queues.pop(qname, None)
- def consuming_from(self, queue):
+ def consuming_from(self, queue: Union[QueueT, str]) -> bool:
"""Return :const:`True` if currently consuming from queue'."""
- name = queue
- if isinstance(queue, Queue):
- name = queue.name
+ name = queue.name if isinstance(queue, QueueT) else queue
return name in self._active_tags
- def purge(self):
+ async def purge(self) -> int:
"""Purge messages from all queues.
Warning:
This will *delete all ready messages*, there is no undo operation.
"""
- return sum(queue.purge() for queue in self._queues.values())
+ return sum(await queue.purge() for queue in self._queues.values())
- def flow(self, active):
+ async def flow(self, active: bool) -> None:
"""Enable/disable flow from peer.
This is a simple flow-control mechanism that a peer can use
@@ -524,9 +599,12 @@ class Consumer:
will finish sending the current content (if any), and then wait
until flow is reactivated.
"""
- self.channel.flow(active)
+ await self.channel.flow(active)
- def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False):
+ async def qos(self,
+ prefetch_size: int = 0,
+ prefetch_count: int = 0,
+ apply_global: bool = False) -> None:
"""Specify quality of service.
The client can request that messages should be sent in
@@ -550,11 +628,10 @@ class Consumer:
apply_global (bool): Apply new settings globally on all channels.
"""
- return self.channel.basic_qos(prefetch_size,
- prefetch_count,
- apply_global)
+ await self.channel.basic_qos(
+ prefetch_size, prefetch_count, apply_global)
- def recover(self, requeue=False):
+ async def recover(self, requeue: bool = False) -> None:
"""Redeliver unacknowledged messages.
Asks the broker to redeliver all unacknowledged messages
@@ -566,9 +643,9 @@ class Consumer:
server will attempt to requeue the message, potentially then
delivering it to an alternative subscriber.
"""
- return self.channel.basic_recover(requeue=requeue)
+ await self.channel.basic_recover(requeue=requeue)
- def receive(self, body, message):
+ async def receive(self, body: Any, message: MessageT) -> None:
"""Method called when a message is received.
This dispatches to the registered :attr:`callbacks`.
@@ -584,24 +661,27 @@ class Consumer:
callbacks = self.callbacks
if not callbacks:
raise NotImplementedError('Consumer does not have any callbacks')
- [callback(body, message) for callback in callbacks]
+ for callback in callbacks:
+ await callback(body, message)
- def _basic_consume(self, queue, consumer_tag=None,
- no_ack=no_ack, nowait=True):
+ async def _basic_consume(self, queue: QueueT,
+ consumer_tag: str = None,
+ no_ack: bool = no_ack,
+ nowait: bool = True) -> str:
tag = self._active_tags.get(queue.name)
if tag is None:
tag = self._add_tag(queue, consumer_tag)
- queue.consume(tag, self._receive_callback,
- no_ack=no_ack, nowait=nowait)
+ await queue.consume(tag, self._receive_callback,
+ no_ack=no_ack, nowait=nowait)
return tag
- def _add_tag(self, queue, consumer_tag=None):
+ def _add_tag(self, queue: QueueT, consumer_tag: str = None) -> str:
tag = consumer_tag or '{0}{1}'.format(
self.tag_prefix, next(self._tags))
self._active_tags[queue.name] = tag
return tag
- def _receive_callback(self, message):
+ async def _receive_callback(self, message: MessageT) -> None:
accept = self.accept
on_m, channel, decoded = self.on_message, self.channel, None
try:
@@ -611,20 +691,20 @@ class Consumer:
if accept is not None:
message.accept = accept
if message.errors:
- return message._reraise_error(self.on_decode_error)
+ message._reraise_error(self.on_decode_error)
decoded = None if on_m else message.decode()
except Exception as exc:
if not self.on_decode_error:
raise
self.on_decode_error(message, exc)
else:
- return on_m(message) if on_m else self.receive(decoded, message)
+ await on_m(message) if on_m else self.receive(decoded, message)
- def __repr__(self):
+ def __repr__(self) -> str:
return '<{name}: {0.queues}>'.format(self, name=type(self).__name__)
@property
- def connection(self):
+ def connection(self) -> ClientT:
try:
return self.channel.connection.client
except AttributeError:
diff --git a/kombu/pidbox.py b/kombu/pidbox.py
index 0c6196fc..4775a59b 100644
--- a/kombu/pidbox.py
+++ b/kombu/pidbox.py
@@ -8,12 +8,19 @@ from copy import copy
from itertools import count
from threading import local
from time import time
+from typing import Any, Callable, Mapping, Set, Sequence, cast
-from . import Exchange, Queue, Consumer, Producer
-from .clocks import LamportClock
+from amqp.types import ChannelT
+
+from .clocks import Clock, LamportClock
from .common import maybe_declare, oid_from
+from .entity import Exchange, Queue
from .exceptions import InconsistencyError
from .log import get_logger
+from .messaging import Consumer, Producer
+from .typing import (
+ ClientT, ConsumerT, ExchangeT, ProducerT, MessageT, ResourceT,
+)
from .utils.functional import maybe_evaluate, reprcall
from .utils.objects import cached_property
from .utils.uuid import uuid
@@ -35,22 +42,26 @@ class Node:
"""Mailbox node."""
#: hostname of the node.
- hostname = None
+ hostname: str = None
#: the :class:`Mailbox` this is a node for.
- mailbox = None
+ mailbox: 'Mailbox' = None
#: map of method name/handlers.
- handlers = None
+ handlers: Mapping[str, Callable] = None
#: current context (passed on to handlers)
- state = None
+ state: Any = None
#: current channel.
- channel = None
-
- def __init__(self, hostname, state=None, channel=None,
- handlers=None, mailbox=None):
+ channel: ChannelT = None
+
+ def __init__(self,
+ hostname: str,
+ state: Any = None,
+ channel: ChannelT = None,
+ handlers: Mapping[str, Callable] = None,
+ mailbox: 'Mailbox' = None) -> None:
self.channel = channel
self.mailbox = mailbox
self.hostname = hostname
@@ -60,10 +71,14 @@ class Node:
handlers = {}
self.handlers = handlers
- def Consumer(self, channel=None, no_ack=True, accept=None, **options):
+ def Consumer(self,
+ channel: ChannelT = None,
+ no_ack: bool = True,
+ accept: Set[str] = None,
+ **options) -> ConsumerT:
queue = self.mailbox.get_queue(self.hostname)
- def verify_exclusive(name, messages, consumers):
+ def verify_exclusive(name: str, messages: int, consumers: int) -> None:
if consumers:
warnings.warn(W_PIDBOX_IN_USE.format(node=self))
queue.on_declared = verify_exclusive
@@ -74,22 +89,24 @@ class Node:
**options
)
- def handler(self, fun):
+ def handler(self, fun: Callable) -> Callable:
self.handlers[fun.__name__] = fun
return fun
- def on_decode_error(self, message, exc):
+ def on_decode_error(self, message: str, exc: Exception) -> None:
error('Cannot decode message: %r', exc, exc_info=1)
- def listen(self, channel=None, callback=None):
+ def listen(self,
+ channel: ChannelT = None,
+ callback: Callable = None) -> ConsumerT:
consumer = self.Consumer(channel=channel,
callbacks=[callback or self.handle_message],
on_decode_error=self.on_decode_error)
consumer.consume()
return consumer
- def dispatch(self, method, arguments=None,
- reply_to=None, ticket=None, **kwargs):
+ def dispatch(self, method: str, arguments: Mapping = None,
+ reply_to: str = None, ticket: str = None, **kwargs) -> Any:
arguments = arguments or {}
debug('pidbox received method %s [reply_to:%s ticket:%s]',
reprcall(method, (), kwargs=arguments), reply_to, ticket)
@@ -109,16 +126,17 @@ class Node:
ticket=ticket)
return reply
- def handle(self, method, arguments={}):
+ def handle(self, method: str, arguments: Mapping = {}) -> Any:
return self.handlers[method](self.state, **arguments)
- def handle_call(self, method, arguments):
+ def handle_call(self, method: str, arguments: Mapping) -> Any:
return self.handle(method, arguments)
- def handle_cast(self, method, arguments):
+ def handle_cast(self, method: str, arguments: Mapping) -> Any:
return self.handle(method, arguments)
- def handle_message(self, body, message=None):
+ def handle_message(self, body: Any, message: MessageT = None) -> None:
+ body = cast(Mapping, body)
destination = body.get('destination')
if message:
self.adjust_clock(message.headers.get('clock') or 0)
@@ -126,7 +144,8 @@ class Node:
return self.dispatch(**body)
dispatch_from_message = handle_message
- def reply(self, data, exchange, routing_key, ticket, **kwargs):
+ def reply(self, data: Any, exchange: str, routing_key: str, ticket: str,
+ **kwargs) -> None:
self.mailbox._publish_reply(data, exchange, routing_key, ticket,
channel=self.channel,
serializer=self.mailbox.serializer)
@@ -135,36 +154,42 @@ class Node:
class Mailbox:
"""Process Mailbox."""
- node_cls = Node
- exchange_fmt = '%s.pidbox'
- reply_exchange_fmt = 'reply.%s.pidbox'
+ node_cls: type = Node
+ exchange_fmt: str = '%s.pidbox'
+ reply_exchange_fmt: str = 'reply.%s.pidbox'
#: Name of application.
- namespace = None
+ namespace: str = None
#: Connection (if bound).
- connection = None
+ connection: ClientT = None
#: Exchange type (usually direct, or fanout for broadcast).
- type = 'direct'
+ type: str = 'direct'
#: mailbox exchange (init by constructor).
- exchange = None
+ exchange: ExchangeT = None
#: exchange to send replies to.
- reply_exchange = None
+ reply_exchange: ExchangeT = None
#: Only accepts json messages by default.
- accept = ['json']
+ accept: Set[str] = {'json'}
#: Message serializer
- serializer = None
-
- def __init__(self, namespace,
- type='direct', connection=None, clock=None,
- accept=None, serializer=None, producer_pool=None,
- queue_ttl=None, queue_expires=None,
- reply_queue_ttl=None, reply_queue_expires=10.0):
+ serializer: str = None
+
+ def __init__(self, namespace: str,
+ type: str = 'direct',
+ connection: ClientT = None,
+ clock: Clock = None,
+ accept: Set[str] = None,
+ serializer: str = None,
+ producer_pool: ResourceT = None,
+ queue_ttl: float = None,
+ queue_expires: float = None,
+ reply_queue_ttl: float = None,
+ reply_queue_expires: float = 10.0) -> None:
self.namespace = namespace
self.connection = connection
self.type = type
@@ -181,36 +206,48 @@ class Mailbox:
self.reply_queue_expires = reply_queue_expires
self._producer_pool = producer_pool
- def __call__(self, connection):
+ def __call__(self, connection: ClientT) -> 'Mailbox':
bound = copy(self)
bound.connection = connection
return bound
- def Node(self, hostname=None, state=None, channel=None, handlers=None):
+ def Node(self,
+ hostname: str = None,
+ state: Any = None,
+ channel: ChannelT = None,
+ handlers: Mapping[str, Callable] = None) -> Node:
hostname = hostname or socket.gethostname()
return self.node_cls(hostname, state, channel, handlers, mailbox=self)
- def call(self, destination, command, kwargs={},
- timeout=None, callback=None, channel=None):
+ def call(self, destination: str, command: str,
+ kwargs: Mapping[str, Any] = {},
+ timeout: float = None,
+ callback: Callable = None,
+ channel: ChannelT = None) -> Sequence[Mapping]:
return self._broadcast(command, kwargs, destination,
reply=True, timeout=timeout,
callback=callback,
channel=channel)
- def cast(self, destination, command, kwargs={}):
- return self._broadcast(command, kwargs, destination, reply=False)
+ def cast(self, destination: str, command: str,
+ kwargs: Mapping[str, Any] = {}) -> None:
+ self._broadcast(command, kwargs, destination, reply=False)
- def abcast(self, command, kwargs={}):
- return self._broadcast(command, kwargs, reply=False)
+ def abcast(self, command: str, kwargs: Mapping[str, Any] = {}) -> None:
+ self._broadcast(command, kwargs, reply=False)
- def multi_call(self, command, kwargs={}, timeout=1,
- limit=None, callback=None, channel=None):
+ def multi_call(self, command: str,
+ kwargs: Mapping[str, Any] = {},
+ timeout: str = 1,
+ limit: int = None,
+ callback: Callable = None,
+ channel: ChannelT = None) -> Sequence[Mapping]:
return self._broadcast(command, kwargs, reply=True,
timeout=timeout, limit=limit,
callback=callback,
channel=channel)
- def get_reply_queue(self):
+ def get_reply_queue(self) -> Queue:
oid = self.oid
return Queue(
'%s.%s' % (oid, self.reply_exchange.name),
@@ -223,10 +260,10 @@ class Mailbox:
)
@cached_property
- def reply_queue(self):
+ def reply_queue(self) -> Queue:
return self.get_reply_queue()
- def get_queue(self, hostname):
+ def get_queue(self, hostname: str) -> Queue:
return Queue(
'%s.%s.pidbox' % (hostname, self.namespace),
exchange=self.exchange,
@@ -237,7 +274,9 @@ class Mailbox:
)
@contextmanager
- def producer_or_acquire(self, producer=None, channel=None):
+ def producer_or_acquire(self,
+ producer: ProducerT = None,
+ channel: ChannelT = None) -> ProducerT:
if producer:
yield producer
elif self.producer_pool:
@@ -246,8 +285,11 @@ class Mailbox:
else:
yield Producer(channel, auto_declare=False)
- def _publish_reply(self, reply, exchange, routing_key, ticket,
- channel=None, producer=None, **opts):
+ def _publish_reply(self, reply: Any,
+ exchange: str, routing_key: str, ticket: str,
+ channel: ChannelT = None,
+ producer: ProducerT = None,
+ **opts) -> None:
chan = channel or self.connection.default_channel
exchange = Exchange(exchange, exchange_type='direct',
delivery_mode='transient',
@@ -265,9 +307,13 @@ class Mailbox:
# queue probably deleted and no one is expecting a reply.
pass
- def _publish(self, type, arguments, destination=None,
- reply_ticket=None, channel=None, timeout=None,
- serializer=None, producer=None):
+ def _publish(self, type: str, arguments: Mapping[str, Any],
+ destination: str = None,
+ reply_ticket: str = None,
+ channel: ChannelT = None,
+ timeout: float = None,
+ serializer: str = None,
+ producer: ProducerT = None) -> None:
message = {'method': type,
'arguments': arguments,
'destination': destination}
@@ -287,9 +333,15 @@ class Mailbox:
serializer=serializer,
)
- def _broadcast(self, command, arguments=None, destination=None,
- reply=False, timeout=1, limit=None,
- callback=None, channel=None, serializer=None):
+ def _broadcast(self, command: str,
+ arguments: Mapping[str, Any] = None,
+ destination: str = None,
+ reply: bool = False,
+ timeout: float = 1.0,
+ limit: int = None,
+ callback: Callable = None,
+ channel: ChannelT = None,
+ serializer: str = None) -> Sequence[Mapping]:
if destination is not None and \
not isinstance(destination, (list, tuple)):
raise ValueError(
@@ -317,9 +369,12 @@ class Mailbox:
callback=callback,
channel=chan)
- def _collect(self, ticket,
- limit=None, timeout=1, callback=None,
- channel=None, accept=None):
+ def _collect(self, ticket: str,
+ limit: str = None,
+ timeout: float = 1.0,
+ callback: Callable = None,
+ channel: ChannelT = None,
+ accept: Set[str] = None) -> Sequence[Mapping]:
if accept is None:
accept = self.accept
chan = channel or self.connection.default_channel
@@ -334,7 +389,7 @@ class Mailbox:
except KeyError:
pass
- def on_message(body, message):
+ def on_message(body: Any, message: MessageT) -> None:
# ticket header added in kombu 2.5
header = message.headers.get
adjust_clock(header('clock') or 0)
@@ -361,20 +416,20 @@ class Mailbox:
finally:
chan.after_reply_message_received(queue.name)
- def _get_exchange(self, namespace, type):
+ def _get_exchange(self, namespace: str, type: str) -> Exchange:
return Exchange(self.exchange_fmt % namespace,
type=type,
durable=False,
delivery_mode='transient')
- def _get_reply_exchange(self, namespace):
+ def _get_reply_exchange(self, namespace: str) -> Exchange:
return Exchange(self.reply_exchange_fmt % namespace,
type='direct',
durable=False,
delivery_mode='transient')
@cached_property
- def oid(self):
+ def oid(self) -> str:
try:
return self._tls.OID
except AttributeError:
@@ -382,5 +437,5 @@ class Mailbox:
return oid
@cached_property
- def producer_pool(self):
+ def producer_pool(self) -> ResourceT:
return maybe_evaluate(self._producer_pool)
diff --git a/kombu/pools.py b/kombu/pools.py
index f9168886..88770212 100644
--- a/kombu/pools.py
+++ b/kombu/pools.py
@@ -1,23 +1,25 @@
"""Public resource pools."""
import os
-
from itertools import chain
-
+from typing import Any
from .connection import Resource
from .messaging import Producer
+from .types import ClientT, ProducerT, ResourceT
from .utils.collections import EqualityDict
from .utils.compat import register_after_fork
from .utils.functional import lazy
-__all__ = ['ProducerPool', 'PoolGroup', 'register_group',
- 'connections', 'producers', 'get_limit', 'set_limit', 'reset']
+__all__ = [
+ 'ProducerPool', 'PoolGroup', 'register_group',
+ 'connections', 'producers', 'get_limit', 'set_limit', 'reset',
+]
_limit = [10]
_groups = []
-use_global_limit = object()
+use_global_limit = 44444444444444444444444444
disable_limit_protection = os.environ.get('KOMBU_DISABLE_LIMIT_PROTECTION')
-def _after_fork_cleanup_group(group):
+def _after_fork_cleanup_group(group: 'PoolGroup') -> None:
group.clear()
@@ -27,15 +29,19 @@ class ProducerPool(Resource):
Producer = Producer
close_after_fork = True
- def __init__(self, connections, *args, Producer=None, **kwargs):
+ def __init__(self,
+ connections: ResourceT,
+ *args,
+ Producer: type = None,
+ **kwargs) -> None:
self.connections = connections
self.Producer = Producer or self.Producer
super().__init__(*args, **kwargs)
- def _acquire_connection(self):
+ def _acquire_connection(self) -> ClientT:
return self.connections.acquire(block=True)
- def create_producer(self):
+ def create_producer(self) -> ProducerT:
conn = self._acquire_connection()
try:
return self.Producer(conn)
@@ -43,18 +49,18 @@ class ProducerPool(Resource):
conn.release()
raise
- def new(self):
+ def new(self) -> lazy:
return lazy(self.create_producer)
- def setup(self):
+ def setup(self) -> None:
if self.limit:
for _ in range(self.limit):
self._resource.put_nowait(self.new())
- def close_resource(self, resource):
+ def close_resource(self, resource: ProducerT) -> None:
...
- def prepare(self, p):
+ def prepare(self, p: Any) -> None:
if callable(p):
p = p()
if p._channel is None:
@@ -66,7 +72,7 @@ class ProducerPool(Resource):
raise
return p
- def release(self, resource):
+ def release(self, resource: ProducerT) -> None:
if resource.__connection__:
resource.__connection__.release()
resource.channel = None
@@ -76,24 +82,25 @@ class ProducerPool(Resource):
class PoolGroup(EqualityDict):
"""Collection of resource pools."""
- def __init__(self, limit=None, close_after_fork=True):
+ def __init__(self, limit: int = None,
+ close_after_fork: bool = True) -> None:
self.limit = limit
self.close_after_fork = close_after_fork
if self.close_after_fork and register_after_fork is not None:
register_after_fork(self, _after_fork_cleanup_group)
- def create(self, resource, limit):
+ def create(self, connection: ClientT, limit: int) -> Any:
raise NotImplementedError('PoolGroups must define ``create``')
- def __missing__(self, resource):
+ def __missing__(self, resource: Any) -> Any:
limit = self.limit
- if limit is use_global_limit:
+ if limit == use_global_limit:
limit = get_limit()
k = self[resource] = self.create(resource, limit)
return k
-def register_group(group):
+def register_group(group: PoolGroup) -> PoolGroup:
"""Register group (can be used as decorator)."""
_groups.append(group)
return group
@@ -102,7 +109,7 @@ def register_group(group):
class Connections(PoolGroup):
"""Collection of connection pools."""
- def create(self, connection, limit):
+ def create(self, connection: ClientT, limit: int) -> Any:
return connection.Pool(limit=limit)
connections = register_group(Connections(limit=use_global_limit)) # noqa: E305
@@ -110,21 +117,24 @@ connections = register_group(Connections(limit=use_global_limit)) # noqa: E305
class Producers(PoolGroup):
"""Collection of producer pools."""
- def create(self, connection, limit):
+ def create(self, connection: ClientT, limit: int) -> Any:
return ProducerPool(connections[connection], limit=limit)
producers = register_group(Producers(limit=use_global_limit)) # noqa: E305
-def _all_pools():
+def _all_pools() -> chain[ResourceT]:
return chain(*[(g.values() if g else iter([])) for g in _groups])
-def get_limit():
+def get_limit() -> int:
"""Get current connection pool limit."""
return _limit[0]
-def set_limit(limit, force=False, reset_after=False, ignore_errors=False):
+def set_limit(limit: int,
+ force: bool = False,
+ reset_after: bool = False,
+ ignore_errors: bool = False) -> int:
"""Set new connection pool limit."""
limit = limit or 0
glimit = _limit[0] or 0
@@ -135,7 +145,7 @@ def set_limit(limit, force=False, reset_after=False, ignore_errors=False):
return limit
-def reset(*args, **kwargs):
+def reset(*args, **kwargs) -> None:
"""Reset all pools by closing open resources."""
for pool in _all_pools():
try:
diff --git a/kombu/resource.py b/kombu/resource.py
index d98f49e7..1829b466 100644
--- a/kombu/resource.py
+++ b/kombu/resource.py
@@ -1,16 +1,15 @@
"""Generic resource pool implementation."""
import os
-
from collections import deque
from queue import Empty, LifoQueue as _LifoQueue
-
+from typing import Any
from . import exceptions
-from .utils import abstract
+from .types import ResourceT
from .utils.compat import register_after_fork
from .utils.functional import lazy
-def _after_fork_cleanup_resource(resource):
+def _after_fork_cleanup_resource(resource: ResourceT) -> None:
try:
resource.force_close_all()
except Exception:
@@ -20,19 +19,21 @@ def _after_fork_cleanup_resource(resource):
class LifoQueue(_LifoQueue):
"""Last in first out version of Queue."""
- def _init(self, maxsize):
+ def _init(self, maxsize: int) -> None:
self.queue = deque()
-@abstract.Resource.register
-class Resource:
+class Resource(ResourceT):
"""Pool of resources."""
LimitExceeded = exceptions.LimitExceeded
close_after_fork = False
- def __init__(self, limit=None, preload=None, close_after_fork=None):
+ def __init__(self,
+ limit: int = None,
+ preload: int = None,
+ close_after_fork: bool = None):
self._limit = limit
self.preload = preload or 0
self._closed = False
@@ -47,10 +48,10 @@ class Resource:
register_after_fork(self, _after_fork_cleanup_resource)
self.setup()
- def setup(self):
+ def setup(self) -> None:
raise NotImplementedError('subclass responsibility')
- def _add_when_empty(self):
+ def _add_when_empty(self) -> None:
if self.limit and len(self._dirty) >= self.limit:
raise self.LimitExceeded(self.limit)
# All taken, put new on the queue and
@@ -58,7 +59,7 @@ class Resource:
# will get the resource.
self._resource.put_nowait(self.new())
- def acquire(self, block=False, timeout=None):
+ def acquire(self, block: bool = False, timeout: int = None):
"""Acquire resource.
Arguments:
@@ -94,7 +95,7 @@ class Resource:
else:
R = self.prepare(self.new())
- def release():
+ def release() -> None:
"""Release resource so it can be used by another thread.
Warnings:
@@ -107,16 +108,16 @@ class Resource:
return R
- def prepare(self, resource):
+ def prepare(self, resource: Any) -> Any:
return resource
- def close_resource(self, resource):
+ def close_resource(self, resource: Any) -> None:
resource.close()
- def release_resource(self, resource):
+ def release_resource(self, resource: Any) -> None:
...
- def replace(self, resource):
+ def replace(self, resource: Any) -> None:
"""Replace existing resource with a new instance.
This can be used in case of defective resources.
@@ -125,7 +126,7 @@ class Resource:
self._dirty.discard(resource)
self.close_resource(resource)
- def release(self, resource):
+ def release(self, resource: Any) -> None:
if self.limit:
self._dirty.discard(resource)
self._resource.put_nowait(resource)
@@ -133,10 +134,10 @@ class Resource:
else:
self.close_resource(resource)
- def collect_resource(self, resource):
+ def collect_resource(self, resource: Any) -> None:
...
- def force_close_all(self):
+ def force_close_all(self) -> None:
"""Close and remove all resources in the pool (also those in use).
Used to close resources from parent processes after fork
@@ -169,7 +170,11 @@ class Resource:
except AttributeError:
pass # Issue #78
- def resize(self, limit, force=False, ignore_errors=False, reset=False):
+ def resize(
+ self, limit: int,
+ force: bool = False,
+ ignore_errors: bool = False,
+ reset: bool = False) -> None:
prev_limit = self._limit
if (self._dirty and limit < self._limit) and not ignore_errors:
if not force:
@@ -187,7 +192,7 @@ class Resource:
if limit < prev_limit:
self._shrink_down()
- def _shrink_down(self):
+ def _shrink_down(self) -> None:
resource = self._resource
# Items to the left are last recently used, so we remove those first.
with resource.mutex:
@@ -195,11 +200,11 @@ class Resource:
self.collect_resource(resource.queue.popleft())
@property
- def limit(self):
+ def limit(self) -> int:
return self._limit
@limit.setter
- def limit(self, limit):
+ def limit(self, limit: int) -> None:
self.resize(limit)
if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover
@@ -208,7 +213,7 @@ class Resource:
_next_resource_id = 0
- def acquire(self, *args, **kwargs): # noqa
+ def acquire(self, *args, **kwargs) -> Any: # noqa
import traceback
id = self._next_resource_id = self._next_resource_id + 1
print('+{0} ACQUIRE {1}'.format(id, self.__class__.__name__))
@@ -220,7 +225,7 @@ class Resource:
r.acquired_by.append(traceback.format_stack())
return r
- def release(self, resource): # noqa
+ def release(self, resource: Any) -> None: # noqa
id = resource._resource_id
print('+{0} RELEASE {1}'.format(id, self.__class__.__name__))
r = self._orig_release(resource)
diff --git a/kombu/serialization.py b/kombu/serialization.py
index 7f885d70..a8a9cc48 100644
--- a/kombu/serialization.py
+++ b/kombu/serialization.py
@@ -9,9 +9,9 @@ try:
except ImportError: # pragma: no cover
cpickle = None # noqa
-from collections import namedtuple
from contextlib import contextmanager
from io import BytesIO
+from typing import Callable, NamedTuple
from .exceptions import (
ContentDisallowed, DecodeError, EncodeError, SerializerNotInstalled
@@ -37,7 +37,13 @@ pickle_load = pickle.load
#: There's a new protocol (3) but this is only supported by Python 3.
pickle_protocol = int(os.environ.get('PICKLE_PROTOCOL', 2))
-codec = namedtuple('codec', ('content_type', 'content_encoding', 'encoder'))
+
+class codec(NamedTuple):
+ """Codec registration triple."""
+
+ content_type: str
+ content_encoding: str
+ encoder: Callable
@contextmanager
diff --git a/kombu/simple.py b/kombu/simple.py
index 4facedf7..2436a28e 100644
--- a/kombu/simple.py
+++ b/kombu/simple.py
@@ -1,28 +1,32 @@
"""Simple messaging interface."""
import socket
-
from collections import deque
from time import monotonic
from queue import Empty
-
+from typing import Any, Mapping
from . import entity
from . import messaging
from .connection import maybe_channel
+from .types import ChannelArgT, ConsumerT, ProducerT, MessageT, SimpleQueueT
__all__ = ['SimpleQueue', 'SimpleBuffer']
-class SimpleBase:
+class SimpleBase(SimpleQueueT):
Empty = Empty
- _consuming = False
+ _consuming: bool = False
- def __enter__(self):
+ def __enter__(self) -> SimpleQueueT:
return self
- def __exit__(self, *exc_info):
+ def __exit__(self, *exc_info) -> None:
self.close()
- def __init__(self, channel, producer, consumer, no_ack=False):
+ def __init__(self,
+ channel: ChannelArgT,
+ producer: ProducerT,
+ consumer: ConsumerT,
+ no_ack: bool = False) -> None:
self.channel = maybe_channel(channel)
self.producer = producer
self.consumer = consumer
@@ -31,7 +35,7 @@ class SimpleBase:
self.buffer = deque()
self.consumer.register_callback(self._receive)
- def get(self, block=True, timeout=None):
+ def get(self, block: bool = True, timeout: float = None) -> MessageT:
if not block:
return self.get_nowait()
self._consume()
@@ -49,14 +53,18 @@ class SimpleBase:
elapsed += monotonic() - time_start
remaining = timeout and timeout - elapsed or None
- def get_nowait(self):
+ def get_nowait(self) -> MessageT:
m = self.queue.get(no_ack=self.no_ack)
if not m:
raise self.Empty()
return m
- def put(self, message, serializer=None, headers=None, compression=None,
- routing_key=None, **kwargs):
+ def put(self, message: Any,
+ serializer: str = None,
+ headers: Mapping = None,
+ compression: str = None,
+ routing_key: str = None,
+ **kwargs) -> None:
self.producer.publish(message,
serializer=serializer,
routing_key=routing_key,
@@ -64,61 +72,72 @@ class SimpleBase:
compression=compression,
**kwargs)
- def clear(self):
+ def clear(self) -> int:
return self.consumer.purge()
- def qsize(self):
+ def qsize(self) -> int:
_, size, _ = self.queue.queue_declare(passive=True)
return size
- def close(self):
+ def close(self) -> None:
self.consumer.cancel()
- def _receive(self, message_data, message):
+ def _receive(self, message_data: Any, message: MessageT) -> None:
self.buffer.append(message)
- def _consume(self):
+ def _consume(self) -> None:
if not self._consuming:
self.consumer.consume(no_ack=self.no_ack)
self._consuming = True
- def __len__(self):
+ def __len__(self) -> int:
"""`len(self) -> self.qsize()`."""
return self.qsize()
- def __bool__(self):
+ def __bool__(self) -> bool:
return True
class SimpleQueue(SimpleBase):
"""Simple API for persistent queues."""
- no_ack = False
- queue_opts = {}
- exchange_opts = {'type': 'direct'}
-
- def __init__(self, channel, name, no_ack=None, queue_opts=None,
- exchange_opts=None, serializer=None,
- compression=None, **kwargs):
+ no_ack: bool = False
+ queue_opts: Mapping = {}
+ exchange_opts: Mapping = {'type': 'direct'}
+
+ def __init__(
+ self,
+ channel: ChannelArgT,
+ name: str,
+ no_ack: bool = None,
+ queue_opts: Mapping = None,
+ exchange_opts: Mapping = None,
+ serializer: str = None,
+ compression: str = None,
+ **kwargs) -> None:
queue = name
- queue_opts = dict(self.queue_opts, **queue_opts or {})
- exchange_opts = dict(self.exchange_opts, **exchange_opts or {})
+ all_queue_opts = dict(self.queue_opts)
+ if queue_opts:
+ all_queue_opts.update(queue_opts)
+ all_exchange_opts = dict(self.exchange_opts)
+ if exchange_opts:
+ all_exchange_opts.update(exchange_opts)
if no_ack is None:
no_ack = self.no_ack
if not isinstance(queue, entity.Queue):
- exchange = entity.Exchange(name, **exchange_opts)
- queue = entity.Queue(name, exchange, name, **queue_opts)
+ exchange = entity.Exchange(name, **all_exchange_opts)
+ queue = entity.Queue(name, exchange, name, **all_queue_opts)
routing_key = name
else:
name = queue.name
exchange = queue.exchange
routing_key = queue.routing_key
- consumer = messaging.Consumer(channel, queue)
+ consumer = messaging.Consumer(channel, [queue])
producer = messaging.Producer(channel, exchange,
serializer=serializer,
routing_key=routing_key,
compression=compression)
- super().__init__(channel, producer, consumer, no_ack, **kwargs)
+ super().__init__(channel, producer, consumer, no_ack)
class SimpleBuffer(SimpleQueue):
diff --git a/kombu/transport/base.py b/kombu/transport/base.py
index 0b1f3f35..059de82f 100644
--- a/kombu/transport/base.py
+++ b/kombu/transport/base.py
@@ -1,20 +1,19 @@
"""Base transport interface."""
-import amqp.abstract
+import amqp.types
import errno
import socket
-
+from typing import Any, Callable, ChannelT, Dict, Mapping, Sequence, Tuple
from amqp.exceptions import RecoverableConnectionError
-
from kombu.exceptions import ChannelError, ConnectionError
-from kombu.five import items
from kombu.message import Message
+from kombu.types import ClientT, ConnectionT, ConsumerT, ProducerT, TransportT
from kombu.utils.functional import dictfilter
from kombu.utils.objects import cached_property
from kombu.utils.time import maybe_s_to_ms
__all__ = ['Message', 'StdChannel', 'Management', 'Transport']
-RABBITMQ_QUEUE_ARGUMENTS = { # type: Mapping[str, Tuple[str, Callable]]
+RABBITMQ_QUEUE_ARGUMENTS: Mapping[str, Tuple[str, Callable]] = {
'expires': ('x-expires', maybe_s_to_ms),
'message_ttl': ('x-message-ttl', maybe_s_to_ms),
'max_length': ('x-max-length', int),
@@ -23,8 +22,7 @@ RABBITMQ_QUEUE_ARGUMENTS = { # type: Mapping[str, Tuple[str, Callable]]
}
-def to_rabbitmq_queue_arguments(arguments, **options):
- # type: (Mapping, **Any) -> Dict
+def to_rabbitmq_queue_arguments(arguments: Mapping, **options) -> Dict:
"""Convert queue arguments to RabbitMQ queue arguments.
This is the implementation for Channel.prepare_queue_arguments
@@ -52,40 +50,39 @@ def to_rabbitmq_queue_arguments(arguments, **options):
"""
prepared = dictfilter(dict(
_to_rabbitmq_queue_argument(key, value)
- for key, value in items(options)
+ for key, value in options.items()
))
return dict(arguments, **prepared) if prepared else arguments
-def _to_rabbitmq_queue_argument(key, value):
- # type: (str, Any) -> Tuple[str, Any]
+def _to_rabbitmq_queue_argument(key: str, value: Any) -> Tuple[str, Any]:
opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key]
return opt, typ(value) if value is not None else value
-def _LeftBlank(obj, method):
+def _LeftBlank(obj: Any, method: str) -> Exception:
return NotImplementedError(
'Transport {0.__module__}.{0.__name__} does not implement {1}'.format(
obj.__class__, method))
-class StdChannel:
+class StdChannel(amqp.types.ChannelT):
"""Standard channel base class."""
no_ack_consumers = None
- def Consumer(self, *args, **kwargs):
+ def Consumer(self, *args, **kwargs) -> ConsumerT:
from kombu.messaging import Consumer
return Consumer(self, *args, **kwargs)
- def Producer(self, *args, **kwargs):
+ def Producer(self, *args, **kwargs) -> ProducerT:
from kombu.messaging import Producer
return Producer(self, *args, **kwargs)
- def get_bindings(self):
+ async def get_bindings(self) -> Sequence[Mapping]:
raise _LeftBlank(self, 'get_bindings')
- def after_reply_message_received(self, queue):
+ async def after_reply_message_received(self, queue: str) -> None:
"""Callback called after RPC reply received.
Notes:
@@ -94,42 +91,39 @@ class StdChannel:
"""
...
- def prepare_queue_arguments(self, arguments, **kwargs):
+ def prepare_queue_arguments(self, arguments: Mapping, **kwargs) -> Mapping:
return arguments
- def __enter__(self):
+ def __enter__(self) -> ChannelT:
return self
- def __exit__(self, *exc_info):
+ def __exit__(self, *exc_info) -> None:
self.close()
-amqp.abstract.Channel.register(StdChannel)
-
-
class Management:
"""AMQP Management API (incomplete)."""
- def __init__(self, transport):
+ def __init__(self, transport: TransportT):
self.transport = transport
- def get_bindings(self):
+ async def get_bindings(self) -> Sequence[Mapping]:
raise _LeftBlank(self, 'get_bindings')
class Implements(dict):
"""Helper class used to define transport features."""
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> bool:
try:
return self[key]
except KeyError:
raise AttributeError(key)
- def __setattr__(self, key, value):
+ def __setattr__(self, key: str, value: bool) -> None:
self[key] = value
- def extend(self, **kwargs):
+ def extend(self, **kwargs) -> 'Implements':
return self.__class__(self, **kwargs)
@@ -140,65 +134,66 @@ default_transport_capabilities = Implements(
)
-class Transport:
+class Transport(amqp.types.ConnectionT):
"""Base class for transports."""
- Management = Management
+ Management: type = Management
#: The :class:`~kombu.Connection` owning this instance.
- client = None
+ client: ClientT = None
#: Set to True if :class:`~kombu.Connection` should pass the URL
#: unmodified.
- can_parse_url = False
+ can_parse_url: bool = False
#: Default port used when no port has been specified.
- default_port = None
+ default_port: int = None
#: Tuple of errors that can happen due to connection failure.
- connection_errors = (ConnectionError,)
+ connection_errors: Tuple[type, ...] = (ConnectionError,)
#: Tuple of errors that can happen due to channel/method failure.
- channel_errors = (ChannelError,)
+ channel_errors: Tuple[type, ...] = (ChannelError,)
#: Type of driver, can be used to separate transports
#: using the AMQP protocol (driver_type: 'amqp'),
#: Redis (driver_type: 'redis'), etc...
- driver_type = 'N/A'
+ driver_type: str = 'N/A'
#: Name of driver library (e.g. 'py-amqp', 'redis').
- driver_name = 'N/A'
+ driver_name: str = 'N/A'
__reader = None
implements = default_transport_capabilities.extend()
- def __init__(self, client, **kwargs):
+ def __init__(self, client: ClientT, **kwargs) -> None:
self.client = client
- def establish_connection(self):
+ async def establish_connection(self) -> ConnectionT:
raise _LeftBlank(self, 'establish_connection')
- def close_connection(self, connection):
+ async def close_connection(self, connection: ConnectionT) -> None:
raise _LeftBlank(self, 'close_connection')
- def create_channel(self, connection):
+ def create_channel(self, connection: ConnectionT) -> ChannelT:
raise _LeftBlank(self, 'create_channel')
- def close_channel(self, connection):
+ async def close_channel(self, connection: ConnectionT) -> None:
raise _LeftBlank(self, 'close_channel')
- def drain_events(self, connection, **kwargs):
+ async def drain_events(self, connection: ConnectionT, **kwargs) -> None:
raise _LeftBlank(self, 'drain_events')
- def heartbeat_check(self, connection, rate=2):
+ async def heartbeat_check(self, connection: ConnectionT,
+ rate: int = 2) -> None:
...
- def driver_version(self):
+ def driver_version(self) -> str:
return 'N/A'
- def get_heartbeat_interval(self, connection):
- return 0
+ def get_heartbeat_interval(self, connection: ConnectionT) -> float:
+ return 0.0
def register_with_event_loop(self, connection, loop):
...
@@ -206,7 +201,7 @@ class Transport:
def unregister_from_event_loop(self, connection, loop):
...
- def verify_connection(self, connection):
+ def verify_connection(self, connection: ConnectionT) -> bool:
return True
def _make_reader(self, connection, timeout=socket.timeout,
@@ -228,7 +223,7 @@ class Transport:
return _read
- def qos_semantics_matches_spec(self, connection):
+ def qos_semantics_matches_spec(self, connection: ConnectionT) -> bool:
return True
def on_readable(self, connection, loop):
@@ -238,20 +233,20 @@ class Transport:
reader(loop)
@property
- def default_connection_params(self):
+ def default_connection_params(self) -> Mapping:
return {}
- def get_manager(self, *args, **kwargs):
+ def get_manager(self, *args, **kwargs) -> Management:
return self.Management(self)
@cached_property
- def manager(self):
+ def manager(self) -> Management:
return self.get_manager()
@property
- def supports_heartbeats(self):
+ def supports_heartbeats(self) -> bool:
return self.implements.heartbeats
@property
- def supports_ev(self):
+ def supports_ev(self) -> bool:
return self.implements.async
diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py
index 9ba7d8b9..b00c63a9 100644
--- a/kombu/transport/pyamqp.py
+++ b/kombu/transport/pyamqp.py
@@ -96,14 +96,14 @@ class Transport(base.Transport):
def create_channel(self, connection):
return connection.channel()
- def drain_events(self, connection, **kwargs):
- return connection.drain_events(**kwargs)
+ async def drain_events(self, connection, **kwargs):
+ await connection.drain_events(**kwargs)
def _collect(self, connection):
if connection is not None:
connection.collect()
- def establish_connection(self):
+ async def establish_connection(self):
"""Establish connection to the AMQP broker."""
conninfo = self.client
for name, default_value in self.default_connection_params.items():
@@ -124,16 +124,16 @@ class Transport(base.Transport):
}, **conninfo.transport_options or {})
conn = self.Connection(**opts)
conn.client = self.client
- conn.connect()
+ await conn.connect()
return conn
def verify_connection(self, connection):
return connection.connected
- def close_connection(self, connection):
+ async def close_connection(self, connection):
"""Close the AMQP broker connection."""
connection.client = None
- connection.close()
+ await connection.close()
def get_heartbeat_interval(self, connection):
return connection.heartbeat
@@ -142,8 +142,8 @@ class Transport(base.Transport):
connection.transport.raise_on_initial_eintr = True
loop.add_reader(connection.sock, self.on_readable, connection, loop)
- def heartbeat_check(self, connection, rate=2):
- return connection.heartbeat_tick(rate=rate)
+ async def heartbeat_check(self, connection, rate=2):
+ await connection.heartbeat_tick(rate=rate)
def qos_semantics_matches_spec(self, connection):
props = connection.server_properties
diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py
index 6960104c..b3bf63f3 100644
--- a/kombu/transport/virtual/__init__.py
+++ b/kombu/transport/virtual/__init__.py
@@ -1,5 +1,3 @@
-from __future__ import absolute_import, unicode_literals
-
from .base import (
Base64, NotEquivalentError, UndeliverableWarning, BrokerState,
QoS, Message, AbstractChannel, Channel, Management, Transport,
diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py
index 75bca055..052a620d 100644
--- a/kombu/transport/virtual/base.py
+++ b/kombu/transport/virtual/base.py
@@ -2,24 +2,30 @@
Emulates the AMQ API for non-AMQ transports.
"""
-from __future__ import absolute_import, print_function, unicode_literals
-
+import abc
import base64
import socket
import sys
import warnings
from array import array
-from collections import OrderedDict, defaultdict, namedtuple
+from asyncio import sleep
+from collections import OrderedDict, defaultdict
from itertools import count
from multiprocessing.util import Finalize
-from time import sleep
+from queue import Empty
+from time import monotonic
+from typing import (
+ Any, AnyStr, Callable, IO, Iterable, List, Mapping, MutableMapping,
+ NamedTuple, Optional, Set, Sequence, Tuple,
+)
from amqp.protocol import queue_declare_ok_t
+from amqp.types import ChannelT, ConnectionT
from kombu.exceptions import ResourceError, ChannelError
-from kombu.five import Empty, items, monotonic
from kombu.log import get_logger
+from kombu.types import ClientT, MessageT, TransportT
from kombu.utils.encoding import str_to_bytes, bytes_to_str
from kombu.utils.div import emergency_dump_state
from kombu.utils.scheduling import FairCycle
@@ -27,7 +33,7 @@ from kombu.utils.uuid import uuid
from kombu.transport import base
-from .exchange import STANDARD_EXCHANGE_TYPES
+from .exchange import STANDARD_EXCHANGE_TYPES, ExchangeType
ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H'
@@ -50,24 +56,41 @@ RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}'
logger = get_logger(__name__)
-#: Key format used for queue argument lookups in BrokerState.bindings.
-binding_key_t = namedtuple('binding_key_t', (
- 'queue', 'exchange', 'routing_key',
-))
-#: BrokerState.queue_bindings generates tuples in this format.
-queue_binding_t = namedtuple('queue_binding_t', (
- 'exchange', 'routing_key', 'arguments',
-))
+class binding_key_t(NamedTuple):
+ """Key format used for queue argument lookups in BrokerState.bindings."""
+
+ queue: str
+ exchange: str
+ routing_key: str
+
+
+class queue_binding_t(NamedTuple):
+ """BrokerState.queue_bindings generates tuples in this format."""
+
+ exchange: str
+ routing_key: str
+ arguments: Mapping
-class Base64(object):
+class Codec(metaclass=abc.ABCMeta):
+
+ @abc.abstractmethod
+ def encode(self, s: AnyStr) -> str:
+ ...
+
+ @abc.abstractmethod
+ def decode(self, s: AnyStr) -> str:
+ ...
+
+
+class Base64(Codec):
"""Base64 codec."""
- def encode(self, s):
+ def encode(self, s: AnyStr) -> str:
return bytes_to_str(base64.b64encode(str_to_bytes(s)))
- def decode(self, s):
+ def decode(self, s: AnyStr) -> str:
return base64.b64decode(str_to_bytes(s))
@@ -84,7 +107,7 @@ class BrokerState(object):
#: Mapping of exchange name to
#: :class:`kombu.transport.virtual.exchange.ExchangeType`
- exchanges = None
+ exchanges: Mapping[str, ExchangeType] = None
#: This is the actual bindings registry, used to store bindings and to
#: test 'in' relationships in constant time. It has the following
@@ -94,7 +117,7 @@ class BrokerState(object):
#: (queue, exchange, routing_key): arguments,
#: # ...,
#: }
- bindings = None
+ bindings: Mapping[binding_key_t, Mapping] = None
#: The queue index is used to access directly (constant time)
#: all the bindings of a certain queue. It has the following structure::
@@ -105,27 +128,32 @@ class BrokerState(object):
#: },
#: # ...,
#: }
- queue_index = None
+ queue_index: Mapping[str, binding_key_t] = None
- def __init__(self, exchanges=None):
+ def __init__(self, exchanges: Mapping[str, ExchangeType] = None) -> None:
self.exchanges = {} if exchanges is None else exchanges
self.bindings = {}
self.queue_index = defaultdict(set)
- def clear(self):
+ def clear(self) -> None:
self.exchanges.clear()
self.bindings.clear()
self.queue_index.clear()
- def has_binding(self, queue, exchange, routing_key):
+ def has_binding(self, queue: str, exchange: str, routing_key: str) -> bool:
return (queue, exchange, routing_key) in self.bindings
- def binding_declare(self, queue, exchange, routing_key, arguments):
+ def binding_declare(self,
+ queue: str,
+ exchange: str,
+ routing_key: str,
+ arguments: Mapping) -> None:
key = binding_key_t(queue, exchange, routing_key)
self.bindings.setdefault(key, arguments)
self.queue_index[queue].add(key)
- def binding_delete(self, queue, exchange, routing_key):
+ def binding_delete(
+ self, queue: str, exchange: str, routing_key: str) -> None:
key = binding_key_t(queue, exchange, routing_key)
try:
del self.bindings[key]
@@ -134,7 +162,7 @@ class BrokerState(object):
else:
self.queue_index[queue].remove(key)
- def queue_bindings_delete(self, queue):
+ def queue_bindings_delete(self, queue: str) -> None:
try:
bindings = self.queue_index.pop(queue)
except KeyError:
@@ -142,7 +170,7 @@ class BrokerState(object):
else:
[self.bindings.pop(binding, None) for binding in bindings]
- def queue_bindings(self, queue):
+ def queue_bindings(self, queue: str) -> Iterable[binding_key_t]:
return (
queue_binding_t(key.exchange, key.routing_key, self.bindings[key])
for key in self.queue_index[queue]
@@ -160,22 +188,22 @@ class QoS(object):
"""
#: current prefetch count value
- prefetch_count = 0
+ prefetch_count: int = 0
#: :class:`~collections.OrderedDict` of active messages.
#: *NOTE*: Can only be modified by the consuming thread.
- _delivered = None
+ _delivered: OrderedDict = None
#: acks can be done by other threads than the consuming thread.
#: Instead of a mutex, which doesn't perform well here, we mark
#: the delivery tags as dirty, so subsequent calls to append() can remove
#: them.
- _dirty = None
+ _dirty: Set[MessageT] = None
#: If disabled, unacked messages won't be restored at shutdown.
- restore_at_shutdown = True
+ restore_at_shutdown: bool = True
- def __init__(self, channel, prefetch_count=0):
+ def __init__(self, channel: ChannelT, prefetch_count: int = 0) -> None:
self.channel = channel
self.prefetch_count = prefetch_count or 0
@@ -188,7 +216,7 @@ class QoS(object):
self, self.restore_unacked_once, exitpriority=1,
)
- def can_consume(self):
+ def can_consume(self) -> bool:
"""Return true if the channel can be consumed from.
Used to ensure the client adhers to currently active
@@ -197,7 +225,7 @@ class QoS(object):
pcount = self.prefetch_count
return not pcount or len(self._delivered) - len(self._dirty) < pcount
- def can_consume_max_estimate(self):
+ def can_consume_max_estimate(self) -> int:
"""Return the maximum number of messages allowed to be returned.
Returns an estimated number of messages that a consumer may be allowed
@@ -212,16 +240,16 @@ class QoS(object):
if pcount:
return max(pcount - (len(self._delivered) - len(self._dirty)), 0)
- def append(self, message, delivery_tag):
+ def append(self, message: MessageT, delivery_tag: str) -> None:
"""Append message to transactional state."""
if self._dirty:
self._flush()
self._quick_append(delivery_tag, message)
- def get(self, delivery_tag):
+ def get(self, delivery_tag: str) -> MessageT:
return self._delivered[delivery_tag]
- def _flush(self):
+ def _flush(self) -> None:
"""Flush dirty (acked/rejected) tags from."""
dirty = self._dirty
delivered = self._delivered
@@ -232,17 +260,17 @@ class QoS(object):
break
delivered.pop(dirty_tag, None)
- def ack(self, delivery_tag):
+ def ack(self, delivery_tag: str) -> None:
"""Acknowledge message and remove from transactional state."""
self._quick_ack(delivery_tag)
- def reject(self, delivery_tag, requeue=False):
+ def reject(self, delivery_tag: str, requeue: bool = False) -> None:
"""Remove from transactional state and requeue message."""
if requeue:
self.channel._restore_at_beginning(self._delivered[delivery_tag])
self._quick_ack(delivery_tag)
- def restore_unacked(self):
+ async def restore_unacked(self) -> None:
"""Restore all unacknowledged messages."""
self._flush()
delivered = self._delivered
@@ -257,13 +285,13 @@ class QoS(object):
break
try:
- restore(message)
+ await restore(message)
except BaseException as exc:
errors.append((exc, message))
delivered.clear()
return errors
- def restore_unacked_once(self, stderr=None):
+ async def restore_unacked_once(self, stderr: IO = None) -> None:
"""Restore all unacknowledged messages at shutdown/gc collect.
Note:
@@ -284,7 +312,7 @@ class QoS(object):
if state:
print(RESTORING_FMT.format(len(self._delivered)),
file=stderr)
- unrestored = self.restore_unacked()
+ unrestored = await self.restore_unacked()
if unrestored:
errors, messages = list(zip(*unrestored))
@@ -294,7 +322,7 @@ class QoS(object):
finally:
state.restored = True
- def restore_visible(self, *args, **kwargs):
+ async def restore_visible(self, *args, **kwargs) -> None:
"""Restore any pending unackwnowledged messages.
To be filled in for visibility_timeout style implementations.
@@ -303,13 +331,14 @@ class QoS(object):
This is implementation optional, and currently only
used by the Redis transport.
"""
- pass
+ ...
class Message(base.Message):
"""Message object."""
- def __init__(self, payload, channel=None, **kwargs):
+ def __init__(self, payload: Any,
+ channel: ChannelT = None, **kwargs) -> None:
self._raw = payload
properties = payload['properties']
body = payload.get('body')
@@ -327,7 +356,7 @@ class Message(base.Message):
postencode='utf-8',
**kwargs)
- def serializable(self):
+ def serializable(self) -> Mapping:
props = self.properties
body, _ = self.channel.encode_body(self.body,
props.get('body_encoding'))
@@ -354,23 +383,23 @@ class AbstractChannel(object):
from :class:`Channel`.
"""
- def _get(self, queue, timeout=None):
+ async def _get(self, queue: str, timeout: float = None) -> MessageT:
"""Get next message from `queue`."""
raise NotImplementedError('Virtual channels must implement _get')
- def _put(self, queue, message):
+ async def _put(self, queue: str, message: MessageT) -> None:
"""Put `message` onto `queue`."""
raise NotImplementedError('Virtual channels must implement _put')
- def _purge(self, queue):
+ async def _purge(self, queue: str) -> int:
"""Remove all messages from `queue`."""
raise NotImplementedError('Virtual channels must implement _purge')
- def _size(self, queue):
+ async def _size(self, queue: str) -> int:
"""Return the number of messages in `queue` as an :class:`int`."""
return 0
- def _delete(self, queue, *args, **kwargs):
+ async def _delete(self, queue: str, *args, **kwargs) -> None:
"""Delete `queue`.
Note:
@@ -379,16 +408,16 @@ class AbstractChannel(object):
"""
self._purge(queue)
- def _new_queue(self, queue, **kwargs):
+ async def _new_queue(self, queue: str, **kwargs) -> None:
"""Create new queue.
Note:
Your transport can override this method if it needs
to do something whenever a new queue is declared.
"""
- pass
+ ...
- def _has_queue(self, queue, **kwargs):
+ async def _has_queue(self, queue: str, **kwargs) -> bool:
"""Verify that queue exists.
Returns:
@@ -397,13 +426,14 @@ class AbstractChannel(object):
"""
return True
- def _poll(self, cycle, callback, timeout=None):
+ async def _poll(self, cycle: Any, callback: Callable,
+ timeout: float = None) -> Any:
"""Poll a list of queues for available messages."""
- return cycle.get(callback)
+ return await cycle.get(callback)
- def _get_and_deliver(self, queue, callback):
- message = self._get(queue)
- callback(message, queue)
+ async def _get_and_deliver(self, queue: str, callback: Callable) -> None:
+ message = await self._get(queue)
+ await callback(message, queue)
class Channel(AbstractChannel, base.StdChannel):
@@ -415,44 +445,56 @@ class Channel(AbstractChannel, base.StdChannel):
"""
#: message class used.
- Message = Message
+ Message: type = Message
#: QoS class used.
- QoS = QoS
+ QoS: type = QoS
#: flag to restore unacked messages when channel
#: goes out of scope.
- do_restore = True
+ do_restore: bool = True
#: mapping of exchange types and corresponding classes.
- exchange_types = dict(STANDARD_EXCHANGE_TYPES)
+ exchange_type_classes: Mapping[str, type] = dict(
+ STANDARD_EXCHANGE_TYPES)
+
+ exchange_types: Mapping[str, ExchangeType] = None
#: flag set if the channel supports fanout exchanges.
- supports_fanout = False
+ supports_fanout: bool = False
#: Binary <-> ASCII codecs.
- codecs = {'base64': Base64()}
+ codecs: Mapping[str, Codec] = {'base64': Base64()}
#: Default body encoding.
#: NOTE: ``transport_options['body_encoding']`` will override this value.
- body_encoding = 'base64'
+ body_encoding: str = 'base64'
#: counter used to generate delivery tags for this channel.
_delivery_tags = count(1)
#: Optional queue where messages with no route is delivered.
#: Set by ``transport_options['deadletter_queue']``.
- deadletter_queue = None
+ deadletter_queue: str = None
# List of options to transfer from :attr:`transport_options`.
- from_transport_options = ('body_encoding', 'deadletter_queue')
+ from_transport_options: Tuple[str, ...] = (
+ 'body_encoding',
+ 'deadletter_queue',
+ )
# Priority defaults
default_priority = 0
min_priority = 0
max_priority = 9
- def __init__(self, connection, **kwargs):
+ _consumers: Set[str]
+ _tag_to_queue: Mapping[str, str]
+ _active_queues: List[str]
+ closed: bool = False
+ _qos: QoS = None
+
+ def __init__(self, connection: ConnectionT, **kwargs) -> None:
self.connection = connection
self._consumers = set()
self._cycle = None
@@ -462,9 +504,9 @@ class Channel(AbstractChannel, base.StdChannel):
self.closed = False
# instantiate exchange types
- self.exchange_types = dict(
- (typ, cls(self)) for typ, cls in items(self.exchange_types)
- )
+ self.exchange_types = {
+ typ: cls(self) for typ, cls in self.exchange_type_classes.items()
+ }
try:
self.channel_id = self.connection._avail_channel_ids.pop()
@@ -482,9 +524,15 @@ class Channel(AbstractChannel, base.StdChannel):
except KeyError:
pass
- def exchange_declare(self, exchange=None, type='direct', durable=False,
- auto_delete=False, arguments=None,
- nowait=False, passive=False):
+ async def exchange_declare(
+ self,
+ exchange: str = None,
+ type: str = 'direct',
+ durable: bool = False,
+ auto_delete: bool = False,
+ arguments: Mapping = None,
+ nowait: bool = False,
+ passive: bool = False) -> None:
"""Declare exchange."""
type = type or 'direct'
exchange = exchange or 'amq.%s' % type
@@ -512,49 +560,74 @@ class Channel(AbstractChannel, base.StdChannel):
'table': [],
}
- def exchange_delete(self, exchange, if_unused=False, nowait=False):
+ async def exchange_delete(
+ self,
+ exchange: str,
+ if_unused: bool = False,
+ nowait: bool = False) -> None:
"""Delete `exchange` and all its bindings."""
for rkey, _, queue in self.get_table(exchange):
- self.queue_delete(queue, if_unused=True, if_empty=True)
+ await self.queue_delete(queue, if_unused=True, if_empty=True)
self.state.exchanges.pop(exchange, None)
- def queue_declare(self, queue=None, passive=False, **kwargs):
+ async def queue_declare(
+ self,
+ queue: str = None,
+ passive: bool = False,
+ **kwargs) -> queue_declare_ok_t:
"""Declare queue."""
queue = queue or 'amq.gen-%s' % uuid()
- if passive and not self._has_queue(queue, **kwargs):
+ if passive and not await self._has_queue(queue, **kwargs):
raise ChannelError(
'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format(
queue, self.connection.client.virtual_host or '/'),
(50, 10), 'Channel.queue_declare', '404',
)
else:
- self._new_queue(queue, **kwargs)
- return queue_declare_ok_t(queue, self._size(queue), 0)
-
- def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs):
+ await self._new_queue(queue, **kwargs)
+ return queue_declare_ok_t(queue, await self._size(queue), 0)
+
+ async def queue_delete(
+ self,
+ queue: str,
+ if_unused: bool = False,
+ if_empty: bool = False,
+ **kwargs) -> None:
"""Delete queue."""
- if if_empty and self._size(queue):
+ if if_empty and await self._size(queue):
return
for exchange, routing_key, args in self.state.queue_bindings(queue):
meta = self.typeof(exchange).prepare_bind(
queue, exchange, routing_key, args,
)
- self._delete(queue, exchange, *meta, **kwargs)
+ await self._delete(queue, exchange, *meta, **kwargs)
self.state.queue_bindings_delete(queue)
- def after_reply_message_received(self, queue):
- self.queue_delete(queue)
+ async def after_reply_message_received(self, queue: str) -> None:
+ await self.queue_delete(queue)
- def exchange_bind(self, destination, source='', routing_key='',
- nowait=False, arguments=None):
+ async def exchange_bind(
+ self, destination: str,
+ source: str = '',
+ routing_key: str = '',
+ nowait: bool = False,
+ arguments: Mapping = None) -> None:
raise NotImplementedError('transport does not support exchange_bind')
- def exchange_unbind(self, destination, source='', routing_key='',
- nowait=False, arguments=None):
+ async def exchange_unbind(
+ self, destination: str,
+ source: str = '',
+ routing_key: str = '',
+ nowait: bool = False,
+ arguments: Mapping = None) -> None:
raise NotImplementedError('transport does not support exchange_unbind')
- def queue_bind(self, queue, exchange=None, routing_key='',
- arguments=None, **kwargs):
+ async def queue_bind(
+ self, queue: str,
+ exchange: str = None,
+ routing_key: str = '',
+ arguments: Mapping = None,
+ **kwargs) -> None:
"""Bind `queue` to `exchange` with `routing key`."""
exchange = exchange or 'amq.direct'
if self.state.has_binding(queue, exchange, routing_key):
@@ -568,10 +641,14 @@ class Channel(AbstractChannel, base.StdChannel):
)
table.append(meta)
if self.supports_fanout:
- self._queue_bind(exchange, *meta)
-
- def queue_unbind(self, queue, exchange=None, routing_key='',
- arguments=None, **kwargs):
+ await self._queue_bind(exchange, *meta)
+
+ async def queue_unbind(
+ self, queue: str,
+ exchange: str = None,
+ routing_key: str = '',
+ arguments: Mapping = None,
+ **kwargs) -> None:
# Remove queue binding:
self.state.binding_delete(queue, exchange, routing_key)
try:
@@ -585,19 +662,21 @@ class Channel(AbstractChannel, base.StdChannel):
# Should be optimized. Modifying table in place.
table[:] = [meta for meta in table if meta != binding_meta]
- def list_bindings(self):
- return ((queue, exchange, rkey)
+ async def list_bindings(self) -> Iterable[queue_binding_t]:
+ return (queue_binding_t(queue, exchange, rkey)
for exchange in self.state.exchanges
for rkey, pattern, queue in self.get_table(exchange))
- def queue_purge(self, queue, **kwargs):
+ async def queue_purge(self, queue: str, **kwargs) -> int:
"""Remove all ready messages from queue."""
- return self._purge(queue)
+ return await self._purge(queue)
- def _next_delivery_tag(self):
+ def _next_delivery_tag(self) -> str:
return uuid()
- def basic_publish(self, message, exchange, routing_key, **kwargs):
+ async def basic_publish(
+ self, message: MessageT, exchange: str, routing_key: str,
+ **kwargs) -> None:
"""Publish message."""
self._inplace_augment_message(message, exchange, routing_key)
if exchange:
@@ -605,9 +684,10 @@ class Channel(AbstractChannel, base.StdChannel):
message, exchange, routing_key, **kwargs
)
# anon exchange: routing_key is the destination queue
- return self._put(routing_key, message, **kwargs)
+ return await self._put(routing_key, message, **kwargs)
- def _inplace_augment_message(self, message, exchange, routing_key):
+ def _inplace_augment_message(
+ self, message: MessageT, exchange: str, routing_key: str) -> None:
message['body'], body_encoding = self.encode_body(
message['body'], self.body_encoding,
)
@@ -621,23 +701,28 @@ class Channel(AbstractChannel, base.StdChannel):
routing_key=routing_key,
)
- def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs):
+ async def basic_consume(
+ self, queue: str,
+ no_ack: bool = False,
+ callback: Callable = None,
+ consumer_tag: str = None,
+ **kwargs) -> None:
"""Consume from `queue`."""
self._tag_to_queue[consumer_tag] = queue
self._active_queues.append(queue)
- def _callback(raw_message):
+ async def _callback(raw_message):
message = self.Message(raw_message, channel=self)
if not no_ack:
self.qos.append(message, message.delivery_tag)
- return callback(message)
+ return await callback(message)
self.connection._callbacks[queue] = _callback
self._consumers.add(consumer_tag)
self._reset_cycle()
- def basic_cancel(self, consumer_tag):
+ await def basic_cancel(self, consumer_tag: str) -> None:
"""Cancel consumer by consumer tag."""
if consumer_tag in self._consumers:
self._consumers.remove(consumer_tag)
@@ -649,32 +734,38 @@ class Channel(AbstractChannel, base.StdChannel):
pass
self.connection._callbacks.pop(queue, None)
- def basic_get(self, queue, no_ack=False, **kwargs):
+ async def basic_get(
+ self, queue: str,
+ no_ack: bool = False,
+ **kwargs) -> Optional[MessageT]:
"""Get message by direct access (synchronous)."""
try:
- message = self.Message(self._get(queue), channel=self)
+ message = self.Message(await self._get(queue), channel=self)
if not no_ack:
self.qos.append(message, message.delivery_tag)
return message
except Empty:
pass
- def basic_ack(self, delivery_tag, multiple=False):
+ async def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
"""Acknowledge message."""
self.qos.ack(delivery_tag)
- def basic_recover(self, requeue=False):
+ async def basic_recover(self, requeue: bool = False) -> None:
"""Recover unacked messages."""
if requeue:
- return self.qos.restore_unacked()
+ return await self.qos.restore_unacked()
raise NotImplementedError('Does not support recover(requeue=False)')
- def basic_reject(self, delivery_tag, requeue=False):
+ async def basic_reject(self, delivery_tag: str, requeue: bool = False) -> None:
"""Reject message."""
self.qos.reject(delivery_tag, requeue=requeue)
- def basic_qos(self, prefetch_size=0, prefetch_count=0,
- apply_global=False):
+ async def basic_qos(
+ self,
+ prefetch_size: int = 0,
+ prefetch_count: int = 0,
+ apply_global: bool = False) -> None:
"""Change QoS settings for this channel.
Note:
@@ -682,14 +773,14 @@ class Channel(AbstractChannel, base.StdChannel):
"""
self.qos.prefetch_count = prefetch_count
- def get_exchanges(self):
+ def get_exchanges(self) -> Sequence[str]:
return list(self.state.exchanges)
- def get_table(self, exchange):
+ def get_table(self, exchange: str) -> Mapping:
"""Get table of bindings for `exchange`."""
return self.state.exchanges[exchange]['table']
- def typeof(self, exchange, default='direct'):
+ def typeof(self, exchange: str, default: str = 'direct') -> ExchangeType:
"""Get the exchange type instance for `exchange`."""
try:
type = self.state.exchanges[exchange]['type']
@@ -697,7 +788,9 @@ class Channel(AbstractChannel, base.StdChannel):
type = default
return self.exchange_types[type]
- def _lookup(self, exchange, routing_key, default=None):
+ def _lookup(
+ self, exchange: str, routing_key: str,
+ default: str = None) -> Sequence[str]:
"""Find all queues matching `routing_key` for the given `exchange`.
Returns:
@@ -725,34 +818,42 @@ class Channel(AbstractChannel, base.StdChannel):
R = [default]
return R
- def _restore(self, message):
+ async def _restore(self, message: MessageT) -> None:
"""Redeliver message to its original destination."""
delivery_info = message.delivery_info
message = message.serializable()
message['redelivered'] = True
for queue in self._lookup(
delivery_info['exchange'], delivery_info['routing_key']):
- self._put(queue, message)
+ await self._put(queue, message)
- def _restore_at_beginning(self, message):
- return self._restore(message)
+ async def _restore_at_beginning(self, message: MessageT) -> None:
+ await self._restore(message)
- def drain_events(self, timeout=None, callback=None):
+ async def drain_events(
+ self,
+ timeout: float = None,
+ callback: Callable = None) -> None:
callback = callback or self.connection._deliver
if self._consumers and self.qos.can_consume():
if hasattr(self, '_get_many'):
- return self._get_many(self._active_queues, timeout=timeout)
- return self._poll(self.cycle, callback, timeout=timeout)
+ await self._get_many(self._active_queues, timeout=timeout)
+ else:
+ await self._poll(self.cycle, callback, timeout=timeout)
raise Empty()
- def message_to_python(self, raw_message):
+ def message_to_python(self, raw_message: Any) -> MessageT:
"""Convert raw message to :class:`Message` instance."""
if not isinstance(raw_message, self.Message):
return self.Message(payload=raw_message, channel=self)
return raw_message
- def prepare_message(self, body, priority=None, content_type=None,
- content_encoding=None, headers=None, properties=None):
+ def prepare_message(self, body: Any,
+ priority: int = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ headers: Mapping = None,
+ properties: Mapping = None) -> Mapping:
"""Prepare message data."""
properties = properties or {}
properties.setdefault('delivery_info', {})
@@ -764,7 +865,7 @@ class Channel(AbstractChannel, base.StdChannel):
'headers': headers or {},
'properties': properties or {}}
- def flow(self, active=True):
+ async def flow(self, active: bool = True) -> None:
"""Enable/disable message flow.
Raises:
@@ -773,7 +874,7 @@ class Channel(AbstractChannel, base.StdChannel):
"""
raise NotImplementedError('virtual channels do not support flow.')
- def close(self):
+ async def close(self) -> None:
"""Close channel.
Cancel all consumers, and requeue unacked messages.
@@ -781,55 +882,62 @@ class Channel(AbstractChannel, base.StdChannel):
if not self.closed:
self.closed = True
for consumer in list(self._consumers):
- self.basic_cancel(consumer)
+ await self.basic_cancel(consumer)
if self._qos:
- self._qos.restore_unacked_once()
+ await self._qos.restore_unacked_once()
if self._cycle is not None:
self._cycle.close()
self._cycle = None
if self.connection is not None:
- self.connection.close_channel(self)
+ await self.connection.close_channel(self)
self.exchange_types = None
- def encode_body(self, body, encoding=None):
+ def encode_body(self, body: Any, encoding: str = None) -> Tuple[Any, str]:
if encoding:
return self.codecs.get(encoding).encode(body), encoding
return body, encoding
- def decode_body(self, body, encoding=None):
+ def decode_body(self, body: Any, encoding: str = None) -> Any:
if encoding:
return self.codecs.get(encoding).decode(body)
return body
- def _reset_cycle(self):
+ def _reset_cycle(self) -> None:
self._cycle = FairCycle(
self._get_and_deliver, self._active_queues, Empty)
- def __enter__(self):
+ def __enter__(self) -> 'Channel':
return self
- def __exit__(self, *exc_info):
+ def __exit__(self, *exc_info) -> None:
self.close()
+ async def __aenter__(self) -> 'Channel':
+ return self
+
+ async def __aexit__(self, *exc_info) -> None:
+ ...
+
@property
- def state(self):
+ def state(self) -> BrokerState:
"""Broker state containing exchanges and bindings."""
return self.connection.state
@property
- def qos(self):
+ def qos(self) -> QoS:
""":class:`QoS` manager for this channel."""
if self._qos is None:
self._qos = self.QoS(self)
return self._qos
@property
- def cycle(self):
+ def cycle(self) -> FairCycle:
if self._cycle is None:
self._reset_cycle()
return self._cycle
- def _get_message_priority(self, message, reverse=False):
+ def _get_message_priority(self, message: MessageT,
+ *, reverse: bool = False) -> int:
"""Get priority from message.
The value is limited to within a boundary of 0 to 9.
@@ -852,15 +960,15 @@ class Channel(AbstractChannel, base.StdChannel):
class Management(base.Management):
"""Base class for the AMQP management API."""
- def __init__(self, transport):
+ def __init__(self, transport: TransportT) -> None:
super(Management, self).__init__(transport)
self.channel = transport.client.channel()
- def get_bindings(self):
+ def get_bindings(self) -> Sequence[Mapping]:
return [dict(destination=q, source=e, routing_key=r)
for q, e, r in self.channel.list_bindings()]
- def close(self):
+ def close(self) -> None:
self.channel.close()
@@ -871,31 +979,31 @@ class Transport(base.Transport):
client (kombu.Connection): The client this is a transport for.
"""
- Channel = Channel
- Cycle = FairCycle
- Management = Management
+ Channel: type = Channel
+ Cycle: type = FairCycle
+ Management: type = Management
#: Global :class:`BrokerState` containing declared exchanges and bindings.
- state = BrokerState()
+ state: BrokerState = BrokerState()
#: :class:`~kombu.utils.scheduling.FairCycle` instance
#: used to fairly drain events from channels (set by constructor).
- cycle = None
+ cycle: FairCycle = None
#: port number used when no port is specified.
- default_port = None
+ default_port: int = None
#: active channels.
- channels = None
+ channels: Sequence[ChannelT] = None
#: queue/callback map.
- _callbacks = None
+ _callbacks: MutableMapping[str, Callable] = None
#: Time to sleep between unsuccessful polls.
- polling_interval = 1.0
+ polling_interval: float = 1.0
#: Max number of channels
- channel_max = 65535
+ channel_max: int = 65535
implements = base.Transport.implements.extend(
async=False,
@@ -903,7 +1011,7 @@ class Transport(base.Transport):
heartbeats=False,
)
- def __init__(self, client, **kwargs):
+ def __init__(self, client: ClientT, **kwargs):
self.client = client
self.channels = []
self._avail_channels = []
@@ -916,7 +1024,7 @@ class Transport(base.Transport):
ARRAY_TYPE_H, range(self.channel_max, 0, -1),
)
- def create_channel(self, connection):
+ def create_channel(self, connection: ConnectionT) -> ChannelT:
try:
return self._avail_channels.pop()
except IndexError:
@@ -924,7 +1032,7 @@ class Transport(base.Transport):
self.channels.append(channel)
return channel
- def close_channel(self, channel):
+ async def close_channel(self, channel: ChannelT) -> None:
try:
self._avail_channel_ids.append(channel.channel_id)
try:
@@ -934,14 +1042,14 @@ class Transport(base.Transport):
finally:
channel.connection = None
- def establish_connection(self):
+ async def establish_connection(self) -> 'Transport':
# creates channel to verify connection.
# this channel is then used as the next requested channel.
# (returned by ``create_channel``).
self._avail_channels.append(self.create_channel(self))
return self # for drain events
- def close_connection(self, connection):
+ async def close_connection(self, connection: ConnectionT) -> None:
self.cycle.close()
for l in self._avail_channels, self.channels:
while l:
@@ -950,24 +1058,25 @@ class Transport(base.Transport):
except LookupError: # pragma: no cover
pass
else:
- channel.close()
+ await channel.close()
- def drain_events(self, connection, timeout=None):
+ async def drain_events(self, connection: ConnectionT,
+ timeout: float = None) -> None:
time_start = monotonic()
get = self.cycle.get
polling_interval = self.polling_interval
while 1:
try:
- get(self._deliver, timeout=timeout)
+ await get(self._deliver, timeout=timeout)
except Empty:
if timeout is not None and monotonic() - time_start >= timeout:
raise socket.timeout()
if polling_interval is not None:
- sleep(polling_interval)
+ await sleep(polling_interval)
else:
break
- def _deliver(self, message, queue):
+ async def _deliver(self, message: MessageT, queue: str) -> None:
if not queue:
raise KeyError(
'Received message without destination queue: {0}'.format(
@@ -980,7 +1089,7 @@ class Transport(base.Transport):
else:
callback(message)
- def _reject_inbound_message(self, raw_message):
+ def _reject_inbound_message(self, raw_message: Any) -> None:
for channel in self.channels:
if channel:
message = channel.Message(raw_message, channel=channel)
@@ -988,16 +1097,18 @@ class Transport(base.Transport):
channel.basic_reject(message.delivery_tag, requeue=True)
break
- def on_message_ready(self, channel, message, queue):
+ def on_message_ready(
+ self, channel: ChannelT, message: MessageT, queue: str) -> None:
if not queue or queue not in self._callbacks:
raise KeyError(
'Message for queue {0!r} without consumers: {1}'.format(
queue, message))
self._callbacks[queue](message)
- def _drain_channel(self, channel, callback, timeout=None):
+ def _drain_channel(self, channel: ChannelT, callback: Callable,
+ timeout: float = None) -> None:
return channel.drain_events(callback=callback, timeout=timeout)
@property
- def default_connection_params(self):
- return {'port': self.default_port, 'hostname': 'localhost'}
+ def default_connection_params(self) -> Mapping:
+ return {'port': self.default_port, 'hostnaime': 'localhost'}
diff --git a/kombu/transport/virtual/exchange.py b/kombu/transport/virtual/exchange.py
index 33d63066..98c0e6db 100644
--- a/kombu/transport/virtual/exchange.py
+++ b/kombu/transport/virtual/exchange.py
@@ -3,10 +3,11 @@
Implementations of the standard exchanges defined
by the AMQ protocol (excluding the `headers` exchange).
"""
-
import re
-
+from typing import Mapping, Match, Pattern, Set, Tuple
+from amqp.types import ChannelT
from kombu.utils.text import escape_regex
+from kombu.types import MessageT
class ExchangeType:
@@ -18,9 +19,10 @@ class ExchangeType:
channel (ChannelT): AMQ Channel.
"""
- type = None
+ type: str
+ channel: ChannelT
- def __init__(self, channel):
+ def __init__(self, channel: ChannelT) -> None:
self.channel = channel
def lookup(self, table, exchange, routing_key, default):
@@ -57,13 +59,18 @@ class DirectExchange(ExchangeType):
type = 'direct'
- def lookup(self, table, exchange, routing_key, default):
+ def lookup(self,
+ table: Mapping,
+ exchange: str,
+ routing_key: str,
+ default: str) -> Set[str]:
return {
queue for rkey, _, queue in table
if rkey == routing_key
}
- def deliver(self, message, exchange, routing_key, **kwargs):
+ def deliver(self, message: MessageT,
+ exchange: str, routing_key: str, **kwargs) -> None:
_lookup = self.channel._lookup
_put = self.channel._put
for queue in _lookup(exchange, routing_key):
@@ -81,19 +88,26 @@ class TopicExchange(ExchangeType):
type = 'topic'
#: map of wildcard to regex conversions
- wildcards = {'*': r'.*?[^\.]',
- '#': r'.*?'}
+ wildcards: Mapping[str, str] = {
+ '*': r'.*?[^\.]',
+ '#': r'.*?',
+ }
#: compiled regex cache
- _compiled = {}
+ _compiled: Mapping[str, Pattern] = {}
- def lookup(self, table, exchange, routing_key, default):
+ def lookup(self,
+ table: Mapping,
+ exchange: str,
+ routing_key: str,
+ default: str) -> Set[str]:
return {
queue for rkey, pattern, queue in table
if self._match(pattern, routing_key)
}
- def deliver(self, message, exchange, routing_key, **kwargs):
+ def deliver(self, message: MessageT,
+ exchange: str, routing_key: str, **kwargs) -> None:
_lookup = self.channel._lookup
_put = self.channel._put
deadletter = self.channel.deadletter_queue
@@ -101,17 +115,21 @@ class TopicExchange(ExchangeType):
if q and q != deadletter]:
_put(queue, message, **kwargs)
- def prepare_bind(self, queue, exchange, routing_key, arguments):
+ def prepare_bind(self,
+ queue: str,
+ exchange: str,
+ routing_key: str,
+ arguments: Mapping) -> Tuple[str, Pattern, str]:
return routing_key, self.key_to_pattern(routing_key), queue
- def key_to_pattern(self, rkey):
+ def key_to_pattern(self, rkey: str) -> str:
"""Get the corresponding regex for any routing key."""
return '^%s$' % ('\.'.join(
self.wildcards.get(word, word)
for word in escape_regex(rkey, '.#*').split('.')
))
- def _match(self, pattern, string):
+ def _match(self, pattern: str, string: str) -> Match:
"""Match regular expression (cached).
Same as :func:`re.match`, except the regex is compiled and cached,
@@ -141,17 +159,19 @@ class FanoutExchange(ExchangeType):
type = 'fanout'
- def lookup(self, table, exchange, routing_key, default):
+ def lookup(self, table: Mapping,
+ exchange: str, routing_key: str, default: str) -> Set[str]:
return {queue for _, _, queue in table}
- def deliver(self, message, exchange, routing_key, **kwargs):
+ def deliver(self, message: MessageT,
+ exchange: str, routing_key: str, **kwargs) -> None:
if self.channel.supports_fanout:
self.channel._put_fanout(
exchange, message, routing_key, **kwargs)
#: Map of standard exchange types and corresponding classes.
-STANDARD_EXCHANGE_TYPES = {
+STANDARD_EXCHANGE_TYPES: Mapping[str, type] = {
'direct': DirectExchange,
'topic': TopicExchange,
'fanout': FanoutExchange,
diff --git a/kombu/types.py b/kombu/types.py
new file mode 100644
index 00000000..5825fb1f
--- /dev/null
+++ b/kombu/types.py
@@ -0,0 +1,893 @@
+import abc
+import logging
+from collections import Hashable, deque
+from numbers import Number
+from typing import (
+ Any, Callable, Dict, Iterable, List, Mapping, Optional,
+ Sequence, TypeVar, Tuple, Union,
+)
+from typing import Set # noqa
+from amqp.types import ChannelT, MessageT as _MessageT, TransportT
+from amqp.spec import queue_declare_ok_t
+
+
+def _hasattr(C: Any, attr: str) -> bool:
+ return any(attr in B.__dict__ for B in C.__mro__)
+
+
+class _AbstractClass(abc.ABCMeta):
+ __required_attributes__ = frozenset() # type: frozenset
+
+ @classmethod
+ def _subclasshook_using(cls, parent: Any, C: Any):
+ return (
+ cls is parent and
+ all(_hasattr(C, attr) for attr in cls.__required_attributes__)
+ ) or NotImplemented
+
+
+class Revivable(metaclass=_AbstractClass):
+
+ async def revive(self, channel: ChannelT) -> None:
+ ...
+
+
+class ClientT(Revivable, metaclass=_AbstractClass):
+
+ hostname: str
+ userid: str
+ password: str
+ ssl: Any
+ login_method: str
+ port: int = None
+ virtual_host: str = '/'
+ connect_timeout: float = 5.0
+ alt: Sequence[str]
+
+ uri_prefix: str = None
+ declared_entities: Set['EntityT'] = None
+ cycle: Iterable
+ transport_options: Mapping
+ failover_strategy: str
+ heartbeat: float = None
+ resolve_aliases: Mapping[str, str] = resolve_aliases
+ failover_strategies: Mapping[str, Callable] = failover_strategies
+
+ @property
+ @abc.abstractmethod
+ def default_channel(self) -> ChannelT:
+ ...
+
+ def __init__(
+ self,
+ hostname: str = 'localhost',
+ userid: str = None,
+ password: str = None,
+ virtual_host: str = None,
+ port: int = None,
+ insist: bool = False,
+ ssl: Any = None,
+ transport: Union[type, str] = None,
+ connect_timeout: float = 5.0,
+ transport_options: Mapping = None,
+ login_method: str = None,
+ uri_prefix: str = None,
+ heartbeat: float = None,
+ failover_strategy: str = 'round-robin',
+ alternates: Sequence[str] = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ def switch(self, url: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ def maybe_switch_next(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def connect(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def channel(self) -> ChannelT:
+ ...
+
+ @abc.abstractmethod
+ async def heartbeat_check(self, rate: int = 2) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def drain_events(self, timeout: float = None, **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def maybe_close_channel(self, channel: ChannelT) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def collect(self, socket_timeout: float = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def release(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def close(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def ensure_connection(
+ self,
+ errback: Callable = None,
+ max_retries: int = None,
+ interval_start: float = 2.0,
+ interval_step: float = 2.0,
+ interval_max: float = 30.0,
+ callback: Callable = None,
+ reraise_as_library_errors: bool = True) -> 'ClientT':
+ ...
+
+ @abc.abstractmethod
+ def completes_cycle(self, retries: int) -> bool:
+ ...
+
+ @abc.abstractmethod
+ async def ensure(
+ self, obj: Revivable, fun: Callable,
+ errback: Callable = None,
+ max_retries: int = None,
+ interval_start: float = 1.0,
+ interval_step: float = 1.0,
+ interval_max: float = 1.0,
+ on_revive: Callable = None) -> Any:
+ ...
+
+ @abc.abstractmethod
+ def autoretry(self, fun: Callable,
+ channel: ChannelT = None, **ensure_options) -> Any:
+ ...
+
+ @abc.abstractmethod
+ def create_transport(self) -> TransportT:
+ ...
+
+ @abc.abstractmethod
+ def get_transport_cls(self) -> type:
+ ...
+
+ @abc.abstractmethod
+ def clone(self, **kwargs) -> 'ClientT':
+ ...
+
+ @abc.abstractmethod
+ def get_heartbeat_interval(self) -> float:
+ ...
+
+ @abc.abstractmethod
+ def info(self) -> Mapping[str, Any]:
+ ...
+
+ @abc.abstractmethod
+ def as_uri(
+ self,
+ *,
+ include_password: bool = False,
+ mask: str = '**',
+ getfields: Callable = None) -> str:
+ ...
+
+ @abc.abstractmethod
+ def Pool(self, limit: int = None, **kwargs) -> 'ResourceT':
+ ...
+
+ @abc.abstractmethod
+ def ChannelPool(self, limit: int = None, **kwargs) -> 'ResourceT':
+ ...
+
+ @abc.abstractmethod
+ def Producer(self, channel: ChannelT = None,
+ *args, **kwargs) -> 'ProducerT':
+ ...
+
+ @abc.abstractmethod
+ def Consumer(
+ self,
+ queues: Sequence['QueueT'] = None,
+ channel: ChannelT = None,
+ *args, **kwargs) -> 'ConsumerT':
+ ...
+
+ @abc.abstractmethod
+ def SimpleQueue(
+ self, name: str,
+ no_ack: bool = None,
+ queue_opts: Mapping = None,
+ exchange_opts: Mapping = None,
+ channel: ChannelT = None,
+ **kwargs) -> 'SimpleQueueT':
+ ...
+
+
+class EntityT(Hashable, Revivable, metaclass=_AbstractClass):
+
+ @property
+ @abc.abstractmethod
+ def can_cache_declaration(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def is_bound(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def channel(self) -> ChannelT:
+ ...
+
+ @abc.abstractmethod
+ def __call__(self, channel: ChannelT) -> 'EntityT':
+ ...
+
+ @abc.abstractmethod
+ def bind(self, channel: ChannelT) -> 'EntityT':
+ ...
+
+ @abc.abstractmethod
+ def maybe_bind(self, channel: Optional[ChannelT]) -> 'EntityT':
+ ...
+
+ @abc.abstractmethod
+ def when_bound(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def as_dict(self, recurse: bool = False) -> Dict:
+ ...
+
+
+ChannelArgT = TypeVar('ChannelArgT', ChannelT, ClientT)
+
+
+class ExchangeT(EntityT, metaclass=_AbstractClass):
+
+ name: str
+ type: str
+ channel: ChannelT
+ arguments: Mapping
+ durable: bool = True
+ auto_delete: bool = False
+ delivery_mode: int
+ no_declare: bool = False
+ passive: bool = False
+
+ def __init__(
+ self,
+ name: str = '',
+ type: str = '',
+ channel: ChannelArgT = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def declare(
+ self,
+ nowait: bool = False,
+ passive: bool = None,
+ channel: ChannelT = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def bind_to(
+ self,
+ exchange: Union[str, 'ExchangeT'] = '',
+ routing_key: str = '',
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def unbind_from(
+ self,
+ source: Union[str, 'ExchangeT'] = '',
+ routing_key: str = '',
+ nowait: bool = False,
+ arguments: Mapping = None,
+ channel: ChannelT = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def publish(
+ self, message: Union[MessageT, str, bytes],
+ routing_key: str = None,
+ mandatory: bool = False,
+ immediate: bool = False,
+ exchange: str = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def delete(
+ self,
+ if_unused: bool = False,
+ nowait: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ def binding(
+ self,
+ routing_key: str = '',
+ arguments: Mapping = None,
+ unbind_arguments: Mapping = None) -> BindingT:
+ ...
+
+ @abc.abstractmethod
+ def Message(
+ self, body: Any,
+ delivery_mode: Union[str, int] = None,
+ properties: Mapping = None,
+ **kwargs) -> Any:
+ ...
+
+
+class QueueT(EntityT, metaclass=_AbstractClass):
+
+ name: str = ''
+ exchange: ExchangeT
+ routing_key: str = ''
+ alias: str
+
+ bindings: Sequence['BindingT'] = None
+
+ durable: bool = True
+ exclusive: bool = False
+ auto_delete: bool = False
+ no_ack: bool = False
+ no_declare: bool = False
+ expires: float
+ message_ttl: float
+ max_length: int
+ max_length_bytes: int
+ max_priority: int
+
+ queue_arguments: Mapping
+ binding_arguments: Mapping
+ consumer_arguments: Mapping
+
+ def __init__(
+ self,
+ name: str = '',
+ exchange: ExchangeT = None,
+ routing_key: str = '',
+ channel: ChannelT = None,
+ bindings: Sequence['BindingT'] = None,
+ on_declared: Callable = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def declare(self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> str:
+ ...
+
+ @abc.abstractmethod
+ async def queue_declare(
+ self,
+ nowait: bool = False,
+ passive: bool = False,
+ channel: ChannelT = None) -> queue_declare_ok_t:
+ ...
+
+ @abc.abstractmethod
+ async def queue_bind(
+ self,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def bind_to(
+ self,
+ exchange: Union[str, ExchangeT] = '',
+ routing_key: str = '',
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def get(
+ self,
+ no_ack: bool = None,
+ accept: Set[str] = None) -> MessageT:
+ ...
+
+ @abc.abstractmethod
+ async def purge(self, nowait: bool = False) -> int:
+ ...
+
+ @abc.abstractmethod
+ async def consume(
+ self,
+ consumer_tag: str = '',
+ callback: Callable = None,
+ no_ack: bool = None,
+ nowait: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def cancel(self, consumer_tag: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def delete(
+ self,
+ if_unused: bool = False,
+ if_empty: bool = False,
+ nowait: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def queue_unbind(
+ self,
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def unbind_from(
+ self,
+ exchange: str = '',
+ routing_key: str = '',
+ arguments: Mapping = None,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ ...
+
+ @classmethod
+ @abc.abstractmethod
+ def from_dict(
+ self, queue: str,
+ exchange: str = None,
+ exchange_type: str = None,
+ binding_key: str = None,
+ routing_key: str = None,
+ delivery_mode: Union[int, str] = None,
+ bindings: Sequence = None,
+ durable: bool = None,
+ queue_durable: bool = None,
+ exchange_durable: bool = None,
+ auto_delete: bool = None,
+ queue_auto_delete: bool = None,
+ exchange_auto_delete: bool = None,
+ exchange_arguments: Mapping = None,
+ queue_arguments: Mapping = None,
+ binding_arguments: Mapping = None,
+ consumer_arguments: Mapping = None,
+ exclusive: bool = None,
+ no_ack: bool = None,
+ **options) -> 'QueueT':
+ ...
+
+
+class BindingT(metaclass=_AbstractClass):
+
+ exchange: ExchangeT
+ routing_key: str
+ arguments: Mapping
+ unbind_arguments: Mapping
+
+ def __init__(
+ self,
+ exchange: ExchangeT = None,
+ routing_key: str = '',
+ arguments: Mapping = None,
+ unbind_arguments: Mapping = None) -> None:
+ ...
+
+ def declare(self, channel: ChannelT, nowait: bool = False) -> None:
+ ...
+
+ def bind(
+ self, entity: EntityT,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ ...
+
+ def unbind(
+ self, entity: EntityT,
+ nowait: bool = False,
+ channel: ChannelT = None) -> None:
+ ...
+
+
+class ConsumerT(Revivable, metaclass=_AbstractClass):
+
+ ContentDisallowed: type
+
+ channel: ChannelT
+ queues: Sequence[QueueT]
+ no_ack: bool
+ auto_declare: bool = True
+ callbacks: Sequence[Callable]
+ on_message: Callable
+ on_decode_error: Callable
+ accept: Set[str]
+ prefetch_count: int = None
+ tag_prefix: str
+
+ @property
+ @abc.abstractmethod
+ def connection(self) -> ClientT:
+ ...
+
+ def __init__(
+ self,
+ channel: ChannelT,
+ queues: Sequence[QueueT] = None,
+ no_ack: bool = None,
+ auto_declare: bool = None,
+ callbacks: Sequence[Callable] = None,
+ on_decode_error: Callable = None,
+ on_message: Callable = None,
+ accept: Sequence[str] = None,
+ prefetch_count: int = None,
+ tag_prefix: str = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def declare(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def register_callback(self, callback: Callable) -> None:
+ ...
+
+ @abc.abstractmethod
+ def __enter__(self) -> 'ConsumerT':
+ ...
+
+ @abc.abstractmethod
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def __aenter__(self) -> 'ConsumerT':
+ ...
+
+ @abc.abstractmethod
+ async def __aexit__(self, *exc_info) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def add_queue(self, queue: QueueT) -> QueueT:
+ ...
+
+ @abc.abstractmethod
+ async def consume(self, no_ack: bool = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def cancel(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None:
+ ...
+
+ @abc.abstractmethod
+ def consuming_from(self, queue: Union[QueueT, str]) -> bool:
+ ...
+
+ @abc.abstractmethod
+ async def purge(self) -> int:
+ ...
+
+ @abc.abstractmethod
+ async def flow(self, active: bool) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def qos(self,
+ prefetch_size: int = 0,
+ prefetch_count: int = 0,
+ apply_global: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def recover(self, requeue: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def receive(self, body: Any, message: MessageT) -> None:
+ ...
+
+
+class ProducerT(Revivable, metaclass=_AbstractClass):
+
+ exchange: ExchangeT
+ routing_key: str = ''
+ serializer: str
+ compression: str
+ auto_declare: bool = True
+ on_return: Callable = None
+
+ __connection__: ClientT
+
+ channel: ChannelT
+
+ @property
+ @abc.abstractmethod
+ def connection(self) -> ClientT:
+ ...
+
+ def __init__(self,
+ channel: ChannelArgT,
+ exchange: ExchangeT = None,
+ routing_key: str = None,
+ serializer: str = None,
+ auto_declare: bool = None,
+ compression: str = None,
+ on_return: Callable = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def declare(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def maybe_declare(self, entity: EntityT,
+ retry: bool = False, **retry_policy) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def publish(
+ self, body: Any,
+ routing_key: str = None,
+ delivery_mode: Union[int, str] = None,
+ mandatory: bool = False,
+ immediate: bool = False,
+ priority: int = 0,
+ content_type: str = None,
+ content_encoding: str = None,
+ serializer: str = None,
+ headers: Mapping = None,
+ compression: str = None,
+ exchange: Union[ExchangeT, str] = None,
+ retry: bool = False,
+ retry_policy: Mapping = None,
+ declare: Sequence[EntityT] = None,
+ expiration: Number = None,
+ **properties) -> None:
+ ...
+
+ @abc.abstractmethod
+ def __enter__(self) -> 'ProducerT':
+ ...
+
+ @abc.abstractmethod
+ async def __aenter__(self) -> 'ProducerT':
+ ...
+
+ @abc.abstractmethod
+ def __exit__(self, *exc_info) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def __aexit__(self, *exc_info) -> None:
+ ...
+
+ @abc.abstractmethod
+ def release(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def close(self) -> None:
+ ...
+
+
+class MessageT(_MessageT):
+
+ MessageStateError: type
+
+ _state: str
+ channel: ChannelT = None
+ delivery_tag: str
+ content_type: str
+ content_encoding: str
+ delivery_info: Mapping
+ headers: Mapping
+ properties: Mapping
+ accept: Set[str]
+ body: Any
+ errors: List[Any]
+
+ @property
+ @abc.abstractmethod
+ def acknowledged(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def payload(self) -> Any:
+ ...
+
+ def __init__(
+ self,
+ body: Any = None,
+ delivery_tag: str = None,
+ content_type: str = None,
+ content_encoding: str = None,
+ delivery_info: Mapping = None,
+ properties: Mapping = None,
+ headers: Mapping = None,
+ postencode: str = None,
+ accept: Set[str] = None,
+ channel: ChannelT = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def ack(self, multiple: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def ack_log_error(
+ self, logger: logging.Logger, errors: Tuple[type, ...],
+ multiple: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def reject_log_error(
+ self, logger: logging.Logger, errors: Tuple[type, ...],
+ requeue: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def reject(self, requeue: bool = False) -> None:
+ ...
+
+ @abc.abstractmethod
+ async def requeue(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def decode(self) -> Any:
+ ...
+
+
+class ResourceT(metaclass=_AbstractClass):
+
+ LimitedExceeeded: type
+
+ close_after_fork: bool = False
+
+ @property
+ @abc.abstractmethod
+ def limit(self) -> int:
+ ...
+
+ def __init__(self,
+ limit: int = None,
+ preload: int = None,
+ close_after_fork: bool = None) -> None:
+ ...
+
+ @abc.abstractmethod
+ def setup(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def acquire(self, block: bool = False, timeout: int = None):
+ ...
+
+ @abc.abstractmethod
+ def prepare(self, resource: Any) -> Any:
+ ...
+
+ @abc.abstractmethod
+ def close_resource(self, resource: Any) -> None:
+ ...
+
+ @abc.abstractmethod
+ def release_resource(self, resource: Any) -> None:
+ ...
+
+ @abc.abstractmethod
+ def replace(self, resource: Any) -> None:
+ ...
+
+ @abc.abstractmethod
+ def release(self, resource: Any) -> None:
+ ...
+
+ @abc.abstractmethod
+ def collect_resource(self, resource: Any) -> None:
+ ...
+
+ @abc.abstractmethod
+ def force_close_all(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def resize(
+ self, limit: int,
+ force: bool = False,
+ ignore_errors: bool = False,
+ reset: bool = False) -> None:
+ ...
+
+
+class SimpleQueueT(metaclass=_AbstractClass):
+
+ Empty: type
+
+ channel: ChannelT
+ producer: ProducerT
+ consumer: ConsumerT
+ no_ack: bool = False
+ queue: QueueT
+ buffer: deque
+
+ def __init__(
+ self,
+ channel: ChannelArgT,
+ name: str,
+ no_ack: bool = None,
+ queue_opts: Mapping = None,
+ exchange_opts: Mapping = None,
+ serializer: str = None,
+ compression: str = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ def __enter__(self) -> 'SimpleQueueT':
+ ...
+
+ @abc.abstractmethod
+ def __exit__(self, *exc_info) -> None:
+ ...
+
+ @abc.abstractmethod
+ def get(self, block: bool = True, timeout: float = None) -> MessageT:
+ ...
+
+ @abc.abstractmethod
+ def get_nowait(self) -> MessageT:
+ ...
+
+ @abc.abstractmethod
+ def put(self, message: Any,
+ serializer: str = None,
+ headers: Mapping = None,
+ compression: str = None,
+ routing_key: str = None,
+ **kwargs) -> None:
+ ...
+
+ @abc.abstractmethod
+ def clear(self) -> int:
+ ...
+
+ @abc.abstractmethod
+ def qsize(self) -> int:
+ ...
+
+ @abc.abstractmethod
+ def close(self) -> None:
+ ...
+
+ @abc.abstractmethod
+ def __len__(self) -> int:
+ ...
+
+ @abc.abstractmethod
+ def __bool__(self) -> bool:
+ ...
diff --git a/kombu/utils/abstract.py b/kombu/utils/abstract.py
deleted file mode 100644
index 27a27d45..00000000
--- a/kombu/utils/abstract.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import abc
-
-from typing import Any
-from typing import Set # noqa
-
-
-def _hasattr(C: Any, attr: str) -> bool:
- return any(attr in B.__dict__ for B in C.__mro__)
-
-
-class _AbstractClass(object, metaclass=abc.ABCMeta):
- __required_attributes__ = frozenset() # type: frozenset
-
- @classmethod
- def _subclasshook_using(cls, parent: Any, C: Any):
- return (
- cls is parent and
- all(_hasattr(C, attr) for attr in cls.__required_attributes__)
- ) or NotImplemented
-
- @classmethod
- def register(cls, other: Any) -> Any:
- # we override `register` to return other for use as a decorator.
- type(cls).register(cls, other)
- return other
-
-
-class Connection(_AbstractClass):
- ...
-
-
-class Entity(_AbstractClass):
- ...
-
-
-class Consumer(_AbstractClass):
- ...
-
-
-class Producer(_AbstractClass):
- ...
-
-
-class Messsage(_AbstractClass):
- ...
-
-
-class Resource(_AbstractClass):
- ...
diff --git a/kombu/utils/amq_manager.py b/kombu/utils/amq_manager.py
index fcf1c6d5..c8cb8508 100644
--- a/kombu/utils/amq_manager.py
+++ b/kombu/utils/amq_manager.py
@@ -1,9 +1,9 @@
"""AMQP Management API utilities."""
from typing import Any, Union
-from . import abstract
+from kombu.types import ClientT
-def get_manager(client: abstract.Connection,
+def get_manager(client: ClientT,
hostname: str = None,
port: Union[int, str] = None,
userid: str = None,
diff --git a/kombu/utils/collections.py b/kombu/utils/collections.py
index 0ee27399..b202470b 100644
--- a/kombu/utils/collections.py
+++ b/kombu/utils/collections.py
@@ -1,4 +1,5 @@
"""Custom maps, sequences, etc."""
+from typing import Any, Union
class HashedSeq(list):
@@ -10,15 +11,15 @@ class HashedSeq(list):
__slots__ = 'hashvalue'
- def __init__(self, *seq):
+ def __init__(self, *seq) -> None:
self[:] = seq
self.hashvalue = hash(seq)
- def __hash__(self):
+ def __hash__(self) -> int:
return self.hashvalue
-def eqhash(o):
+def eqhash(o: Any) -> Union[int, HashedSeq]:
"""Call ``obj.__eqhash__``."""
try:
return o.__eqhash__()
@@ -29,14 +30,17 @@ def eqhash(o):
class EqualityDict(dict):
"""Dict using the eq operator for keying."""
- def __getitem__(self, key):
+ def __getitem__(self, key: Any) -> Any:
h = eqhash(key)
if h not in self:
return self.__missing__(key)
return dict.__getitem__(self, h)
- def __setitem__(self, key, value):
- return dict.__setitem__(self, eqhash(key), value)
+ def __setitem__(self, key: Any, value: Any) -> None:
+ dict.__setitem__(self, eqhash(key), value)
- def __delitem__(self, key):
- return dict.__delitem__(self, eqhash(key))
+ def __delitem__(self, key: Any) -> None:
+ dict.__delitem__(self, eqhash(key))
+
+ def __missing__(self, key: Any) -> Any:
+ raise KeyError(key)
diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py
index c9390ec3..b12dd41e 100644
--- a/kombu/utils/compat.py
+++ b/kombu/utils/compat.py
@@ -3,6 +3,8 @@ import numbers
import sys
from functools import wraps
from contextlib import contextmanager
+from typing import Any, Callable, Iterator, Optional, Tuple, Union
+from .typing import SupportsFileno
try:
from io import UnsupportedOperation
@@ -21,17 +23,17 @@ except ImportError: # pragma: no cover
_environment = None
-def coro(gen):
+def coro(gen: Callable) -> Callable:
"""Decorator to mark generator as co-routine."""
@wraps(gen)
- def wind_up(*args, **kwargs):
+ def wind_up(*args, **kwargs) -> Any:
it = gen(*args, **kwargs)
next(it)
return it
return wind_up
-def _detect_environment():
+def _detect_environment() -> str:
# ## -eventlet-
if 'eventlet' in sys.modules:
try:
@@ -57,14 +59,15 @@ def _detect_environment():
return 'default'
-def detect_environment():
+def detect_environment() -> str:
"""Detect the current environment: default, eventlet, or gevent."""
global _environment
if _environment is None:
_environment = _detect_environment()
return _environment
-def entrypoints(namespace):
+
+def entrypoints(namespace: str) -> Iterator[Tuple[str, Any]]:
"""Return setuptools entrypoints for namespace."""
try:
from pkg_resources import iter_entry_points
@@ -73,14 +76,14 @@ def entrypoints(namespace):
return ((ep, ep.load()) for ep in iter_entry_points(namespace))
-def fileno(f):
+def fileno(f: Union[SupportsFileno, numbers.Integral]) -> numbers.Integral:
"""Get fileno from file-like object."""
if isinstance(f, numbers.Integral):
return f
return f.fileno()
-def maybe_fileno(f):
+def maybe_fileno(f: Any) -> Optional[numbers.Integral]:
"""Get object fileno, or :const:`None` if not defined."""
try:
return fileno(f)
diff --git a/kombu/utils/debug.py b/kombu/utils/debug.py
index 6bc59020..df98a04d 100644
--- a/kombu/utils/debug.py
+++ b/kombu/utils/debug.py
@@ -1,10 +1,7 @@
"""Debugging support."""
import logging
-
-from typing import Any, Optional, Sequence, Union
-
+from typing import Any, Sequence, Union
from vine.utils import wraps
-
from kombu.log import get_logger
__all__ = ['setup_logging', 'Logwrapped']
@@ -29,8 +26,8 @@ class Logwrapped:
__ignore = ('__enter__', '__exit__')
def __init__(self, instance: Any,
- logger: Optional[LoggerArg]=None,
- ident: Optional[str]=None) -> None:
+ logger: LoggerArg = None,
+ ident: str = None) -> None:
self.instance = instance
self.logger = get_logger(logger) # type: logging.Logger
self.ident = ident
diff --git a/kombu/utils/div.py b/kombu/utils/div.py
index 21b4c59a..0535c30f 100644
--- a/kombu/utils/div.py
+++ b/kombu/utils/div.py
@@ -1,9 +1,13 @@
"""Div. Utilities."""
-from .encoding import default_encode
import sys
+from typing import Any, Callable, IO
+from .encoding import default_encode
-def emergency_dump_state(state, open_file=open, dump=None, stderr=None):
+def emergency_dump_state(state: Any,
+ open_file: Callable = open,
+ dump: Callable = None,
+ stderr: IO = None):
"""Dump message state to stdout or file."""
from pprint import pformat
from tempfile import mktemp
diff --git a/kombu/utils/encoding.py b/kombu/utils/encoding.py
index a68fef25..8faf32a1 100644
--- a/kombu/utils/encoding.py
+++ b/kombu/utils/encoding.py
@@ -8,7 +8,6 @@ applications without crashing from the infamous
import sys
import traceback
-
from typing import Any, AnyStr, IO, Optional
str_t = str
@@ -67,7 +66,7 @@ def default_encode(obj: Any) -> Any:
return obj
-def safe_str(s: Any, errors: str='replace') -> str:
+def safe_str(s: Any, errors: str = 'replace') -> str:
"""Safe form of str(), void of unicode errors."""
s = bytes_to_str(s)
if not isinstance(s, (str, bytes)):
@@ -75,7 +74,7 @@ def safe_str(s: Any, errors: str='replace') -> str:
return _safe_str(s, errors)
-def _safe_str(s: Any, errors: str='replace', file: IO=None) -> str:
+def _safe_str(s: Any, errors: str = 'replace', file: IO = None) -> str:
if isinstance(s, str):
return s
try:
@@ -85,7 +84,7 @@ def _safe_str(s: Any, errors: str='replace', file: IO=None) -> str:
type(s), exc, '\n'.join(traceback.format_stack()))
-def safe_repr(o: Any, errors='replace') -> str:
+def safe_repr(o: Any, errors: str = 'replace') -> str:
"""Safe form of repr, void of Unicode errors."""
try:
return repr(o)
diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py
index 8be03494..d10f5868 100644
--- a/kombu/utils/eventio.py
+++ b/kombu/utils/eventio.py
@@ -8,7 +8,7 @@ from numbers import Integral
from typing import Any, Callable, Optional, Sequence, IO, cast
from typing import Set, Tuple # noqa
from . import fileno
-from .typing import Fd, Timeout
+from .typing import Fd
from .compat import detect_environment
__all__ = ['poll']
@@ -81,7 +81,7 @@ class _epoll(BasePoller):
except (PermissionError, FileNotFoundError):
pass
- def poll(self, timeout: Timeout) -> Optional[Sequence]:
+ def poll(self, timeout: float) -> Optional[Sequence]:
try:
return self._epoll.poll(timeout if timeout is not None else -1)
except Exception as exc:
@@ -148,7 +148,7 @@ class _kqueue(BasePoller):
except ValueError:
pass
- def poll(self, timeout: Timeout) -> Sequence:
+ def poll(self, timeout: float) -> Sequence:
try:
kevents = self._kcontrol(None, 1000, timeout)
except Exception as exc:
@@ -212,7 +212,7 @@ class _poll(BasePoller):
self._quick_unregister(fd)
return fd
- def poll(self, timeout: Timeout,
+ def poll(self, timeout: float,
round: Callable=math.ceil,
POLLIN: int=POLLIN, POLLOUT: int=POLLOUT, POLLERR: int=POLLERR,
READ: int=READ, WRITE: int=WRITE, ERR: int=ERR,
@@ -284,7 +284,7 @@ class _select(BasePoller):
self._wfd.discard(fd)
self._efd.discard(fd)
- def poll(self, timeout: Timeout) -> Sequence:
+ def poll(self, timeout: float) -> Sequence:
try:
read, write, error = _selectf(
cast(Sequence, self._rfd),
diff --git a/kombu/utils/functional.py b/kombu/utils/functional.py
index 85feffa5..8d631809 100644
--- a/kombu/utils/functional.py
+++ b/kombu/utils/functional.py
@@ -6,9 +6,11 @@ from itertools import count, repeat
from time import sleep
from typing import (
Any, Callable, Dict, Iterable, Iterator,
- KeysView, Mapping, Optional, Sequence, Tuple,
+ KeysView, ItemsView, Mapping, Optional, Sequence, Tuple, ValuesView,
)
+from amqp.types import ChannelT
from vine.utils import wraps
+from ..types import ClientT
from .encoding import safe_repr as _safe_repr
__all__ = [
@@ -17,27 +19,36 @@ __all__ = [
]
KEYWORD_MARK = object()
-
MemoizeKeyFun = Callable[[Sequence, Mapping], Any]
class ChannelPromise(object):
+ __value__: ChannelT
+
+ def __init__(self, connection: ClientT) -> None:
+ self.__connection__ = connection
- def __init__(self, contract):
- self.__contract__ = contract
+ def __call__(self) -> Any:
+ try:
+ return self.__value__
+ except AttributeError:
+ value = self.__value__ = self.__connection__.default_channel
+ return value
- def __call__(self):
+ async def resolve(self):
try:
return self.__value__
except AttributeError:
- value = self.__value__ = self.__contract__()
+ await self.__connection__.connect()
+ value = self.__value__ = self.__connection__.default_channel
+ await value.open()
return value
- def __repr__(self):
+ def __repr__(self) -> str:
try:
return repr(self.__value__)
except AttributeError:
- return '<promise: 0x{0:x}>'.format(id(self.__contract__))
+ return '<promise: 0x{0:x}>'.format(id(self.__connection__))
class LRUCache(UserDict):
@@ -53,7 +64,7 @@ class LRUCache(UserDict):
def __init__(self, limit: int = None) -> None:
self.limit = limit
self.mutex = threading.RLock()
- self.data = OrderedDict() # type: OrderedDict
+ self.data: OrderedDict = OrderedDict()
def __getitem__(self, key: Any) -> Any:
with self.mutex:
@@ -69,7 +80,7 @@ class LRUCache(UserDict):
for _ in range(len(data) - limit):
data.popitem(last=False)
- def popitem(self, last: bool=True) -> Any:
+ def popitem(self, last: bool = True) -> Any:
with self.mutex:
return self.data.popitem(last)
@@ -83,7 +94,7 @@ class LRUCache(UserDict):
def __iter__(self) -> Iterator:
return iter(self.data)
- def items(self) -> Iterator[Tuple[Any, Any]]:
+ def items(self) -> ItemsView[Any, Any]:
with self.mutex:
for k in self:
try:
@@ -91,7 +102,7 @@ class LRUCache(UserDict):
except KeyError: # pragma: no cover
pass
- def values(self) -> Iterator[Any]:
+ def values(self) -> ValuesView[Any]:
with self.mutex:
for k in self:
try:
@@ -104,7 +115,7 @@ class LRUCache(UserDict):
with self.mutex:
return self.data.keys()
- def incr(self, key: Any, delta: int=1) -> Any:
+ def incr(self, key: Any, delta: int = 1) -> Any:
with self.mutex:
# this acts as memcached does- store as a string, but return a
# integer as long as it exists and we can cast it
@@ -122,9 +133,9 @@ class LRUCache(UserDict):
self.mutex = threading.RLock()
-def memoize(maxsize: Optional[int]=None,
- keyfun: Optional[MemoizeKeyFun]=None,
- Cache: Any=LRUCache) -> Callable:
+def memoize(maxsize: int = None,
+ keyfun: MemoizeKeyFun = None,
+ Cache: Any = LRUCache) -> Callable:
"""Decorator to cache function return value."""
def _memoize(fun: Callable) -> Callable:
@@ -189,10 +200,10 @@ class lazy:
def __repr__(self) -> str:
return repr(self())
- def __eq__(self, rhs) -> bool:
+ def __eq__(self, rhs: Any) -> bool:
return self() == rhs
- def __ne__(self, rhs) -> bool:
+ def __ne__(self, rhs: Any) -> bool:
return self() != rhs
def __deepcopy__(self, memo: Dict) -> Any:
@@ -200,8 +211,10 @@ class lazy:
return self
def __reduce__(self) -> Any:
- return (self.__class__, (self._fun,), {'_args': self._args,
- '_kwargs': self._kwargs})
+ return (self.__class__, (self._fun,), {
+ '_args': self._args,
+ '_kwargs': self._kwargs,
+ })
def maybe_evaluate(value: Any) -> Any:
@@ -212,8 +225,9 @@ def maybe_evaluate(value: Any) -> Any:
def is_list(l: Any,
- scalars: tuple=(Mapping, str),
- iters: tuple=(Iterable,)) -> bool:
+ *,
+ scalars: tuple = (Mapping, str),
+ iters: tuple = (Iterable,)) -> bool:
"""Return true if the object is iterable.
Note:
@@ -222,18 +236,19 @@ def is_list(l: Any,
return isinstance(l, iters) and not isinstance(l, scalars or ())
-def maybe_list(l: Any, scalars: tuple=(Mapping, str)) -> Optional[Sequence]:
+def maybe_list(l: Any, *,
+ scalars: tuple = (Mapping, str)) -> Optional[Sequence]:
"""Return list of one element if ``l`` is a scalar."""
return l if l is None or is_list(l, scalars) else [l]
-def dictfilter(d: Optional[Mapping]=None, **kw) -> Mapping:
+def dictfilter(d: Mapping = None, **kw) -> Mapping:
"""Remove all keys from dict ``d`` whose value is :const:`None`."""
d = kw if d is None else (dict(d, **kw) if kw else d)
return {k: v for k, v in d.items() if v is not None}
-def shufflecycle(it):
+def shufflecycle(it: Sequence) -> Iterator:
it = list(it) # don't modify callers list
shuffle = random.shuffle
for _ in repeat(None):
@@ -241,7 +256,10 @@ def shufflecycle(it):
yield it[0]
-def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False):
+def fxrange(start: float = 1.0,
+ stop: float = None,
+ step: float = 1.0,
+ repeatlast: bool = False) -> Iterator[float]:
cur = start * 1.0
while 1:
if not stop or cur <= stop:
@@ -253,7 +271,10 @@ def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False):
yield cur - step
-def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0):
+def fxrangemax(start: float = 1.0,
+ stop: float = None,
+ step: float = 1.0,
+ max: float = 100.0) -> Iterator[float]:
sum_, cur = 0, start * 1.0
while 1:
if sum_ >= max:
@@ -266,9 +287,18 @@ def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0):
sum_ += cur
-def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
- max_retries=None, interval_start=2, interval_step=2,
- interval_max=30, callback=None):
+async def retry_over_time(
+ fun: Callable,
+ catch: Tuple[Any, ...],
+ args: Sequence = [],
+ kwargs: Mapping[str, Any] = {},
+ *,
+ errback: Callable = None,
+ max_retries: int = None,
+ interval_start: float = 2.0,
+ interval_step: float = 2.0,
+ interval_max: float = 30.0,
+ callback: Callable = None) -> Any:
"""Retry the function over and over until max retries is exceeded.
For each retry we sleep a for a while before we try again, this interval
@@ -303,7 +333,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
interval_step, repeatlast=True)
for retries in count():
try:
- return fun(*args, **kwargs)
+ return await fun(*args, **kwargs)
except catch as exc:
if max_retries and retries >= max_retries:
raise
@@ -320,11 +350,17 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
sleep(abs(int(tts) - tts))
-def reprkwargs(kwargs, sep=', ', fmt='{0}={1}'):
+def reprkwargs(kwargs: Mapping,
+ *,
+ sep: str = ', ', fmt: str = '{0}={1}') -> str:
return sep.join(fmt.format(k, _safe_repr(v)) for k, v in kwargs.items())
-def reprcall(name, args=(), kwargs={}, sep=', '):
+def reprcall(name: str,
+ args: Sequence = (),
+ kwargs: Mapping = {},
+ *,
+ sep: str = ', ') -> str:
return '{0}({1}{2}{3})'.format(
name, sep.join(map(_safe_repr, args or ())),
(args and kwargs) and sep or '',
diff --git a/kombu/utils/imports.py b/kombu/utils/imports.py
index 5262b41d..345ec4bb 100644
--- a/kombu/utils/imports.py
+++ b/kombu/utils/imports.py
@@ -1,10 +1,18 @@
"""Import related utilities."""
import importlib
import sys
+from typing import Any, Callable, Mapping
-def symbol_by_name(name, aliases={}, imp=None, package=None,
- sep='.', default=None, **kwargs):
+def symbol_by_name(
+ name: Any,
+ aliases: Mapping[str, str] = {},
+ *,
+ imp: Callable = None,
+ package: str = None,
+ sep: str = '.',
+ default: Any = None,
+ **kwargs) -> Any:
"""Get symbol by qualified name.
The name should be the full dot-separated path to the class::
diff --git a/kombu/utils/json.py b/kombu/utils/json.py
index 6668a73e..b3c6dfb5 100644
--- a/kombu/utils/json.py
+++ b/kombu/utils/json.py
@@ -10,7 +10,7 @@ from .typing import AnyBuffer
try:
from django.utils.functional import Promise as DjangoPromise
except ImportError: # pragma: no cover
- class DjangoPromise(object): # noqa
+ class DjangoPromise(object): # noqa, type: ignore
"""Dummy object."""
try:
@@ -24,20 +24,21 @@ except ImportError: # pragma: no cover
else:
from simplejson.decoder import JSONDecodeError as _DecodeError
-_encoder_cls = type(json._default_encoder)
-_default_encoder = None # ... set to JSONEncoder below.
+_encoder_cls: type = type(json._default_encoder) # type: ignore
+_default_encoder: type = None # ... set to JSONEncoder below.
class JSONEncoder(_encoder_cls):
"""Kombu custom json encoder."""
- def default(self, o,
+ def default(self, o: Any,
+ *,
dates=(datetime.datetime, datetime.date),
times=(datetime.time,),
textual=(decimal.Decimal, uuid.UUID, DjangoPromise),
isinstance=isinstance,
datetime=datetime.datetime,
- str=str):
+ str=str) -> Any:
reducer = getattr(o, '__json__', None)
if reducer is not None:
o = reducer()
@@ -59,6 +60,7 @@ _default_encoder = JSONEncoder
def dumps(s: Any,
+ *,
_dumps: Callable = json.dumps,
cls: Any = None,
default_kwargs: Dict = _json_extra_kwargs,
@@ -68,7 +70,7 @@ def dumps(s: Any,
**dict(default_kwargs, **kwargs))
-def loads(s: AnyBuffer, _loads: Callable = json.loads) -> Any:
+def loads(s: AnyBuffer, *, _loads: Callable = json.loads) -> Any:
"""Deserialize json from string."""
# None of the json implementations supports decoding from
# a buffer/memoryview, or even reading from a stream
diff --git a/kombu/utils/limits.py b/kombu/utils/limits.py
index ee9c8003..2c34f0de 100644
--- a/kombu/utils/limits.py
+++ b/kombu/utils/limits.py
@@ -1,8 +1,7 @@
"""Token bucket implementation for rate limiting."""
from collections import deque
from time import monotonic
-from typing import Any, Optional
-
+from typing import Any
from .typing import Float
__all__ = ['TokenBucket']
@@ -24,21 +23,21 @@ class TokenBucket:
"""
#: The rate in tokens/second that the bucket will be refilled.
- fill_rate = None
+ fill_rate: float = None
#: Maximum number of tokens in the bucket.
- capacity = 1.0 # type: float
+ capacity: float = 1.0
#: Timestamp of the last time a token was taken out of the bucket.
- timestamp = None # type: Optional[float]
+ timestamp: float = None
- def __init__(self, fill_rate: Optional[Float],
- capacity: Float=1.0) -> None:
+ def __init__(self, fill_rate: Float = None,
+ capacity: Float = 1.0) -> None:
self.capacity = float(capacity)
self._tokens = capacity
self.fill_rate = float(fill_rate)
self.timestamp = monotonic()
- self.contents = deque() # type: deque
+ self.contents = deque()
def add(self, item: Any) -> None:
self.contents.append(item)
@@ -49,7 +48,7 @@ class TokenBucket:
def clear_pending(self) -> None:
self.contents.clear()
- def can_consume(self, tokens: int=1) -> bool:
+ def can_consume(self, tokens: int = 1) -> bool:
"""Check if one or more tokens can be consumed.
Returns:
diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py
index 74d4f25a..624a06d0 100644
--- a/kombu/utils/scheduling.py
+++ b/kombu/utils/scheduling.py
@@ -1,8 +1,7 @@
"""Scheduling Utilities."""
from itertools import count
-from typing import Any, Callable, Iterable, Optional, Sequence, Union
+from typing import Any, Callable, Iterable, Sequence, Union
from typing import List # noqa
-
from .imports import symbol_by_name
__all__ = [
@@ -24,12 +23,12 @@ class FairCycle:
Arguments:
fun (Callable): Callback to call.
- resources (Sequence[Any]): List of resources.
+ resources (Sequence): List of resources.
predicate (type): Exception predicate.
"""
def __init__(self, fun: Callable, resources: Sequence,
- predicate: Any=Exception) -> None:
+ predicate: Any = Exception) -> None:
self.fun = fun
self.resources = resources
self.predicate = predicate
@@ -72,8 +71,8 @@ class BaseCycle:
class round_robin_cycle(BaseCycle):
"""Iterator that cycles between items in round-robin."""
- def __init__(self, it: Optional[Iterable]=None) -> None:
- self.items = list(it if it is not None else []) # type: List
+ def __init__(self, it: Iterable = None) -> None:
+ self.items = list(it if it is not None else [])
def update(self, it: Sequence) -> None:
"""Update items from iterable."""
diff --git a/kombu/utils/text.py b/kombu/utils/text.py
index 6f7977fc..f2e8280d 100644
--- a/kombu/utils/text.py
+++ b/kombu/utils/text.py
@@ -1,11 +1,18 @@
# -*- coding: utf-8 -*-
"""Text Utilities."""
from difflib import SequenceMatcher
-from typing import Iterator, Sequence, NamedTuple, Tuple
-
+from numbers import Number
+from typing import (
+ Iterator, Sequence, NamedTuple, Optional, SupportsInt, Tuple, Union,
+)
from kombu import version_info_t
-fmatch_t = NamedTuple('fmatch_t', [('ratio', float), ('key', str)])
+
+class fmatch_t(NamedTuple):
+ """Return value of :func:`fmatch_iter`."""
+
+ ratio: float
+ key: str
def escape_regex(p: str, white: str = '') -> str:
@@ -30,14 +37,14 @@ def fmatch_iter(needle: str, haystack: Sequence[str],
def fmatch_best(needle: str, haystack: Sequence[str],
- min_ratio: float = 0.6) -> str:
+ min_ratio: float = 0.6) -> Optional[str]:
"""Fuzzy match - Find best match (scalar)."""
try:
return sorted(
fmatch_iter(needle, haystack, min_ratio), reverse=True,
)[0][1]
except IndexError:
- pass
+ return None
def version_string_as_tuple(s: str) -> version_info_t:
@@ -52,7 +59,9 @@ def version_string_as_tuple(s: str) -> version_info_t:
return v
-def _unpack_version(major: int, minor: int = 0, micro: int = 0,
+def _unpack_version(major: Union[SupportsInt, str, bytes],
+ minor: Union[SupportsInt, str, bytes] = 0,
+ micro: Union[SupportsInt, str, bytes] = 0,
releaselevel: str = '',
serial: str = '') -> version_info_t:
return version_info_t(int(major), int(minor), micro, releaselevel, serial)
diff --git a/kombu/utils/time.py b/kombu/utils/time.py
index 64d31d3b..72448c3f 100644
--- a/kombu/utils/time.py
+++ b/kombu/utils/time.py
@@ -1,10 +1,8 @@
"""Time Utilities."""
-from __future__ import absolute_import, unicode_literals
-
+from typing import Optional, Union
__all__ = ['maybe_s_to_ms']
-def maybe_s_to_ms(v):
- # type: (Optional[Union[int, float]]) -> int
+def maybe_s_to_ms(v: Optional[Union[int, float]]) -> Optional[int]:
"""Convert seconds to milliseconds, but return None for None."""
return int(float(v) * 1000.0) if v is not None else v
diff --git a/kombu/utils/typing.py b/kombu/utils/typing.py
deleted file mode 100644
index 1f4843f6..00000000
--- a/kombu/utils/typing.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from abc import abstractmethod
-from typing import (
- _Protocol, AnyStr, SupportsFloat, SupportsInt, Union,
-)
-
-Float = Union[SupportsInt, SupportsFloat, AnyStr]
-Int = Union[SupportsInt, AnyStr]
-
-Port = Union[SupportsInt, str]
-
-AnyBuffer = Union[AnyStr, memoryview]
-
-
-class SupportsFileno(_Protocol):
- __slots__ = ()
-
- @abstractmethod
- def __fileno__(self) -> int:
- ...
-
-Fd = Union[int, SupportsFileno]
diff --git a/kombu/utils/url.py b/kombu/utils/url.py
index 75db4d89..7dae3464 100644
--- a/kombu/utils/url.py
+++ b/kombu/utils/url.py
@@ -1,30 +1,25 @@
"""URL Utilities."""
from functools import partial
from typing import Any, Dict, Mapping, NamedTuple
+from urllib.parse import parse_qsl, quote, unquote, urlparse
from .typing import Port
-try:
- from urllib.parse import parse_qsl, quote, unquote, urlparse
-except ImportError:
- from urllib import quote, unquote # noqa
- from urlparse import urlparse, parse_qsl # noqa
safequote = partial(quote, safe='')
-urlparts = NamedTuple('urlparts', [
- ('scheme', str),
- ('hostname', str),
- ('port', int),
- ('username', str),
- ('password', str),
- ('path', str),
- ('query', Dict),
-])
+class urlparts(NamedTuple):
+ scheme: str
+ hostname: str
+ port: int
+ username: str
+ password: str
+ path: str
+ query: Dict
-def parse_url(url: str) -> Dict:
+def parse_url(url: str) -> Mapping:
"""Parse URL into mapping of components."""
- scheme, host, port, user, password, path, query = _parse_url(url)
+ scheme, host, port, user, password, path, query = url_to_parts(url)
return dict(transport=scheme, hostname=host,
port=port, userid=user,
password=password, virtual_host=path, **query)
@@ -47,7 +42,6 @@ def url_to_parts(url: str) -> urlparts:
unquote(path or '') or None,
dict(parse_qsl(parts.query)),
)
-_parse_url = url_to_parts # noqa
def as_url(scheme: str,
@@ -79,7 +73,7 @@ def as_url(scheme: str,
def sanitize_url(url: str, mask: str = '**') -> str:
"""Return copy of URL with password removed."""
- return as_url(*_parse_url(url), sanitize=True, mask=mask)
+ return as_url(*url_to_parts(url), sanitize=True, mask=mask)
def maybe_sanitize_url(url: Any, mask: str = '**') -> Any:
diff --git a/kombu/utils/uuid.py b/kombu/utils/uuid.py
index 7de82b29..04130ae2 100644
--- a/kombu/utils/uuid.py
+++ b/kombu/utils/uuid.py
@@ -1,8 +1,9 @@
"""UUID utilities."""
from uuid import uuid4
+from typing import Callable
-def uuid(_uuid=uuid4):
+def uuid(*, _uuid: Callable = uuid4) -> str:
"""Generate unique id in UUID4 format.
See Also:
diff --git a/kombu/transport/SLMQ.py b/moved/transport/SLMQ.py
index 91eebe5f..91eebe5f 100644
--- a/kombu/transport/SLMQ.py
+++ b/moved/transport/SLMQ.py
diff --git a/kombu/transport/SQS.py b/moved/transport/SQS.py
index f5766edb..f5766edb 100644
--- a/kombu/transport/SQS.py
+++ b/moved/transport/SQS.py
diff --git a/kombu/transport/consul.py b/moved/transport/consul.py
index fa77e6e9..fa77e6e9 100644
--- a/kombu/transport/consul.py
+++ b/moved/transport/consul.py
diff --git a/kombu/transport/etcd.py b/moved/transport/etcd.py
index 4d5c652a..3c25d508 100644
--- a/kombu/transport/etcd.py
+++ b/moved/transport/etcd.py
@@ -4,22 +4,16 @@ It uses Etcd as a store to transport messages in Queues
It uses python-etcd for talking to Etcd's HTTP API
"""
-from __future__ import absolute_import, unicode_literals
-
import os
import socket
-
from collections import defaultdict
from contextlib import contextmanager
-
+from queue import Empty
from kombu.exceptions import ChannelError
-from kombu.five import Empty
from kombu.log import get_logger
from kombu.utils.json import loads, dumps
from kombu.utils.objects import cached_property
-
from . import virtual
-
try:
import etcd
except ImportError:
diff --git a/kombu/transport/filesystem.py b/moved/transport/filesystem.py
index 2a5737fb..2a5737fb 100644
--- a/kombu/transport/filesystem.py
+++ b/moved/transport/filesystem.py
diff --git a/kombu/transport/librabbitmq.py b/moved/transport/librabbitmq.py
index 2ea5b779..2ea5b779 100644
--- a/kombu/transport/librabbitmq.py
+++ b/moved/transport/librabbitmq.py
diff --git a/kombu/transport/memory.py b/moved/transport/memory.py
index e3b4e441..e3b4e441 100644
--- a/kombu/transport/memory.py
+++ b/moved/transport/memory.py
diff --git a/kombu/transport/mongodb.py b/moved/transport/mongodb.py
index 7c8b5bb3..7c8b5bb3 100644
--- a/kombu/transport/mongodb.py
+++ b/moved/transport/mongodb.py
diff --git a/kombu/transport/pyro.py b/moved/transport/pyro.py
index c52532aa..c52532aa 100644
--- a/kombu/transport/pyro.py
+++ b/moved/transport/pyro.py
diff --git a/kombu/transport/qpid.py b/moved/transport/qpid.py
index 22046241..9aa4fb48 100644
--- a/kombu/transport/qpid.py
+++ b/moved/transport/qpid.py
@@ -76,8 +76,6 @@ options override and replace any other default or specified values. If using
Celery, this can be accomplished by setting the
*BROKER_TRANSPORT_OPTIONS* Celery option.
"""
-from __future__ import absolute_import, unicode_literals
-
from collections import OrderedDict
import os
import select
@@ -87,6 +85,8 @@ import sys
import uuid
from gettext import gettext as _
+from time import monotonic
+from queue import Empty
import amqp.protocol
@@ -116,7 +116,6 @@ except ImportError: # pragma: no cover
qpid = None
-from kombu.five import Empty, items, monotonic
from kombu.log import get_logger
from kombu.transport.virtual import Base64, Message
from kombu.transport import base
@@ -1585,7 +1584,7 @@ class Transport(base.Transport):
"""
conninfo = self.client
- for name, default_value in items(self.default_connection_params):
+ for name, default_value in self.default_connection_params.items():
if not getattr(conninfo, name, None):
setattr(conninfo, name, default_value)
if conninfo.ssl:
diff --git a/kombu/transport/redis.py b/moved/transport/redis.py
index e3e6de10..4b65e7e2 100644
--- a/kombu/transport/redis.py
+++ b/moved/transport/redis.py
@@ -3,9 +3,9 @@ import numbers
import socket
from bisect import bisect
-from collections import namedtuple
from contextlib import contextmanager
from time import time
+from typing import NamedTuple, Tuple
from queue import Empty
from vine import promise
@@ -18,7 +18,7 @@ from kombu.utils.encoding import bytes_to_str
from kombu.utils.json import loads, dumps
from kombu.utils.objects import cached_property
from kombu.utils.scheduling import cycle_by_name
-from kombu.utils.url import _parse_url
+from kombu.utils.url import url_to_parts
from kombu.utils.uuid import uuid
from . import virtual
@@ -42,15 +42,19 @@ DEFAULT_DB = 0
PRIORITY_STEPS = [0, 3, 6, 9]
-error_classes_t = namedtuple('error_classes_t', (
- 'connection_errors', 'channel_errors',
-))
-
NO_ROUTE_ERROR = """
Cannot route message for exchange {0!r}: Table empty or key no longer exists.
Probably the key ({1!r}) has been removed from the Redis database.
"""
+
+class error_classes_t(NamedTuple):
+ """Return value of :func:`get_redis_error_classes`."""
+
+ connection_errors: Tuple[type, ...]
+ channel_errors: Tuple[type, ...]
+
+
# This implementation may seem overly complex, but I assure you there is
# a good reason for doing it this way.
#
@@ -892,7 +896,7 @@ class Channel(virtual.Channel):
pass
host = connparams['host']
if '://' in host:
- scheme, _, _, _, _, path, query = _parse_url(host)
+ scheme, _, _, _, _, path, query = url_to_parts(host)
if scheme == 'socket':
connparams = self._filter_tcp_connparams(**connparams)
connparams.update({
diff --git a/kombu/transport/zookeeper.py b/moved/transport/zookeeper.py
index cf9f0dc8..cf9f0dc8 100644
--- a/kombu/transport/zookeeper.py
+++ b/moved/transport/zookeeper.py
diff --git a/requirements/test.txt b/requirements/test.txt
index e1036604..6de0f611 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1,3 +1,4 @@
pytz>dev
case>=1.5.2
pytest
+pytest-asyncio
diff --git a/t/integration/test_async.py b/t/integration/test_async.py
new file mode 100644
index 00000000..094b24c3
--- /dev/null
+++ b/t/integration/test_async.py
@@ -0,0 +1,51 @@
+import os
+import pytest
+from kombu import Connection, Exchange, Queue, Producer, Consumer
+from kombu.pools import producers
+
+BROKER_URL = os.environ.get('BROKER_URL', 'pyamqp://')
+
+queue1 = Queue('testq32', Exchange('testq32', 'direct'), 'testq32')
+
+
+@pytest.fixture()
+def connection():
+ return Connection(BROKER_URL)
+
+
+@pytest.mark.asyncio
+async def test_queue_declare(connection):
+ async with connection:
+ async with connection.default_channel as channel:
+ ret = await queue1(channel).declare()
+ assert ret == queue1.name
+
+
+@pytest.mark.asyncio
+async def test_produce_consume(connection):
+
+ messages_received = [0]
+
+ async def on_message(message):
+ messages_received[0] += 1
+ await message.ack()
+
+ async with connection as connection:
+ async with Consumer(connection,
+ queues=[queue1],
+ on_message=on_message) as consumer:
+ async with connection.clone() as w_connection:
+ await w_connection.connect()
+ assert w_connection._connection
+ async with producers[w_connection].acquire() as producer:
+ for i in range(10):
+ await producer.publish(
+ str(i),
+ exchange=queue1.exchange,
+ routing_key=queue1.routing_key,
+ retry=False,
+ declare=[queue1],
+ )
+ while messages_received[0] < 10:
+ await connection.drain_events()
+ assert messages_received[0] == 10
diff --git a/t/integration/__init__.py b/t/oldint/__init__.py
index 23ffb4c1..23ffb4c1 100644
--- a/t/integration/__init__.py
+++ b/t/oldint/__init__.py
diff --git a/t/integration/tests/__init__.py b/t/oldint/tests/__init__.py
index 094b08e0..094b08e0 100644
--- a/t/integration/tests/__init__.py
+++ b/t/oldint/tests/__init__.py
diff --git a/t/integration/tests/test_SLMQ.py b/t/oldint/tests/test_SLMQ.py
index 8428f7d1..8428f7d1 100644
--- a/t/integration/tests/test_SLMQ.py
+++ b/t/oldint/tests/test_SLMQ.py
diff --git a/t/integration/tests/test_SQS.py b/t/oldint/tests/test_SQS.py
index 571cff24..571cff24 100644
--- a/t/integration/tests/test_SQS.py
+++ b/t/oldint/tests/test_SQS.py
diff --git a/t/integration/tests/test_amqp.py b/t/oldint/tests/test_amqp.py
index f7aa762e..f7aa762e 100644
--- a/t/integration/tests/test_amqp.py
+++ b/t/oldint/tests/test_amqp.py
diff --git a/t/integration/tests/test_librabbitmq.py b/t/oldint/tests/test_librabbitmq.py
index 41a21b3c..41a21b3c 100644
--- a/t/integration/tests/test_librabbitmq.py
+++ b/t/oldint/tests/test_librabbitmq.py
diff --git a/t/integration/tests/test_mongodb.py b/t/oldint/tests/test_mongodb.py
index 495208c7..495208c7 100644
--- a/t/integration/tests/test_mongodb.py
+++ b/t/oldint/tests/test_mongodb.py
diff --git a/t/integration/tests/test_pyamqp.py b/t/oldint/tests/test_pyamqp.py
index f7aa762e..f7aa762e 100644
--- a/t/integration/tests/test_pyamqp.py
+++ b/t/oldint/tests/test_pyamqp.py
diff --git a/t/integration/tests/test_qpid.py b/t/oldint/tests/test_qpid.py
index a5c9141d..a5c9141d 100644
--- a/t/integration/tests/test_qpid.py
+++ b/t/oldint/tests/test_qpid.py
diff --git a/t/integration/tests/test_redis.py b/t/oldint/tests/test_redis.py
index 610b0149..610b0149 100644
--- a/t/integration/tests/test_redis.py
+++ b/t/oldint/tests/test_redis.py
diff --git a/t/integration/tests/test_zookeeper.py b/t/oldint/tests/test_zookeeper.py
index 150c4a35..150c4a35 100644
--- a/t/integration/tests/test_zookeeper.py
+++ b/t/oldint/tests/test_zookeeper.py
diff --git a/t/integration/transport.py b/t/oldint/transport.py
index 36b4d33f..36b4d33f 100644
--- a/t/integration/transport.py
+++ b/t/oldint/transport.py
diff --git a/t/unit/__init__.py b/t/unit/__init__.py
index 01e6d4f4..e69de29b 100644
--- a/t/unit/__init__.py
+++ b/t/unit/__init__.py
@@ -1 +0,0 @@
-from __future__ import absolute_import, unicode_literals
diff --git a/t/unit/test_clocks.py b/t/unit/test_clocks.py
index 5ed30bf1..c4f9587c 100644
--- a/t/unit/test_clocks.py
+++ b/t/unit/test_clocks.py
@@ -1,12 +1,7 @@
-from __future__ import absolute_import, unicode_literals
-
import pickle
-
from heapq import heappush
from time import time
-
from case import Mock
-
from kombu.clocks import LamportClock, timetuple
diff --git a/t/unit/test_entity.py b/t/unit/test_entity.py
index 004c98f8..b923d480 100644
--- a/t/unit/test_entity.py
+++ b/t/unit/test_entity.py
@@ -5,7 +5,7 @@ import pytest
from case import Mock, call
from kombu import Connection, Exchange, Producer, Queue, binding
-from kombu.abstract import MaybeChannelBound
+from kombu.abstract import Entity
from kombu.exceptions import NotBoundError
from kombu.serialization import registry
@@ -403,7 +403,7 @@ class test_Queue:
assert 'Queue' in repr(b)
-class test_MaybeChannelBound:
+class test_Entity:
def test_repr(self):
- assert repr(MaybeChannelBound())
+ assert repr(Entity())
diff --git a/t/unit/test_exceptions.py b/t/unit/test_exceptions.py
index f72f3d6d..016364db 100644
--- a/t/unit/test_exceptions.py
+++ b/t/unit/test_exceptions.py
@@ -1,7 +1,4 @@
-from __future__ import absolute_import, unicode_literals
-
from case import Mock
-
from kombu.exceptions import HttpError
diff --git a/t/unit/transport/test_base.py b/t/unit/transport/test_base.py
index f1e0b6c9..3ed9cbda 100644
--- a/t/unit/transport/test_base.py
+++ b/t/unit/transport/test_base.py
@@ -1,4 +1,3 @@
-from __future__ import absolute_import, unicode_literals
import pytest
from case import Mock
from kombu import Connection, Consumer, Exchange, Producer, Queue
diff --git a/t/unit/transport/test_etcd.py b/t/unit/transport/test_etcd.py
index b61fd7a7..5e33b7ca 100644
--- a/t/unit/transport/test_etcd.py
+++ b/t/unit/transport/test_etcd.py
@@ -1,11 +1,6 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
-
+from queue import Empty
from case import Mock, patch, skip
-
-from kombu.five import Empty
-
from kombu.transport.etcd import Channel, Transport
diff --git a/t/unit/transport/test_qpid.py b/t/unit/transport/test_qpid.py
index bf796212..7d87c789 100644
--- a/t/unit/transport/test_qpid.py
+++ b/t/unit/transport/test_qpid.py
@@ -1,5 +1,3 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
import select
import ssl
@@ -7,16 +5,15 @@ import socket
import sys
import time
import uuid
-
from collections import Callable, OrderedDict
from itertools import count
-
+from queue import Empty
+from time import monotonic
from case import Mock, call, patch, skip
-
-from kombu.five import Empty, keys, range, monotonic
-from kombu.transport.qpid import (AuthenticationFailure, Channel, Connection,
- ConnectionError, Message, NotFound, QoS,
- Transport)
+from kombu.transport.qpid import (
+ AuthenticationFailure, Channel, Connection,
+ ConnectionError, Message, NotFound, QoS, Transport,
+)
from kombu.transport.virtual import Base64
@@ -42,15 +39,15 @@ class ExtraAssertionsMixin(object):
Also asserts that the value of each key is the same in a and b using
the is operator.
"""
- assert set(keys(a)) == set(keys(b))
- for key in keys(a):
+ assert set(a.keys()) == set(b.keys())
+ for key in a.keys():
assert a[key] == b[key]
def assertDictContainsSubset(self, a, b, msg=None):
"""
Assert that all the key/value pairs in a exist in b.
"""
- for key in keys(a):
+ for key in a.keys():
assert key in b
assert a[key] == b[key]
diff --git a/t/unit/transport/virtual/test_exchange.py b/t/unit/transport/virtual/test_exchange.py
index 93d48228..bc67a676 100644
--- a/t/unit/transport/virtual/test_exchange.py
+++ b/t/unit/transport/virtual/test_exchange.py
@@ -1,12 +1,7 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
-
from case import Mock
-
from kombu import Connection
from kombu.transport.virtual import exchange
-
from t.mocks import Transport
diff --git a/t/unit/utils/test_compat.py b/t/unit/utils/test_compat.py
index c79b3202..3c7397b5 100644
--- a/t/unit/utils/test_compat.py
+++ b/t/unit/utils/test_compat.py
@@ -1,12 +1,7 @@
-from __future__ import absolute_import, unicode_literals
-
import socket
import sys
import types
-
from case import Mock, mock, patch
-
-from kombu.five import bytes_if_py2
from kombu.utils import compat
from kombu.utils.compat import entrypoints, maybe_fileno
@@ -70,17 +65,15 @@ class test_detect_environment:
def test_detect_environment_no_eventlet_or_gevent(self):
try:
- sys.modules['eventlet'] = types.ModuleType(
- bytes_if_py2('eventlet'))
- sys.modules['eventlet.patcher'] = types.ModuleType(
- bytes_if_py2('patcher'))
+ sys.modules['eventlet'] = types.ModuleType('eventlet')
+ sys.modules['eventlet.patcher'] = types.ModuleType('patcher')
assert compat._detect_environment() == 'default'
finally:
sys.modules.pop('eventlet.patcher', None)
sys.modules.pop('eventlet', None)
compat._detect_environment()
try:
- sys.modules['gevent'] = types.ModuleType(bytes_if_py2('gevent'))
+ sys.modules['gevent'] = types.ModuleType('gevent')
assert compat._detect_environment() == 'default'
finally:
sys.modules.pop('gevent', None)
diff --git a/t/unit/utils/test_div.py b/t/unit/utils/test_div.py
index f0b1a058..bbf60b40 100644
--- a/t/unit/utils/test_div.py
+++ b/t/unit/utils/test_div.py
@@ -1,9 +1,5 @@
-from __future__ import absolute_import, unicode_literals
-
import pickle
-
from io import StringIO, BytesIO
-
from kombu.utils.div import emergency_dump_state
diff --git a/t/unit/utils/test_encoding.py b/t/unit/utils/test_encoding.py
index e3d1040a..bff3a321 100644
--- a/t/unit/utils/test_encoding.py
+++ b/t/unit/utils/test_encoding.py
@@ -1,13 +1,7 @@
# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
import sys
-
from contextlib import contextmanager
-
from case import patch, skip
-
-from kombu.five import bytes_t, string_t
from kombu.utils.encoding import (
get_default_encoding_file, safe_str,
set_default_encoding_file, default_encoding,
@@ -50,13 +44,13 @@ class test_default_encoding:
@skip.if_python3()
def test_str_to_bytes():
with clean_encoding() as e:
- assert isinstance(e.str_to_bytes('foobar'), bytes_t)
+ assert isinstance(e.str_to_bytes('foobar'), bytes)
@skip.if_python3()
def test_from_utf8():
with clean_encoding() as e:
- assert isinstance(e.from_utf8('foobar'), bytes_t)
+ assert isinstance(e.from_utf8('foobar'), bytes)
@skip.if_python3()
@@ -75,7 +69,7 @@ class test_safe_str:
assert safe_str('foo') == 'foo'
def test_when_unicode(self):
- assert isinstance(safe_str('foo'), string_t)
+ assert isinstance(safe_str('foo'), str)
def test_when_encoding_utf8(self):
self._encoding.return_value = 'utf-8'
diff --git a/t/unit/utils/test_scheduling.py b/t/unit/utils/test_scheduling.py
index 79216b56..b6ddeeca 100644
--- a/t/unit/utils/test_scheduling.py
+++ b/t/unit/utils/test_scheduling.py
@@ -1,9 +1,5 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
-
from case import Mock
-
from kombu.utils.scheduling import FairCycle, cycle_by_name
diff --git a/t/unit/utils/test_time.py b/t/unit/utils/test_time.py
index 4f6ddc0b..33d9fd6c 100644
--- a/t/unit/utils/test_time.py
+++ b/t/unit/utils/test_time.py
@@ -1,7 +1,4 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
-
from kombu.utils.time import maybe_s_to_ms
diff --git a/t/unit/utils/test_url.py b/t/unit/utils/test_url.py
index 3d2b0ede..e1bc680b 100644
--- a/t/unit/utils/test_url.py
+++ b/t/unit/utils/test_url.py
@@ -1,7 +1,4 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
-
from kombu.utils.url import as_url, parse_url, maybe_sanitize_url
diff --git a/t/unit/utils/test_utils.py b/t/unit/utils/test_utils.py
index f4668b83..82cd182b 100644
--- a/t/unit/utils/test_utils.py
+++ b/t/unit/utils/test_utils.py
@@ -1,7 +1,4 @@
-from __future__ import absolute_import, unicode_literals
-
import pytest
-
from kombu import version_info_t
from kombu.utils.text import version_string_as_tuple