diff options
author | Ask Solem <ask@celeryproject.org> | 2017-02-16 10:44:48 -0800 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2017-02-16 10:44:48 -0800 |
commit | e6fab2f68b562cf1400bd8167e9b755f0482aafe (patch) | |
tree | b4d6be32dc8c62fa032e3c1a1a74636ac8360a38 /kombu | |
parent | f2f7c67651106e77fb2db60ded134404ccc0a626 (diff) | |
download | kombu-5.0-devel.tar.gz |
WIP5.0-devel
Diffstat (limited to 'kombu')
51 files changed, 2481 insertions, 6395 deletions
diff --git a/kombu/__init__.py b/kombu/__init__.py index 848d2b99..6e0211e1 100644 --- a/kombu/__init__.py +++ b/kombu/__init__.py @@ -71,7 +71,7 @@ for module, items in all_by_module.items(): object_origins[item] = module -class module(ModuleType): +class _module(ModuleType): """Customized Python module.""" def __getattr__(self, name): @@ -100,7 +100,7 @@ except NameError: # pragma: no cover # keep a reference to this module so that it's not garbage collected old_module = sys.modules[__name__] -new_module = sys.modules[__name__] = module(__name__) +new_module = sys.modules[__name__] = _module(__name__) # type: ignore new_module.__dict__.update({ '__file__': __file__, '__path__': __path__, @@ -118,6 +118,7 @@ new_module.__dict__.update({ }) if os.environ.get('KOMBU_LOG_DEBUG'): # pragma: no cover - os.environ.update(KOMBU_LOG_CHANNEL='1', KOMBU_LOG_CONNECTION='1') + os.environ['KOMBU_LOG_CHANNEL'] = '1' + os.environ['KOMBU_LOG_CONNECTION'] = '1' from .utils import debug debug.setup_logging() diff --git a/kombu/abstract.py b/kombu/abstract.py index 916b1eea..4f824415 100644 --- a/kombu/abstract.py +++ b/kombu/abstract.py @@ -1,14 +1,13 @@ -import amqp.abstract - from copy import copy -from typing import Any, Dict +from typing import Any, Dict, Optional from typing import Sequence, Tuple # noqa - +from amqp.types import ChannelT from .connection import maybe_channel from .exceptions import NotBoundError +from .types import EntityT from .utils.functional import ChannelPromise -__all__ = ['Object', 'MaybeChannelBound'] +__all__ = ['Entity'] def unpickle_dict(cls: Any, kwargs: Dict) -> Any: @@ -19,13 +18,16 @@ def _any(v: Any) -> Any: return v -class Object: - """Common base class. - - Supports automatic kwargs->attributes handling, and cloning. - """ +class Entity(EntityT): + """Mixin for classes that can be bound to an AMQP channel.""" attrs = () # type: Sequence[Tuple[str, Any]] + #: Defines whether maybe_declare can skip declaring this entity twice. + can_cache_declaration = False + + _channel: ChannelT + _is_bound: bool = False + def __init__(self, *args, **kwargs) -> None: for name, type_ in self.attrs: value = kwargs.get(name) @@ -37,41 +39,15 @@ class Object: except AttributeError: setattr(self, name, None) - def as_dict(self, recurse: bool=False) -> Dict: - def f(obj: Any, type: Any) -> Any: - if recurse and isinstance(obj, Object): - return obj.as_dict(recurse=True) - return type(obj) if type and obj is not None else obj - return { - attr: f(getattr(self, attr), type) for attr, type in self.attrs - } - - def __reduce__(self) -> Any: - return unpickle_dict, (self.__class__, self.as_dict()) - - def __copy__(self) -> Any: - return self.__class__(**self.as_dict()) - - -class MaybeChannelBound(Object): - """Mixin for classes that can be bound to an AMQP channel.""" - - _channel = None # type: amqp.abstract.Channel - _is_bound = False # type: bool - - #: Defines whether maybe_declare can skip declaring this entity twice. - can_cache_declaration = False - - def __call__(self, channel: amqp.abstract.Channel) -> 'MaybeChannelBound': + def __call__(self, channel: ChannelT) -> EntityT: """`self(channel) -> self.bind(channel)`""" return self.bind(channel) - def bind(self, channel: amqp.abstract.Channel) -> 'MaybeChannelBound': + def bind(self, channel: ChannelT) -> EntityT: """Create copy of the instance that is bound to a channel.""" return copy(self).maybe_bind(channel) - def maybe_bind(self, - channel: amqp.abstract.Channel) -> 'MaybeChannelBound': + def maybe_bind(self, channel: Optional[ChannelT]) -> EntityT: """Bind instance to channel if not already bound.""" if not self.is_bound and channel: self._channel = maybe_channel(channel) @@ -79,7 +55,7 @@ class MaybeChannelBound(Object): self._is_bound = True return self - def revive(self, channel: amqp.abstract.Channel) -> None: + def revive(self, channel: ChannelT) -> None: """Revive channel after the connection has been re-established. Used by :meth:`~kombu.Connection.ensure`. @@ -96,20 +72,35 @@ class MaybeChannelBound(Object): def __repr__(self) -> str: return self._repr_entity(type(self).__name__) - def _repr_entity(self, item: str='') -> str: + def _repr_entity(self, item: str = '') -> str: item = item or type(self).__name__ if self.is_bound: return '<{0} bound to chan:{1}>'.format( item or type(self).__name__, self.channel.channel_id) return '<unbound {0}>'.format(item) + def as_dict(self, recurse: bool = False) -> Dict: + def f(obj: Any, type: Any) -> Any: + if recurse and isinstance(obj, Entity): + return obj.as_dict(recurse=True) + return type(obj) if type and obj is not None else obj + return { + attr: f(getattr(self, attr), type) for attr, type in self.attrs + } + + def __reduce__(self) -> Any: + return unpickle_dict, (self.__class__, self.as_dict()) + + def __copy__(self) -> Any: + return self.__class__(**self.as_dict()) + @property def is_bound(self) -> bool: """Flag set if the channel is bound.""" return self._is_bound and self._channel is not None @property - def channel(self) -> amqp.abstract.Channel: + def channel(self) -> ChannelT: """Current channel if the object is bound.""" channel = self._channel if channel is None: diff --git a/kombu/async/timer.py b/kombu/async/timer.py index 69d93a36..b67f4ddf 100644 --- a/kombu/async/timer.py +++ b/kombu/async/timer.py @@ -3,10 +3,10 @@ import heapq import sys -from collections import namedtuple from datetime import datetime from functools import total_ordering from time import monotonic +from typing import NamedTuple from weakref import proxy as weakrefproxy from vine.utils import wraps @@ -27,7 +27,13 @@ DEFAULT_MAX_INTERVAL = 2 EPOCH = datetime.utcfromtimestamp(0).replace(tzinfo=utc) IS_PYPY = hasattr(sys, 'pypy_version_info') -scheduled = namedtuple('scheduled', ('eta', 'priority', 'entry')) + +class scheduled(NamedTuple): + """Information about scheduled item.""" + + eta: float + priority: int + entry: 'Entry' def to_timestamp(d, default_timezone=utc, time=monotonic): diff --git a/kombu/clocks.py b/kombu/clocks.py index cef4570b..ff82ffb0 100644 --- a/kombu/clocks.py +++ b/kombu/clocks.py @@ -2,9 +2,9 @@ from threading import Lock from itertools import islice from operator import itemgetter -from typing import Any, List, Sequence +from typing import Any, Callable, List, Sequence -__all__ = ['LamportClock', 'timetuple'] +__all__ = ['Clock', 'LamportClock', 'timetuple'] R_CLOCK = '_lamport(clock={0}, timestamp={1}, id={2} {3!r})' @@ -24,7 +24,7 @@ class timetuple(tuple): __slots__ = () def __new__(cls, clock: int, timestamp: float, - id: str, obj: Any=None) -> 'timetuple': + id: str, obj: Any = None) -> 'timetuple': return tuple.__new__(cls, (clock, timestamp, id, obj)) def __repr__(self) -> str: @@ -61,7 +61,54 @@ class timetuple(tuple): obj = property(itemgetter(3)) -class LamportClock: +class Clock: + + value: int = 0 + + def __init__(self, initial_value: int = 0, **kwargs) -> None: + self.value = initial_value + + def adjust(self, other: int) -> int: + raise NotImplementedError() + + def forward(self) -> int: + raise NotImplementedError() + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return '<{name}: {0.value}>'.format(self, name=type(self).__name__) + + def sort_heap(self, h: List[Sequence]) -> Any: + """Sort heap of events. + + List of tuples containing at least two elements, representing + an event, where the first element is the event's scalar clock value, + and the second element is the id of the process (usually + ``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])`` + + The list must already be sorted, which is why we refer to it as a + heap. + + The tuple will not be unpacked, so more than two elements can be + present. + + Will return the latest event. + """ + if h[0][0] == h[1][0]: + same = [] + for PN in zip(h, islice(h, 1, None)): + if PN[0][0] != PN[1][0]: + break # Prev and Next's clocks differ + same.append(PN[0]) + # return first item sorted by process id + return sorted(same, key=lambda event: event[1])[0] + # clock values unique, return first item + return h[0] + + +class LamportClock(Clock): """Lamport's logical clock. From Wikipedia: @@ -100,9 +147,10 @@ class LamportClock: #: The clocks current value. value = 0 - def __init__(self, initial_value: int=0, Lock: Any=Lock) -> None: - self.value = initial_value + def __init__(self, initial_value: int = 0, + Lock: Callable = Lock, **kwargs) -> None: self.mutex = Lock() + super().__init__(initial_value) def adjust(self, other: int) -> int: with self.mutex: @@ -113,36 +161,3 @@ class LamportClock: with self.mutex: self.value += 1 return self.value - - def sort_heap(self, h: List[Sequence]) -> Any: - """Sort heap of events. - - List of tuples containing at least two elements, representing - an event, where the first element is the event's scalar clock value, - and the second element is the id of the process (usually - ``"hostname:pid"``): ``sh([(clock, processid, ...?), (...)])`` - - The list must already be sorted, which is why we refer to it as a - heap. - - The tuple will not be unpacked, so more than two elements can be - present. - - Will return the latest event. - """ - if h[0][0] == h[1][0]: - same = [] - for PN in zip(h, islice(h, 1, None)): - if PN[0][0] != PN[1][0]: - break # Prev and Next's clocks differ - same.append(PN[0]) - # return first item sorted by process id - return sorted(same, key=lambda event: event[1])[0] - # clock values unique, return first item - return h[0] - - def __str__(self) -> str: - return str(self.value) - - def __repr__(self) -> str: - return '<LamportClock: {0.value}>'.format(self) diff --git a/kombu/common.py b/kombu/common.py index 0b1c8bbb..85134482 100644 --- a/kombu/common.py +++ b/kombu/common.py @@ -1,5 +1,4 @@ """Common Utilities.""" -import amqp.abstract import os import socket import threading @@ -15,13 +14,15 @@ from typing import ( ) from uuid import uuid4, uuid3, NAMESPACE_OID -from amqp import RecoverableConnectionError +from amqp import ChannelT, RecoverableConnectionError from .entity import Exchange, Queue from .log import get_logger from .serialization import registry as serializers +from .types import ( + ClientT, ConsumerT, EntityT, MessageT, ProducerT, ResourceT, +) from .utils import abstract -from .utils.typing import Timeout from .utils.uuid import uuid try: @@ -89,10 +90,10 @@ class Broadcast(Queue): attrs = Queue.attrs + (('queue', None),) - def __init__(self, name: Optional[str]=None, - queue: Optional[abstract.Entity]=None, - auto_delete: bool=True, - exchange: Optional[Union[Exchange, str]]=None, + def __init__(self, name: str = None, + queue: EntityT = None, + auto_delete: bool = True, + exchange: Union[Exchange, str] = None, alias: Optional[str]=None, **kwargs) -> None: queue = queue or 'bcast.{0}'.format(uuid()) return super().__init__( @@ -106,15 +107,15 @@ class Broadcast(Queue): ) -def declaration_cached(entity: abstract.Entity, - channel: amqp.abstract.Channel) -> bool: +def declaration_cached(entity: EntityT, + channel: ChannelT) -> bool: return entity in channel.connection.client.declared_entities -def maybe_declare(entity: abstract.Entity, - channel: amqp.abstract.Channel = None, - retry: bool = False, - **retry_policy) -> bool: +async def maybe_declare(entity: EntityT, + channel: ChannelT = None, + retry: bool = False, + **retry_policy) -> bool: """Declare entity (cached).""" is_bound = entity.is_bound orig = entity @@ -135,18 +136,18 @@ def maybe_declare(entity: abstract.Entity, return False if retry: - return _imaybe_declare(entity, declared, ident, - channel, orig, **retry_policy) - return _maybe_declare(entity, declared, ident, channel, orig) + return await _imaybe_declare(entity, declared, ident, + channel, orig, **retry_policy) + return await _maybe_declare(entity, declared, ident, channel, orig) -def _maybe_declare(entity: abstract.Entity, declared: MutableSet, ident: int, - channel: Optional[amqp.abstract.Channel], - orig: abstract.Entity = None) -> bool: +async def _maybe_declare(entity: EntityT, declared: MutableSet, ident: int, + channel: Optional[ChannelT], + orig: EntityT = None) -> bool: channel = channel or entity.channel if not channel.connection: raise RecoverableConnectionError('channel disconnected') - entity.declare(channel=channel) + await entity.declare(channel=channel) if declared is not None and ident: declared.add(ident) if orig is not None: @@ -154,22 +155,22 @@ def _maybe_declare(entity: abstract.Entity, declared: MutableSet, ident: int, return True -def _imaybe_declare(entity: abstract.Entity, - declared: MutableSet, - ident: int, - channel: Optional[amqp.abstract.Channel], - orig: Optional[abstract.Entity]=None, - **retry_policy) -> bool: - return entity.channel.connection.client.ensure( +async def _imaybe_declare(entity: EntityT, + declared: MutableSet, + ident: int, + channel: Optional[ChannelT], + orig: EntityT = None, + **retry_policy) -> bool: + return await entity.channel.connection.client.ensure( entity, _maybe_declare, **retry_policy)( entity, declared, ident, channel, orig) def drain_consumer( - consumer: abstract.Consumer, + consumer: ConsumerT, limit: int = 1, - timeout: Timeout = None, - callbacks: Sequence[Callable] = None) -> Iterator[abstract.Message]: + timeout: float = None, + callbacks: Sequence[Callable] = None) -> Iterator[MessageT]: """Drain messages from consumer instance.""" acc = deque() @@ -189,12 +190,12 @@ def drain_consumer( def itermessages( conn: abstract.Connection, - channel: Optional[amqp.abstract.Channel], - queue: abstract.Entity, + channel: Optional[ChannelT], + queue: EntityT, limit: int = 1, - timeout: Timeout = None, + timeout: float = None, callbacks: Sequence[Callable] = None, - **kwargs) -> Iterator[abstract.Message]: + **kwargs) -> Iterator[MessageT]: """Iterator over messages.""" return drain_consumer( conn.Consumer(queues=[queue], channel=channel, **kwargs), @@ -202,9 +203,9 @@ def itermessages( ) -def eventloop(conn: abstract.Connection, +def eventloop(conn: ClientT, limit: int = None, - timeout: Timeout = None, + timeout: float = None, ignore_timeouts: bool = False) -> Iterator[Any]: """Best practice generator wrapper around ``Connection.drain_events``. @@ -243,9 +244,9 @@ def eventloop(conn: abstract.Connection, raise -def send_reply(exchange: Union[abstract.Exchange, str], - req: abstract.Message, msg: Any, - producer: abstract.Producer = None, +def send_reply(exchange: Union[Exchange, str], + req: MessageT, msg: Any, + producer: ProducerT = None, retry: bool = False, retry_policy: Mapping = None, **props) -> None: """Send reply for request. @@ -271,9 +272,9 @@ def send_reply(exchange: Union[abstract.Exchange, str], def collect_replies( - conn: abstract.Connection, - channel: Optional[amqp.abstract.Channel], - queue: abstract.Entity, + conn: ClientT, + channel: Optional[ChannelT], + queue: EntityT, *args, **kwargs) -> Iterator[Any]: """Generator collecting replies from ``queue``.""" no_ack = kwargs.setdefault('no_ack', True) @@ -305,7 +306,7 @@ def _ignore_errors(conn) -> Iterator: pass -def ignore_errors(conn: abstract.Connection, +def ignore_errors(conn: ClientT, fun: Callable = None, *args, **kwargs) -> Any: """Ignore connection and channel errors. @@ -339,14 +340,14 @@ def ignore_errors(conn: abstract.Connection, return _ignore_errors(conn) -def revive_connection(connection: abstract.Connection, - channel: amqp.abstract.Channel, +def revive_connection(connection: ClientT, + channel: ChannelT, on_revive: Callable = None) -> None: if on_revive: on_revive(channel) -def insured(pool: abstract.Resource, +def insured(pool: ResourceT, fun: Callable, args: Sequence, kwargs: Dict, errback: Callable = None, diff --git a/kombu/compression.py b/kombu/compression.py index 8e5ad99a..cec0282c 100644 --- a/kombu/compression.py +++ b/kombu/compression.py @@ -11,15 +11,15 @@ __all__ = [ 'get_decoder', 'compress', 'decompress', ] -TEncoder = Callable[[bytes], bytes] -TDecoder = Callable[[bytes], bytes] +EncoderT = Callable[[bytes], bytes] +DecoderT = Callable[[bytes], bytes] -_aliases = {} # type: MutableMapping[str, str] -_encoders = {} # type: MutableMapping[str, TEncoder] -_decoders = {} # type: MutableMapping[str, TDecoder] +_aliases: MutableMapping[str, str] = {} +_encoders: MutableMapping[str, EncoderT] = {} +_decoders: MutableMapping[str, DecoderT] = {} -def register(encoder: TEncoder, decoder: TDecoder, content_type: str, +def register(encoder: EncoderT, decoder: DecoderT, content_type: str, aliases: Sequence[str]=[]) -> None: """Register new compression method. @@ -42,13 +42,13 @@ def encoders() -> Sequence[str]: return list(_encoders) -def get_encoder(t: str) -> Tuple[TEncoder, str]: +def get_encoder(t: str) -> Tuple[EncoderT, str]: """Get encoder by alias name.""" t = _aliases.get(t, t) return _encoders[t], t -def get_decoder(t: str) -> TDecoder: +def get_decoder(t: str) -> DecoderT: """Get decoder by alias name.""" return _decoders[_aliases.get(t, t)] diff --git a/kombu/connection.py b/kombu/connection.py index c4c6ff7d..11970ea1 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -7,6 +7,9 @@ from collections import OrderedDict from contextlib import contextmanager from itertools import count, cycle from operator import itemgetter +from typing import Any, Callable, Iterable, Mapping, Set, Sequence, Union + +from amqp.types import ChannelT, ConnectionT # jython breaks on relative import for .exceptions for some reason # (Issue #112) @@ -15,7 +18,7 @@ from kombu import exceptions from .log import get_logger from .resource import Resource from .transport import get_transport_cls, supports_librabbitmq -from .utils import abstract +from .types import ClientT, EntityT, ResourceT, TransportT from .utils.collections import HashedSeq from .utils.functional import dictfilter, lazy, retry_over_time, shufflecycle from .utils.objects import cached_property @@ -27,22 +30,21 @@ logger = get_logger(__name__) roundrobin_failover = cycle -resolve_aliases = { +resolve_aliases: Mapping[str, str] = { 'pyamqp': 'amqp', 'librabbitmq': 'amqp', } -failover_strategies = { +failover_strategies: Mapping[str, Callable] = { 'round-robin': roundrobin_failover, 'shuffle': shufflecycle, } -_log_connection = os.environ.get('KOMBU_LOG_CONNECTION', False) -_log_channel = os.environ.get('KOMBU_LOG_CHANNEL', False) +_log_connection = bool(os.environ.get('KOMBU_LOG_CONNECTION', False)) +_log_channel = bool(os.environ.get('KOMBU_LOG_CHANNEL', False)) -@abstract.Connection.register -class Connection: +class Connection(ClientT): """A connection to the broker. Example: @@ -106,48 +108,65 @@ class Connection: :keyword port: Default port if not provided in the URL. """ - port = None - virtual_host = '/' - connect_timeout = 5 + hostname: str + userid: str + password: str + ssl: Any + login_method: str + port: int = None + virtual_host: str = '/' + connect_timeout: float = 5.0 - _closed = None - _connection = None - _default_channel = None - _transport = None - _logger = False - uri_prefix = None + uri_prefix: str = None #: The cache of declared entities is per connection, #: in case the server loses data. - declared_entities = None + declared_entities: Set[EntityT] = None #: Iterator returning the next broker URL to try in the event #: of connection failure (initialized by :attr:`failover_strategy`). - cycle = None + cycle: Iterable = None #: Additional transport specific options, #: passed on to the transport instance. - transport_options = None + transport_options: Mapping = None #: Strategy used to select new hosts when reconnecting after connection #: failure. One of "round-robin", "shuffle" or any custom iterator #: constantly yielding new URLs to try. - failover_strategy = 'round-robin' + failover_strategy: str = 'round-robin' #: Heartbeat value, currently only supported by the py-amqp transport. - heartbeat = None + heartbeat: float = None - resolve_aliases = resolve_aliases - failover_strategies = failover_strategies + resolve_aliases: Mapping[str, str] = resolve_aliases + failover_strategies: Mapping[str, Callable] = failover_strategies - hostname = userid = password = ssl = login_method = None + _closed: bool = None + _connection: ConnectionT = None + _default_channel: ChannelT = None + _transport: TransportT = None + _logger: bool = False + _initial_params: Mapping - def __init__(self, hostname='localhost', userid=None, - password=None, virtual_host=None, port=None, insist=False, - ssl=False, transport=None, connect_timeout=5, - transport_options=None, login_method=None, uri_prefix=None, - heartbeat=0, failover_strategy='round-robin', - alternates=None, **kwargs): + def __init__( + self, + hostname: str = 'localhost', + userid: str = None, + password: str = None, + virtual_host: str = None, + port: int = None, + insist: bool = False, + ssl: Any = None, + transport: Union[type, str] = None, + connect_timeout: float = 5.0, + transport_options: Mapping = None, + login_method: str = None, + uri_prefix: str = None, + heartbeat: float = None, + failover_strategy: str = 'round-robin', + alternates: Sequence[str] = None, + **kwargs): alt = [] if alternates is None else alternates # have to spell the args out, just to get nice docstrings :( params = self._initial_params = { @@ -208,7 +227,7 @@ class Connection: self.declared_entities = set() - def switch(self, url): + def switch(self, url: str) -> None: """Switch connection parameters to use a new URL. Note: @@ -219,14 +238,15 @@ class Connection: self._closed = False self._init_params(**dict(self._initial_params, **parse_url(url))) - def maybe_switch_next(self): + def maybe_switch_next(self) -> None: """Switch to next URL given by the current failover strategy.""" if self.cycle: self.switch(next(self.cycle)) - def _init_params(self, hostname, userid, password, virtual_host, port, - insist, ssl, transport, connect_timeout, - login_method, heartbeat): + def _init_params(self, hostname: str, userid: str, password: str, + virtual_host: str, port: int, insist: bool, ssl: Any, + transport: Union[type, str], connect_timeout: float, + login_method: str, heartbeat: float) -> None: transport = transport or 'amqp' if transport == 'amqp' and supports_librabbitmq(): transport = 'librabbitmq' @@ -245,18 +265,20 @@ class Connection: def register_with_event_loop(self, loop): self.transport.register_with_event_loop(self.connection, loop) - def _debug(self, msg, *args, **kwargs): + def _debug(self, msg: str, *args, **kwargs) -> None: if self._logger: # pragma: no cover fmt = '[Kombu connection:{id:#x}] {msg}' logger.debug(fmt.format(id=id(self), msg=str(msg)), *args, **kwargs) - def connect(self): + async def connect(self) -> None: """Establish connection to server immediately.""" self._closed = False - return self.connection + if self._connection is None: + await self._start_connection() + return self._connection - def channel(self): + def channel(self) -> ChannelT: """Create and return a new channel.""" self._debug('create channel') chan = self.transport.create_channel(self.connection) @@ -266,7 +288,7 @@ class Connection: '[Kombu channel:{0.channel_id}] ') return chan - def heartbeat_check(self, rate=2): + async def heartbeat_check(self, rate: int = 2) -> None: """Check heartbeats. Allow the transport to perform any periodic tasks @@ -283,9 +305,9 @@ class Connection: is called every 3 / 2 seconds, then the rate is 2. This value is currently unused by any transports. """ - return self.transport.heartbeat_check(self.connection, rate=rate) + await self.transport.heartbeat_check(self.connection, rate=rate) - def drain_events(self, **kwargs): + async def drain_events(self, timeout: float = None, **kwargs) -> None: """Wait for a single event from the server. Arguments: @@ -294,30 +316,31 @@ class Connection: Raises: socket.timeout: if the timeout is exceeded. """ - return self.transport.drain_events(self.connection, **kwargs) + await self.transport.drain_events( + self.connection, timeout=timeout, **kwargs) - def maybe_close_channel(self, channel): + async def maybe_close_channel(self, channel: ChannelT) -> None: """Close given channel, but ignore connection and channel errors.""" try: - channel.close() + await channel.close() except (self.connection_errors + self.channel_errors): pass - def _do_close_self(self): + async def _do_close_self(self): # Close only connection and channel(s), but not transport. self.declared_entities.clear() if self._default_channel: - self.maybe_close_channel(self._default_channel) + await self.maybe_close_channel(self._default_channel) if self._connection: try: - self.transport.close_connection(self._connection) + await self.transport.close_connection(self._connection) except self.connection_errors + (AttributeError, socket.error): pass self._connection = None - def _close(self): + async def _close(self): """Really close connection, even if part of a connection pool.""" - self._do_close_self() + await self._do_close_self() self._do_close_transport() self._debug('closed') self._closed = True @@ -327,7 +350,7 @@ class Connection: self._transport.client = None self._transport = None - def collect(self, socket_timeout=None): + async def collect(self, socket_timeout=None): # amqp requires communication to close, we don't need that just # to clear out references, Transport._collect can also be implemented # by other transports that want fast after fork @@ -337,7 +360,7 @@ class Connection: _timeo = socket.getdefaulttimeout() socket.setdefaulttimeout(socket_timeout) try: - self._do_close_self() + await self._do_close_self() except socket.timeout: pass finally: @@ -349,14 +372,22 @@ class Connection: self.declared_entities.clear() self._connection = None - def release(self): + async def release(self): """Close the connection (if open).""" - self._close() - close = release + await self._close() + + async def close(self): + await self._close() - def ensure_connection(self, errback=None, max_retries=None, - interval_start=2, interval_step=2, interval_max=30, - callback=None, reraise_as_library_errors=True): + async def ensure_connection( + self, + errback=None, + max_retries=None, + interval_start=2, + interval_step=2, + interval_max=30, + callback=None, + reraise_as_library_errors=True): """Ensure we have a connection to the server. If not retry establishing the connection with the settings @@ -395,10 +426,12 @@ class Connection: if not reraise_as_library_errors: ctx = self._dummy_context with ctx(): - retry_over_time(self.connect, self.recoverable_connection_errors, - (), {}, on_error, max_retries, - interval_start, interval_step, interval_max, - callback) + await retry_over_time( + self.connect, self.recoverable_connection_errors, + (), {}, on_error, max_retries, + interval_start, interval_step, interval_max, + callback, + ) return self @contextmanager @@ -423,19 +456,23 @@ class Connection: """Return true if the cycle is complete after number of `retries`.""" return not (retries + 1) % len(self.alt) if self.alt else True - def revive(self, new_channel): + async def revive(self, new_channel): """Revive connection after connection re-established.""" if self._default_channel and new_channel is not self._default_channel: - self.maybe_close_channel(self._default_channel) + await self.maybe_close_channel(self._default_channel) self._default_channel = None def _default_ensure_callback(self, exc, interval): logger.error("Ensure: Operation error: %r. Retry in %ss", exc, interval, exc_info=True) - def ensure(self, obj, fun, errback=None, max_retries=None, - interval_start=1, interval_step=1, interval_max=1, - on_revive=None): + async def ensure(self, obj, fun, + errback=None, + max_retries=None, + interval_start=1, + interval_step=1, + interval_max=1, + on_revive=None): """Ensure operation completes. Regardless of any channel/connection errors occurring. @@ -475,7 +512,7 @@ class Connection: ... errback=errback, max_retries=3) >>> publish({'hello': 'world'}, routing_key='dest') """ - def _ensured(*args, **kwargs): + async def _ensured(*args, **kwargs): got_connection = 0 conn_errors = self.recoverable_connection_errors chan_errors = self.recoverable_channel_errors @@ -485,7 +522,7 @@ class Connection: with self._reraise_as_library_errors(): for retries in count(0): # for infinity try: - return fun(*args, **kwargs) + return await fun(*args, **kwargs) except conn_errors as exc: if got_connection and not has_modern_errors: # transport can not distinguish between @@ -497,21 +534,21 @@ class Connection: raise self._debug('ensure connection error: %r', exc, exc_info=1) - self.collect() + await self.collect() errback and errback(exc, 0) remaining_retries = None if max_retries is not None: remaining_retries = max(max_retries - retries, 1) - self.ensure_connection( + await self.ensure_connection( errback, remaining_retries, interval_start, interval_step, interval_max, reraise_as_library_errors=False, ) channel = self.default_channel - obj.revive(channel) + await obj.revive(channel) if on_revive: - on_revive(channel) + await on_revive(channel) got_connection += 1 except chan_errors as exc: if max_retries is not None and retries > max_retries: @@ -641,7 +678,7 @@ class Connection: sanitize=not include_password, mask=mask, ) - def Pool(self, limit=None, **kwargs): + def Pool(self, limit=None, **kwargs) -> ResourceT: """Pool of connections. See Also: @@ -667,7 +704,7 @@ class Connection: """ return ConnectionPool(self, limit, **kwargs) - def ChannelPool(self, limit=None, **kwargs): + def ChannelPool(self, limit=None, **kwargs) -> ResourceT: """Pool of channels. See Also: @@ -746,9 +783,9 @@ class Connection: return SimpleBuffer(channel or self, name, no_ack, queue_opts, exchange_opts, **kwargs) - def _establish_connection(self): + async def _establish_connection(self): self._debug('establishing connection...') - conn = self.transport.establish_connection() + conn = await self.transport.establish_connection() self._debug('connection established: %r', self) return conn @@ -765,11 +802,19 @@ class Connection: return self.__class__, tuple(self.info().values()), None def __enter__(self): + self.connect() return self def __exit__(self, *args): self.release() + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + @property def qos_semantics_matches_spec(self): return self.transport.qos_semantics_matches_spec(self.connection) @@ -781,6 +826,12 @@ class Connection: self._connection is not None and self.transport.verify_connection(self._connection)) + async def _start_connection(self): + self.declared_entities.clear() + self._default_channel = None + self._connection = await self._establish_connection() + self._closed = False + @property def connection(self): """The underlying connection object. @@ -789,13 +840,7 @@ class Connection: This instance is transport specific, so do not depend on the interface of this object. """ - if not self._closed: - if not self.connected: - self.declared_entities.clear() - self._default_channel = None - self._connection = self._establish_connection() - self._closed = False - return self._connection + return self._connection @property def default_channel(self): diff --git a/kombu/entity.py b/kombu/entity.py index 52ae605e..80522b55 100644 --- a/kombu/entity.py +++ b/kombu/entity.py @@ -1,42 +1,46 @@ """Exchange and Queue declarations.""" import numbers - -from .abstract import MaybeChannelBound +from typing import ( + Any, Callable, Dict, Mapping, Set, Sequence, Tuple, Union, +) +from amqp.protocol import queue_declare_ok_t +from amqp.types import ChannelT +from .types import EntityT from .exceptions import ContentDisallowed from .serialization import prepare_accept_content -from .utils import abstract - -TRANSIENT_DELIVERY_MODE = 1 -PERSISTENT_DELIVERY_MODE = 2 -DELIVERY_MODES = {'transient': TRANSIENT_DELIVERY_MODE, - 'persistent': PERSISTENT_DELIVERY_MODE} +from .types import BindingT, ExchangeT, MessageT, QueueT __all__ = ['Exchange', 'Queue', 'binding', 'maybe_delivery_mode'] -INTERNAL_EXCHANGE_PREFIX = ('amq.',) +TRANSIENT_DELIVERY_MODE = 1 +PERSISTENT_DELIVERY_MODE = 2 +DELIVERY_MODES: Mapping[str, int] = { + 'transient': TRANSIENT_DELIVERY_MODE, + 'persistent': PERSISTENT_DELIVERY_MODE, +} +INTERNAL_EXCHANGE_PREFIX: Tuple[str] = ('amq.',) -def _reprstr(s): - s = repr(s) - if isinstance(s, str) and s.startswith("u'"): - return s[2:-1] - return s[1:-1] +def _reprstr(s: Any) -> str: + return repr(s).strip("'") -def pretty_bindings(bindings): +def pretty_bindings(bindings: Sequence) -> str: return '[{0}]'.format(', '.join(map(str, bindings))) def maybe_delivery_mode( - v, modes=DELIVERY_MODES, default=PERSISTENT_DELIVERY_MODE): + v: Union[numbers.Integral, str], + *, + modes: Mapping[str, int] = DELIVERY_MODES, + default: int = PERSISTENT_DELIVERY_MODE) -> int: """Get delivery mode by name (or none if undefined).""" if v: return v if isinstance(v, numbers.Integral) else modes[v] return default -@abstract.Entity.register -class Exchange(MaybeChannelBound): +class Exchange(ExchangeT): """An Exchange declaration. Arguments: @@ -137,7 +141,7 @@ class Exchange(MaybeChannelBound): durable = True auto_delete = False passive = False - delivery_mode = None + delivery_mode = None # type: int no_declare = False attrs = ( @@ -151,21 +155,29 @@ class Exchange(MaybeChannelBound): ('no_declare', bool), ) - def __init__(self, name='', type='', channel=None, **kwargs): + def __init__(self, + name: str = '', + type: str = '', + channel: ChannelT = None, + **kwargs) -> None: super().__init__(**kwargs) self.name = name or self.name self.type = type or self.type self.maybe_bind(channel) - def __hash__(self): + def __hash__(self) -> int: return hash('E|%s' % (self.name,)) - def _can_declare(self): + def _can_declare(self) -> bool: return not self.no_declare and ( self.name and not self.name.startswith( INTERNAL_EXCHANGE_PREFIX)) - def declare(self, nowait=False, passive=None, channel=None): + async def declare( + self, + nowait: bool = False, + passive: bool = None, + channel: ChannelT = None) -> None: """Declare the exchange. Creates the exchange on the broker, unless passive is set @@ -177,14 +189,20 @@ class Exchange(MaybeChannelBound): """ if self._can_declare(): passive = self.passive if passive is None else passive - return (channel or self.channel).exchange_declare( + await (channel or self.channel).exchange_declare( exchange=self.name, type=self.type, durable=self.durable, auto_delete=self.auto_delete, arguments=self.arguments, nowait=nowait, passive=passive, ) - def bind_to(self, exchange='', routing_key='', - arguments=None, nowait=False, channel=None, **kwargs): + async def bind_to( + self, + exchange: Union[str, ExchangeT] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None, + **kwargs) -> None: """Bind the exchange to another exchange. Arguments: @@ -192,9 +210,9 @@ class Exchange(MaybeChannelBound): will not block waiting for a response. Default is :const:`False`. """ - if isinstance(exchange, Exchange): + if isinstance(exchange, ExchangeT): exchange = exchange.name - return (channel or self.channel).exchange_bind( + await (channel or self.channel).exchange_bind( destination=self.name, source=exchange, routing_key=routing_key, @@ -202,12 +220,17 @@ class Exchange(MaybeChannelBound): arguments=arguments, ) - def unbind_from(self, source='', routing_key='', - nowait=False, arguments=None, channel=None): + async def unbind_from( + self, + source: Union[str, ExchangeT] = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None, + channel: ChannelT = None) -> None: """Delete previously created exchange binding from the server.""" - if isinstance(source, Exchange): + if isinstance(source, ExchangeT): source = source.name - return (channel or self.channel).exchange_unbind( + await (channel or self.channel).exchange_unbind( destination=self.name, source=source, routing_key=routing_key, @@ -215,43 +238,11 @@ class Exchange(MaybeChannelBound): arguments=arguments, ) - def Message(self, body, delivery_mode=None, properties=None, **kwargs): - """Create message instance to be sent with :meth:`publish`. - - Arguments: - body (Any): Message body. - - delivery_mode (bool): Set custom delivery mode. - Defaults to :attr:`delivery_mode`. - - priority (int): Message priority, 0 to broker configured - max priority, where higher is better. - - content_type (str): The messages content_type. If content_type - is set, no serialization occurs as it is assumed this is either - a binary object, or you've done your own serialization. - Leave blank if using built-in serialization as our library - properly sets content_type. - - content_encoding (str): The character set in which this object - is encoded. Use "binary" if sending in raw binary objects. - Leave blank if using built-in serialization as our library - properly sets content_encoding. - - properties (Dict): Message properties. - - headers (Dict): Message headers. - """ - # XXX This method is unused by kombu itself AFAICT [ask]. - properties = {} if properties is None else properties - properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode) - return self.channel.prepare_message( - body, - properties=properties, - **kwargs) - - def publish(self, message, routing_key=None, mandatory=False, - immediate=False, exchange=None): + async def publish(self, message: Union[MessageT, str, bytes], + routing_key: str = None, + mandatory: bool = False, + immediate: bool = False, + exchange: str = None) -> None: """Publish message. Arguments: @@ -264,7 +255,7 @@ class Exchange(MaybeChannelBound): if isinstance(message, str): message = self.Message(message) exchange = exchange or self.name - return self.channel.basic_publish( + await self.channel.basic_publish( message, exchange=exchange, routing_key=routing_key, @@ -272,7 +263,10 @@ class Exchange(MaybeChannelBound): immediate=immediate, ) - def delete(self, if_unused=False, nowait=False): + async def delete( + self, + if_unused: bool = False, + nowait: bool = False) -> None: """Delete the exchange declaration on server. Arguments: @@ -281,14 +275,20 @@ class Exchange(MaybeChannelBound): nowait (bool): If set the server will not respond, and a response will not be waited for. Default is :const:`False`. """ - return self.channel.exchange_delete(exchange=self.name, - if_unused=if_unused, - nowait=nowait) + await self.channel.exchange_delete( + exchange=self.name, + if_unused=if_unused, + nowait=nowait, + ) - def binding(self, routing_key='', arguments=None, unbind_arguments=None): + def binding( + self, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> 'binding': return binding(self, routing_key, arguments, unbind_arguments) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Exchange): return (self.name == other.name and self.type == other.type and @@ -298,28 +298,70 @@ class Exchange(MaybeChannelBound): self.delivery_mode == other.delivery_mode) return NotImplemented - def __ne__(self, other): - return not self.__eq__(other) + def __ne__(self, other: Any) -> bool: + ret = self.__eq__(other) + if ret is NotImplemented: + return ret + return not ret - def __repr__(self): + def __repr__(self) -> str: return self._repr_entity(self) - def __str__(self): - return 'Exchange {0}({1})'.format( + def __str__(self) -> str: + return '{name} {0}({1})'.format( _reprstr(self.name) or repr(''), self.type, + name=type(self).__name__, ) + def Message( + self, body: Any, + delivery_mode: Union[str, int] = None, + properties: Mapping = None, + **kwargs) -> Any: + """Create message instance to be sent with :meth:`publish`. + + Arguments: + body (Any): Message body. + + delivery_mode (Union[str, int]): Set custom delivery mode. + Defaults to :attr:`delivery_mode`. + + priority (int): Message priority, 0 to broker configured + max priority, where higher is better. + + content_type (str): The messages content_type. If content_type + is set, no serialization occurs as it is assumed this is either + a binary object, or you've done your own serialization. + Leave blank if using built-in serialization as our library + properly sets content_type. + + content_encoding (str): The character set in which this object + is encoded. Use "binary" if sending in raw binary objects. + Leave blank if using built-in serialization as our library + properly sets content_encoding. + + properties (Dict): Message properties. + + headers (Dict): Message headers. + """ + # XXX This method is unused by kombu itself AFAICT [ask]. + properties = {} if properties is None else properties + properties['delivery_mode'] = maybe_delivery_mode(self.delivery_mode) + return self.channel.prepare_message( + body, + properties=properties, + **kwargs) + @property - def can_cache_declaration(self): + def can_cache_declaration(self) -> bool: return not self.auto_delete -@abstract.Entity.register -class binding: +class binding(BindingT): """Represents a queue or exchange binding. Arguments: - exchange (Exchange): Exchange to bind to. + exchange (ExchangeT): Exchange to bind to. routing_key (str): Routing key used as binding key. arguments (Dict): Arguments for bind operation. unbind_arguments (Dict): Arguments for unbind operation. @@ -332,45 +374,58 @@ class binding: ('unbind_arguments', None) ) - def __init__(self, exchange=None, routing_key='', - arguments=None, unbind_arguments=None): + def __init__( + self, + exchange: ExchangeT = None, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> None: self.exchange = exchange self.routing_key = routing_key self.arguments = arguments self.unbind_arguments = unbind_arguments - def declare(self, channel, nowait=False): + def declare(self, channel: ChannelT, nowait: bool = False) -> None: """Declare destination exchange.""" if self.exchange and self.exchange.name: self.exchange.declare(channel=channel, nowait=nowait) - def bind(self, entity, nowait=False, channel=None): + def bind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: """Bind entity to this binding.""" - entity.bind_to(exchange=self.exchange, - routing_key=self.routing_key, - arguments=self.arguments, - nowait=nowait, - channel=channel) + entity.bind_to( + exchange=self.exchange, + routing_key=self.routing_key, + arguments=self.arguments, + nowait=nowait, + channel=channel, + ) - def unbind(self, entity, nowait=False, channel=None): + def unbind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: """Unbind entity from this binding.""" - entity.unbind_from(self.exchange, - routing_key=self.routing_key, - arguments=self.unbind_arguments, - nowait=nowait, - channel=channel) + entity.unbind_from( + self.exchange, + routing_key=self.routing_key, + arguments=self.unbind_arguments, + nowait=nowait, + channel=channel + ) - def __repr__(self): + def __repr__(self) -> str: return '<binding: {0}>'.format(self) - def __str__(self): + def __str__(self) -> str: return '{0}->{1}'.format( _reprstr(self.exchange.name), _reprstr(self.routing_key), ) -@abstract.Entity.register -class Queue(MaybeChannelBound): +class Queue(QueueT): """A Queue declaration. Arguments: @@ -561,9 +616,15 @@ class Queue(MaybeChannelBound): ('max_priority', int) ) - def __init__(self, name='', exchange=None, routing_key='', - channel=None, bindings=None, on_declared=None, - **kwargs): + def __init__( + self, + name: str = '', + exchange: ExchangeT = None, + routing_key: str = '', + channel: ChannelT = None, + bindings: Sequence[BindingT] = None, + on_declared: Callable = None, + **kwargs) -> None: super().__init__(**kwargs) self.name = name or self.name self.exchange = exchange or self.exchange @@ -582,44 +643,56 @@ class Queue(MaybeChannelBound): self.auto_delete = True self.maybe_bind(channel) - def bind(self, channel): + def bind(self, channel: ChannelT) -> EntityT: on_declared = self.on_declared bound = super().bind(channel) bound.on_declared = on_declared return bound - def __hash__(self): + def __hash__(self) -> int: return hash('Q|%s' % (self.name,)) - def when_bound(self): + def when_bound(self) -> None: if self.exchange: self.exchange = self.exchange(self.channel) - def declare(self, nowait=False, channel=None): + async def declare(self, + nowait: bool = False, + channel: ChannelT = None) -> str: """Declare queue and exchange then binds queue to exchange.""" if not self.no_declare: # - declare main binding. - self._create_exchange(nowait=nowait, channel=channel) - self._create_queue(nowait=nowait, channel=channel) - self._create_bindings(nowait=nowait, channel=channel) + await self._create_exchange(nowait=nowait, channel=channel) + await self._create_queue(nowait=nowait, channel=channel) + await self._create_bindings(nowait=nowait, channel=channel) return self.name - def _create_exchange(self, nowait=False, channel=None): + async def _create_exchange(self, + nowait: bool = False, + channel: ChannelT = None) -> None: if self.exchange: - self.exchange.declare(nowait=nowait, channel=channel) + await self.exchange.declare(nowait=nowait, channel=channel) - def _create_queue(self, nowait=False, channel=None): - self.queue_declare(nowait=nowait, passive=False, channel=channel) + async def _create_queue(self, + nowait: bool = False, + channel: ChannelT = None) -> None: + await self.queue_declare(nowait=nowait, passive=False, channel=channel) if self.exchange and self.exchange.name: - self.queue_bind(nowait=nowait, channel=channel) + await self.queue_bind(nowait=nowait, channel=channel) - def _create_bindings(self, nowait=False, channel=None): + async def _create_bindings(self, + nowait: bool = False, + channel: ChannelT = None) -> None: for B in self.bindings: channel = channel or self.channel - B.declare(channel) - B.bind(self, nowait=nowait, channel=channel) - - def queue_declare(self, nowait=False, passive=False, channel=None): + await B.declare(channel) + await B.bind(self, nowait=nowait, channel=channel) + + async def queue_declare( + self, + nowait: bool = False, + passive: bool = False, + channel: ChannelT = None) -> queue_declare_ok_t: """Declare queue on the server. Arguments: @@ -637,7 +710,7 @@ class Queue(MaybeChannelBound): max_length_bytes=self.max_length_bytes, max_priority=self.max_priority, ) - ret = channel.queue_declare( + ret = await channel.queue_declare( queue=self.name, passive=passive, durable=self.durable, @@ -652,18 +725,26 @@ class Queue(MaybeChannelBound): self.on_declared(*ret) return ret - def queue_bind(self, nowait=False, channel=None): + async def queue_bind( + self, + nowait: bool = False, + channel: ChannelT = None) -> None: """Create the queue binding on the server.""" - return self.bind_to(self.exchange, self.routing_key, - self.binding_arguments, - channel=channel, nowait=nowait) - - def bind_to(self, exchange='', routing_key='', - arguments=None, nowait=False, channel=None): + await self.bind_to( + self.exchange, self.routing_key, self.binding_arguments, + channel=channel, nowait=nowait) + + async def bind_to( + self, + exchange: Union[str, ExchangeT] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: if isinstance(exchange, Exchange): exchange = exchange.name - return (channel or self.channel).queue_bind( + await (channel or self.channel).queue_bind( queue=self.name, exchange=exchange, routing_key=routing_key, @@ -671,7 +752,10 @@ class Queue(MaybeChannelBound): nowait=nowait, ) - def get(self, no_ack=None, accept=None): + async def get( + self, + no_ack: bool = None, + accept: Set[str] = None) -> MessageT: """Poll the server for a new message. This method provides direct access to the messages in a @@ -686,10 +770,10 @@ class Queue(MaybeChannelBound): Arguments: no_ack (bool): If enabled the broker will automatically ack messages. - accept (Set[str]): Custom list of accepted content types. + accept (Container): Custom list of accepted content types. """ no_ack = self.no_ack if no_ack is None else no_ack - message = self.channel.basic_get(queue=self.name, no_ack=no_ack) + message = await self.channel.basic_get(queue=self.name, no_ack=no_ack) if message is not None: m2p = getattr(self.channel, 'message_to_python', None) if m2p: @@ -699,13 +783,17 @@ class Queue(MaybeChannelBound): message.accept = prepare_accept_content(accept) return message - def purge(self, nowait=False): + async def purge(self, nowait: bool = False) -> int: """Remove all ready messages from the queue.""" return self.channel.queue_purge(queue=self.name, nowait=nowait) or 0 - def consume(self, consumer_tag='', callback=None, - no_ack=None, nowait=False): + async def consume( + self, + consumer_tag: str = '', + callback: Callable = None, + no_ack: bool = None, + nowait: bool = False) -> None: """Start a queue consumer. Consumers last as long as the channel they were created on, or @@ -726,7 +814,7 @@ class Queue(MaybeChannelBound): """ if no_ack is None: no_ack = self.no_ack - return self.channel.basic_consume( + await self.channel.basic_consume( queue=self.name, no_ack=no_ack, consumer_tag=consumer_tag or '', @@ -734,11 +822,15 @@ class Queue(MaybeChannelBound): nowait=nowait, arguments=self.consumer_arguments) - def cancel(self, consumer_tag): + async def cancel(self, consumer_tag: str) -> None: """Cancel a consumer by consumer tag.""" - return self.channel.basic_cancel(consumer_tag) + await self.channel.basic_cancel(consumer_tag) - def delete(self, if_unused=False, if_empty=False, nowait=False): + async def delete( + self, + if_unused: bool = False, + if_empty: bool = False, + nowait: bool = False) -> None: """Delete the queue. Arguments: @@ -751,19 +843,30 @@ class Queue(MaybeChannelBound): nowait (bool): Do not wait for a reply. """ - return self.channel.queue_delete(queue=self.name, - if_unused=if_unused, - if_empty=if_empty, - nowait=nowait) - - def queue_unbind(self, arguments=None, nowait=False, channel=None): - return self.unbind_from(self.exchange, self.routing_key, - arguments, nowait, channel) + await self.channel.queue_delete( + queue=self.name, + if_unused=if_unused, + if_empty=if_empty, + nowait=nowait, + ) - def unbind_from(self, exchange='', routing_key='', - arguments=None, nowait=False, channel=None): + async def queue_unbind( + self, + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + await self.unbind_from(self.exchange, self.routing_key, + arguments, nowait, channel) + + async def unbind_from( + self, + exchange: str = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: """Unbind queue by deleting the binding from the server.""" - return (channel or self.channel).queue_unbind( + await (channel or self.channel).queue_unbind( queue=self.name, exchange=exchange.name, routing_key=routing_key, @@ -771,7 +874,7 @@ class Queue(MaybeChannelBound): nowait=nowait, ) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Queue): return (self.name == other.name and self.exchange == other.exchange and @@ -785,9 +888,12 @@ class Queue(MaybeChannelBound): return NotImplemented def __ne__(self, other): - return not self.__eq__(other) + ret = self.__eq__(other) + if ret is NotImplemented: + return ret + return not ret - def __repr__(self): + def __repr__(self) -> str: if self.bindings: return self._repr_entity('Queue {name} -> {bindings}'.format( name=_reprstr(self.name), @@ -801,30 +907,31 @@ class Queue(MaybeChannelBound): ) @property - def can_cache_declaration(self): + def can_cache_declaration(self) -> bool: return not self.auto_delete @classmethod - def from_dict(self, queue, - exchange=None, - exchange_type=None, - binding_key=None, - routing_key=None, - delivery_mode=None, - bindings=None, - durable=None, - queue_durable=None, - exchange_durable=None, - auto_delete=None, - queue_auto_delete=None, - exchange_auto_delete=None, - exchange_arguments=None, - queue_arguments=None, - binding_arguments=None, - consumer_arguments=None, - exclusive=None, - no_ack=None, - **options): + def from_dict( + self, queue: str, + exchange: str = None, + exchange_type: str = None, + binding_key: str = None, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + bindings: Sequence = None, + durable: bool = None, + queue_durable: bool = None, + exchange_durable: bool = None, + auto_delete: bool = None, + queue_auto_delete: bool = None, + exchange_auto_delete: bool = None, + exchange_arguments: Mapping = None, + queue_arguments: Mapping = None, + binding_arguments: Mapping = None, + consumer_arguments: Mapping = None, + exclusive: bool = None, + no_ack: bool = None, + **options) -> QueueT: return Queue( queue, exchange=Exchange( @@ -850,7 +957,7 @@ class Queue(MaybeChannelBound): bindings=bindings, ) - def as_dict(self, recurse=False): + def as_dict(self, recurse: bool = False) -> Dict: res = super().as_dict(recurse) if not recurse: return res diff --git a/kombu/message.py b/kombu/message.py index dd6f0b4e..9d74571b 100644 --- a/kombu/message.py +++ b/kombu/message.py @@ -1,10 +1,12 @@ """Message class.""" +import logging import sys - +from typing import Any, Callable, List, Mapping, Set, Tuple +from amqp.types import ChannelT +from . import types from .compression import decompress from .exceptions import MessageStateError from .serialization import loads -from .utils import abstract from .utils.functional import dictfilter __all__ = ['Message'] @@ -13,7 +15,7 @@ ACK_STATES = {'ACK', 'REJECTED', 'REQUEUED'} IS_PYPY = hasattr(sys, 'pypy_version_info') -@abstract.Message.register +@types.MessageT.register class Message: """Base class for received messages. @@ -48,7 +50,7 @@ class Message: MessageStateError = MessageStateError - errors = None + errors: List[Any] = None if not IS_PYPY: # pragma: no cover __slots__ = ( @@ -58,10 +60,19 @@ class Message: 'body', '_decoded_cache', 'accept', '__dict__', ) - def __init__(self, body=None, delivery_tag=None, - content_type=None, content_encoding=None, delivery_info={}, - properties=None, headers=None, postencode=None, - accept=None, channel=None, **kwargs): + def __init__( + self, + body: Any = None, + delivery_tag: str = None, + content_type: str = None, + content_encoding: str = None, + delivery_info: Mapping = None, + properties: Mapping = None, + headers: Mapping = None, + postencode: str = None, + accept: Set[str] = None, + channel: ChannelT = None, + **kwargs) -> None: self.errors = [] if self.errors is None else self.errors self.channel = channel self.delivery_tag = delivery_tag @@ -88,7 +99,7 @@ class Message: self.errors.append(sys.exc_info()) self.body = body - def _reraise_error(self, callback=None): + def _reraise_error(self, callback: Callable = None) -> None: try: raise self.errors[0][1].with_traceback(self.errors[0][2]) except Exception as exc: @@ -96,7 +107,7 @@ class Message: raise callback(self, exc) - def ack(self, multiple=False): + async def ack(self, multiple: bool = False) -> None: """Acknowledge this message as being processed. This will remove the message from the queue. @@ -120,24 +131,28 @@ class Message: raise self.MessageStateError( 'Message already acknowledged with state: {0._state}'.format( self)) - self.channel.basic_ack(self.delivery_tag, multiple=multiple) + await self.channel.basic_ack(self.delivery_tag, multiple=multiple) self._state = 'ACK' - def ack_log_error(self, logger, errors, multiple=False): + async def ack_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + multiple: bool = False) -> None: try: - self.ack(multiple=multiple) + await self.ack(multiple=multiple) except errors as exc: logger.critical("Couldn't ack %r, reason:%r", self.delivery_tag, exc, exc_info=True) - def reject_log_error(self, logger, errors, requeue=False): + async def reject_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + requeue: bool = False) -> None: try: - self.reject(requeue=requeue) + await self.reject(requeue=requeue) except errors as exc: logger.critical("Couldn't reject %r, reason: %r", self.delivery_tag, exc, exc_info=True) - def reject(self, requeue=False): + async def reject(self, requeue: bool = False) -> None: """Reject this message. The message will be discarded by the server. @@ -153,10 +168,10 @@ class Message: raise self.MessageStateError( 'Message already acknowledged with state: {0._state}'.format( self)) - self.channel.basic_reject(self.delivery_tag, requeue=requeue) + await self.channel.basic_reject(self.delivery_tag, requeue=requeue) self._state = 'REJECTED' - def requeue(self): + async def requeue(self) -> None: """Reject this message and put it back on the queue. Warning: @@ -174,10 +189,10 @@ class Message: raise self.MessageStateError( 'Message already acknowledged with state: {0._state}'.format( self)) - self.channel.basic_reject(self.delivery_tag, requeue=True) + await self.channel.basic_reject(self.delivery_tag, requeue=True) self._state = 'REQUEUED' - def decode(self): + def decode(self) -> Any: """Deserialize the message body. Returning the original python structure sent by the publisher. @@ -190,21 +205,21 @@ class Message: self._decoded_cache = self._decode() return self._decoded_cache - def _decode(self): + def _decode(self) -> Any: return loads(self.body, self.content_type, self.content_encoding, accept=self.accept) @property - def acknowledged(self): + def acknowledged(self) -> bool: """Set to true if the message has been acknowledged.""" return self._state in ACK_STATES @property - def payload(self): + def payload(self) -> Any: """The decoded message body.""" return self._decoded_cache if self._decoded_cache else self.decode() - def __repr__(self): + def __repr__(self) -> str: return '<{0} object at {1:#x} with details {2!r}>'.format( type(self).__name__, id(self), dictfilter( state=self._state, diff --git a/kombu/messaging.py b/kombu/messaging.py index 3dd03d14..67ccd15d 100644 --- a/kombu/messaging.py +++ b/kombu/messaging.py @@ -1,20 +1,23 @@ """Sending and receiving messages.""" +import numbers from itertools import count - +from typing import Any, Callable, Mapping, Sequence, Tuple, Union +from amqp import ChannelT +from . import types +from .abstract import Entity from .common import maybe_declare from .compression import compress from .connection import maybe_channel, is_connection from .entity import Exchange, Queue, maybe_delivery_mode from .exceptions import ContentDisallowed from .serialization import dumps, prepare_accept_content -from .utils import abstract +from .types import ChannelArgT, ClientT, MessageT, QueueT from .utils.functional import ChannelPromise, maybe_list __all__ = ['Exchange', 'Queue', 'Producer', 'Consumer'] -@abstract.Producer.register -class Producer: +class Producer(types.ProducerT): """Message Producer. Arguments: @@ -34,7 +37,7 @@ class Producer: """ #: Default exchange - exchange = None + exchange = None # type: Exchange #: Default routing key. routing_key = '' @@ -56,9 +59,16 @@ class Producer: #: default_channel). __connection__ = None - def __init__(self, channel, exchange=None, routing_key=None, - serializer=None, auto_declare=None, compression=None, - on_return=None): + _channel: ChannelT + + def __init__(self, + channel: ChannelArgT, + exchange: Exchange = None, + routing_key: str = None, + serializer: str = None, + auto_declare: bool = None, + compression: str = None, + on_return: Callable = None): self._channel = channel self.exchange = exchange self.routing_key = routing_key or self.routing_key @@ -71,20 +81,17 @@ class Producer: if auto_declare is not None: self.auto_declare = auto_declare - if self._channel: - self.revive(self._channel) - - def __repr__(self): + def __repr__(self) -> str: return '<Producer: {0._channel}>'.format(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Any, ...]: return self.__class__, self.__reduce_args__() - def __reduce_args__(self): + def __reduce_args__(self) -> Tuple[Any, ...]: return (None, self.exchange, self.routing_key, self.serializer, self.auto_declare, self.compression) - def declare(self): + async def declare(self) -> None: """Declare the exchange. Note: @@ -92,16 +99,18 @@ class Producer: the :attr:`auto_declare` flag is enabled. """ if self.exchange.name: - self.exchange.declare() + await self.exchange.declare() - def maybe_declare(self, entity, retry=False, **retry_policy): + async def maybe_declare(self, entity: Entity, + retry: bool = False, **retry_policy) -> None: """Declare exchange if not already declared during this session.""" if entity: - return maybe_declare(entity, self.channel, retry, **retry_policy) + await maybe_declare(entity, self.channel, retry, **retry_policy) - def _delivery_details(self, exchange, delivery_mode=None, - maybe_delivery_mode=maybe_delivery_mode, - Exchange=Exchange): + def _delivery_details(self, exchange: Union[Exchange, str], + delivery_mode: Union[int, str]=None, + maybe_delivery_mode: Callable = maybe_delivery_mode, + Exchange: Any = Exchange) -> Tuple[str, int]: if isinstance(exchange, Exchange): return exchange.name, maybe_delivery_mode( delivery_mode or exchange.delivery_mode, @@ -112,12 +121,23 @@ class Producer: delivery_mode or self.exchange.delivery_mode, ) - def publish(self, body, routing_key=None, delivery_mode=None, - mandatory=False, immediate=False, priority=0, - content_type=None, content_encoding=None, serializer=None, - headers=None, compression=None, exchange=None, retry=False, - retry_policy=None, declare=None, expiration=None, - **properties): + async def publish(self, body: Any, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + mandatory: bool = False, + immediate: bool = False, + priority: int = 0, + content_type: str = None, + content_encoding: str = None, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + exchange: Union[Exchange, str] = None, + retry: bool = False, + retry_policy: Mapping = None, + declare: Sequence[Entity] = None, + expiration: numbers.Number = None, + **properties) -> None: """Publish message to the specified exchange. Arguments: @@ -172,54 +192,79 @@ class Producer: declare.append(self.exchange) if retry: - _publish = self.connection.ensure(self, _publish, **retry_policy) - return _publish( + _publish = await self.connection.ensure( + self, _publish, **retry_policy) + await _publish( body, priority, content_type, content_encoding, headers, properties, routing_key, mandatory, immediate, exchange_name, declare, ) - def _publish(self, body, priority, content_type, content_encoding, - headers, properties, routing_key, mandatory, - immediate, exchange, declare): - channel = self.channel + async def _publish(self, + body: Any, + priority: int = None, + content_type: str = None, + content_encoding: str = None, + headers: Mapping = None, + properties: Mapping = None, + routing_key: str = None, + mandatory: bool = None, + immediate: bool = None, + exchange: str = None, + declare: Sequence = None) -> None: + channel = await self._resolve_channel() message = channel.prepare_message( body, priority, content_type, content_encoding, headers, properties, ) if declare: maybe_declare = self.maybe_declare - [maybe_declare(entity) for entity in declare] + for entity in declare: + await maybe_declare(entity) # handle autogenerated queue names for reply_to reply_to = properties.get('reply_to') if isinstance(reply_to, Queue): properties['reply_to'] = reply_to.name - return channel.basic_publish( + await channel.basic_publish( message, exchange=exchange, routing_key=routing_key, mandatory=mandatory, immediate=immediate, ) - def _get_channel(self): + async def _resolve_channel(self): + channel = self._channel + if isinstance(channel, ChannelPromise): + channel = self._channel = await channel.resolve() + if self.exchange: + self.exchange.revive(channel) + if self.on_return: + channel.events['basic_return'].add(self.on_return) + else: + channel = maybe_channel(channel) + await channel.open() + return channel + + def _get_channel(self) -> ChannelT: channel = self._channel if isinstance(channel, ChannelPromise): channel = self._channel = channel() - self.exchange.revive(channel) + if self.exchange: + self.exchange.revive(channel) if self.on_return: channel.events['basic_return'].add(self.on_return) return channel - def _set_channel(self, channel): + def _set_channel(self, channel: ChannelT) -> None: self._channel = channel channel = property(_get_channel, _set_channel) - def revive(self, channel): + async def revive(self, channel: ChannelT) -> None: """Revive the producer after connection loss.""" if is_connection(channel): connection = channel self.__connection__ = connection - channel = ChannelPromise(lambda: connection.default_channel) + channel = ChannelPromise(connection) if isinstance(channel, ChannelPromise): self._channel = channel self.exchange = self.exchange(channel) @@ -230,18 +275,29 @@ class Producer: self._channel.events['basic_return'].add(self.on_return) self.exchange = self.exchange(channel) - def __enter__(self): + def __enter__(self) -> 'Producer': return self - def __exit__(self, *exc_info): + async def __aenter__(self) -> 'Producer': + await self.revive(self.channel) + return self + + def __exit__(self, *exc_info) -> None: self.release() - def release(self): - pass + async def __aexit__(self, *exc_info) -> None: + self.release() + + def release(self) -> None: + ... close = release - def _prepare(self, body, serializer=None, content_type=None, - content_encoding=None, compression=None, headers=None): + def _prepare(self, body: Any, + serializer: str = None, + content_type: str = None, + content_encoding: str = None, + compression: str = None, + headers: Mapping = None) -> Tuple[Any, str, str]: # No content_type? Then we're serializing the data internally. if not content_type: @@ -267,15 +323,14 @@ class Producer: return body, content_type, content_encoding @property - def connection(self): + def connection(self) -> ClientT: try: return self.__connection__ or self.channel.connection.client except AttributeError: pass -@abstract.Consumer.register -class Consumer: +class Consumer(types.ConsumerT): """Message consumer. Arguments: @@ -358,13 +413,23 @@ class Consumer: prefetch_count = None #: Mapping of queues we consume from. - _queues = None + _queues: Mapping[str, QueueT] = None _tags = count(1) # global - - def __init__(self, channel, queues=None, no_ack=None, auto_declare=None, - callbacks=None, on_decode_error=None, on_message=None, - accept=None, prefetch_count=None, tag_prefix=None): + _active_tags: Mapping[str, str] + + def __init__( + self, + channel: ChannelT, + queues: Sequence[QueueT] = None, + no_ack: bool = None, + auto_declare: bool = None, + callbacks: Sequence[Callable] = None, + on_decode_error: Callable = None, + on_message: Callable = None, + accept: Sequence[str] = None, + prefetch_count: int = None, + tag_prefix: str = None) -> None: self.channel = channel self.queues = maybe_list(queues or []) self.no_ack = self.no_ack if no_ack is None else no_ack @@ -380,21 +445,19 @@ class Consumer: self.accept = prepare_accept_content(accept) self.prefetch_count = prefetch_count - if self.channel: - self.revive(self.channel) - @property - def queues(self): + def queues(self) -> Sequence[QueueT]: return list(self._queues.values()) @queues.setter - def queues(self, queues): + def queues(self, queues: Sequence[QueueT]) -> None: self._queues = {q.name: q for q in queues} - def revive(self, channel): + async def revive(self, channel: ChannelT) -> None: """Revive consumer after connection loss.""" self._active_tags.clear() channel = self.channel = maybe_channel(channel) + await channel.open() for qname, queue in self._queues.items(): # name may have changed after declare self._queues.pop(qname, None) @@ -402,22 +465,21 @@ class Consumer: queue.revive(channel) if self.auto_declare: - self.declare() + await self.declare() if self.prefetch_count is not None: - self.qos(prefetch_count=self.prefetch_count) + await self.qos(prefetch_count=self.prefetch_count) - def declare(self): + async def declare(self) -> None: """Declare queues, exchanges and bindings. Note: This is done automatically at instantiation when :attr:`auto_declare` is set. """ - for queue in self._queues.values(): - queue.declare() + [await queue.declare() for queue in self._queues.values()] - def register_callback(self, callback): + def register_callback(self, callback: Callable) -> None: """Register a new callback to be called when a message is received. Note: @@ -427,11 +489,26 @@ class Consumer: """ self.callbacks.append(callback) - def __enter__(self): + def __enter__(self) -> 'Consumer': + self.revive(self.channel) self.consume() return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aenter__(self) -> 'Consumer': + await self.revive(self.channel) + await self.consume() + return self + + async def __aexit__(self, *exc_info) -> None: + if self.channel: + conn_errors = self.channel.connection.client.connection_errors + if not isinstance(exc_info[1], conn_errors): + try: + await self.cancel() + except Exception: + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self.channel: conn_errors = self.channel.connection.client.connection_errors if not isinstance(exc_val, conn_errors): @@ -440,7 +517,7 @@ class Consumer: except Exception: pass - def add_queue(self, queue): + async def add_queue(self, queue: QueueT) -> QueueT: """Add a queue to the list of queues to consume from. Note: @@ -449,11 +526,11 @@ class Consumer: """ queue = queue(self.channel) if self.auto_declare: - queue.declare() + await queue.declare() self._queues[queue.name] = queue return queue - def consume(self, no_ack=None): + async def consume(self, no_ack: bool = None) -> None: """Start consuming messages. Can be called multiple times, but note that while it @@ -470,10 +547,10 @@ class Consumer: H, T = queues[:-1], queues[-1] for queue in H: - self._basic_consume(queue, no_ack=no_ack, nowait=True) - self._basic_consume(T, no_ack=no_ack, nowait=False) + await self._basic_consume(queue, no_ack=no_ack, nowait=True) + await self._basic_consume(T, no_ack=no_ack, nowait=False) - def cancel(self): + async def cancel(self) -> None: """End all active queue consumers. Note: @@ -482,38 +559,36 @@ class Consumer: """ cancel = self.channel.basic_cancel for tag in self._active_tags.values(): - cancel(tag) + await cancel(tag) self._active_tags.clear() close = cancel - def cancel_by_queue(self, queue): + async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None: """Cancel consumer by queue name.""" - qname = queue.name if isinstance(queue, Queue) else queue + qname = queue.name if isinstance(queue, QueueT) else queue try: tag = self._active_tags.pop(qname) except KeyError: pass else: - self.channel.basic_cancel(tag) + await self.channel.basic_cancel(tag) finally: self._queues.pop(qname, None) - def consuming_from(self, queue): + def consuming_from(self, queue: Union[QueueT, str]) -> bool: """Return :const:`True` if currently consuming from queue'.""" - name = queue - if isinstance(queue, Queue): - name = queue.name + name = queue.name if isinstance(queue, QueueT) else queue return name in self._active_tags - def purge(self): + async def purge(self) -> int: """Purge messages from all queues. Warning: This will *delete all ready messages*, there is no undo operation. """ - return sum(queue.purge() for queue in self._queues.values()) + return sum(await queue.purge() for queue in self._queues.values()) - def flow(self, active): + async def flow(self, active: bool) -> None: """Enable/disable flow from peer. This is a simple flow-control mechanism that a peer can use @@ -524,9 +599,12 @@ class Consumer: will finish sending the current content (if any), and then wait until flow is reactivated. """ - self.channel.flow(active) + await self.channel.flow(active) - def qos(self, prefetch_size=0, prefetch_count=0, apply_global=False): + async def qos(self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: """Specify quality of service. The client can request that messages should be sent in @@ -550,11 +628,10 @@ class Consumer: apply_global (bool): Apply new settings globally on all channels. """ - return self.channel.basic_qos(prefetch_size, - prefetch_count, - apply_global) + await self.channel.basic_qos( + prefetch_size, prefetch_count, apply_global) - def recover(self, requeue=False): + async def recover(self, requeue: bool = False) -> None: """Redeliver unacknowledged messages. Asks the broker to redeliver all unacknowledged messages @@ -566,9 +643,9 @@ class Consumer: server will attempt to requeue the message, potentially then delivering it to an alternative subscriber. """ - return self.channel.basic_recover(requeue=requeue) + await self.channel.basic_recover(requeue=requeue) - def receive(self, body, message): + async def receive(self, body: Any, message: MessageT) -> None: """Method called when a message is received. This dispatches to the registered :attr:`callbacks`. @@ -584,24 +661,27 @@ class Consumer: callbacks = self.callbacks if not callbacks: raise NotImplementedError('Consumer does not have any callbacks') - [callback(body, message) for callback in callbacks] + for callback in callbacks: + await callback(body, message) - def _basic_consume(self, queue, consumer_tag=None, - no_ack=no_ack, nowait=True): + async def _basic_consume(self, queue: QueueT, + consumer_tag: str = None, + no_ack: bool = no_ack, + nowait: bool = True) -> str: tag = self._active_tags.get(queue.name) if tag is None: tag = self._add_tag(queue, consumer_tag) - queue.consume(tag, self._receive_callback, - no_ack=no_ack, nowait=nowait) + await queue.consume(tag, self._receive_callback, + no_ack=no_ack, nowait=nowait) return tag - def _add_tag(self, queue, consumer_tag=None): + def _add_tag(self, queue: QueueT, consumer_tag: str = None) -> str: tag = consumer_tag or '{0}{1}'.format( self.tag_prefix, next(self._tags)) self._active_tags[queue.name] = tag return tag - def _receive_callback(self, message): + async def _receive_callback(self, message: MessageT) -> None: accept = self.accept on_m, channel, decoded = self.on_message, self.channel, None try: @@ -611,20 +691,20 @@ class Consumer: if accept is not None: message.accept = accept if message.errors: - return message._reraise_error(self.on_decode_error) + message._reraise_error(self.on_decode_error) decoded = None if on_m else message.decode() except Exception as exc: if not self.on_decode_error: raise self.on_decode_error(message, exc) else: - return on_m(message) if on_m else self.receive(decoded, message) + await on_m(message) if on_m else self.receive(decoded, message) - def __repr__(self): + def __repr__(self) -> str: return '<{name}: {0.queues}>'.format(self, name=type(self).__name__) @property - def connection(self): + def connection(self) -> ClientT: try: return self.channel.connection.client except AttributeError: diff --git a/kombu/pidbox.py b/kombu/pidbox.py index 0c6196fc..4775a59b 100644 --- a/kombu/pidbox.py +++ b/kombu/pidbox.py @@ -8,12 +8,19 @@ from copy import copy from itertools import count from threading import local from time import time +from typing import Any, Callable, Mapping, Set, Sequence, cast -from . import Exchange, Queue, Consumer, Producer -from .clocks import LamportClock +from amqp.types import ChannelT + +from .clocks import Clock, LamportClock from .common import maybe_declare, oid_from +from .entity import Exchange, Queue from .exceptions import InconsistencyError from .log import get_logger +from .messaging import Consumer, Producer +from .typing import ( + ClientT, ConsumerT, ExchangeT, ProducerT, MessageT, ResourceT, +) from .utils.functional import maybe_evaluate, reprcall from .utils.objects import cached_property from .utils.uuid import uuid @@ -35,22 +42,26 @@ class Node: """Mailbox node.""" #: hostname of the node. - hostname = None + hostname: str = None #: the :class:`Mailbox` this is a node for. - mailbox = None + mailbox: 'Mailbox' = None #: map of method name/handlers. - handlers = None + handlers: Mapping[str, Callable] = None #: current context (passed on to handlers) - state = None + state: Any = None #: current channel. - channel = None - - def __init__(self, hostname, state=None, channel=None, - handlers=None, mailbox=None): + channel: ChannelT = None + + def __init__(self, + hostname: str, + state: Any = None, + channel: ChannelT = None, + handlers: Mapping[str, Callable] = None, + mailbox: 'Mailbox' = None) -> None: self.channel = channel self.mailbox = mailbox self.hostname = hostname @@ -60,10 +71,14 @@ class Node: handlers = {} self.handlers = handlers - def Consumer(self, channel=None, no_ack=True, accept=None, **options): + def Consumer(self, + channel: ChannelT = None, + no_ack: bool = True, + accept: Set[str] = None, + **options) -> ConsumerT: queue = self.mailbox.get_queue(self.hostname) - def verify_exclusive(name, messages, consumers): + def verify_exclusive(name: str, messages: int, consumers: int) -> None: if consumers: warnings.warn(W_PIDBOX_IN_USE.format(node=self)) queue.on_declared = verify_exclusive @@ -74,22 +89,24 @@ class Node: **options ) - def handler(self, fun): + def handler(self, fun: Callable) -> Callable: self.handlers[fun.__name__] = fun return fun - def on_decode_error(self, message, exc): + def on_decode_error(self, message: str, exc: Exception) -> None: error('Cannot decode message: %r', exc, exc_info=1) - def listen(self, channel=None, callback=None): + def listen(self, + channel: ChannelT = None, + callback: Callable = None) -> ConsumerT: consumer = self.Consumer(channel=channel, callbacks=[callback or self.handle_message], on_decode_error=self.on_decode_error) consumer.consume() return consumer - def dispatch(self, method, arguments=None, - reply_to=None, ticket=None, **kwargs): + def dispatch(self, method: str, arguments: Mapping = None, + reply_to: str = None, ticket: str = None, **kwargs) -> Any: arguments = arguments or {} debug('pidbox received method %s [reply_to:%s ticket:%s]', reprcall(method, (), kwargs=arguments), reply_to, ticket) @@ -109,16 +126,17 @@ class Node: ticket=ticket) return reply - def handle(self, method, arguments={}): + def handle(self, method: str, arguments: Mapping = {}) -> Any: return self.handlers[method](self.state, **arguments) - def handle_call(self, method, arguments): + def handle_call(self, method: str, arguments: Mapping) -> Any: return self.handle(method, arguments) - def handle_cast(self, method, arguments): + def handle_cast(self, method: str, arguments: Mapping) -> Any: return self.handle(method, arguments) - def handle_message(self, body, message=None): + def handle_message(self, body: Any, message: MessageT = None) -> None: + body = cast(Mapping, body) destination = body.get('destination') if message: self.adjust_clock(message.headers.get('clock') or 0) @@ -126,7 +144,8 @@ class Node: return self.dispatch(**body) dispatch_from_message = handle_message - def reply(self, data, exchange, routing_key, ticket, **kwargs): + def reply(self, data: Any, exchange: str, routing_key: str, ticket: str, + **kwargs) -> None: self.mailbox._publish_reply(data, exchange, routing_key, ticket, channel=self.channel, serializer=self.mailbox.serializer) @@ -135,36 +154,42 @@ class Node: class Mailbox: """Process Mailbox.""" - node_cls = Node - exchange_fmt = '%s.pidbox' - reply_exchange_fmt = 'reply.%s.pidbox' + node_cls: type = Node + exchange_fmt: str = '%s.pidbox' + reply_exchange_fmt: str = 'reply.%s.pidbox' #: Name of application. - namespace = None + namespace: str = None #: Connection (if bound). - connection = None + connection: ClientT = None #: Exchange type (usually direct, or fanout for broadcast). - type = 'direct' + type: str = 'direct' #: mailbox exchange (init by constructor). - exchange = None + exchange: ExchangeT = None #: exchange to send replies to. - reply_exchange = None + reply_exchange: ExchangeT = None #: Only accepts json messages by default. - accept = ['json'] + accept: Set[str] = {'json'} #: Message serializer - serializer = None - - def __init__(self, namespace, - type='direct', connection=None, clock=None, - accept=None, serializer=None, producer_pool=None, - queue_ttl=None, queue_expires=None, - reply_queue_ttl=None, reply_queue_expires=10.0): + serializer: str = None + + def __init__(self, namespace: str, + type: str = 'direct', + connection: ClientT = None, + clock: Clock = None, + accept: Set[str] = None, + serializer: str = None, + producer_pool: ResourceT = None, + queue_ttl: float = None, + queue_expires: float = None, + reply_queue_ttl: float = None, + reply_queue_expires: float = 10.0) -> None: self.namespace = namespace self.connection = connection self.type = type @@ -181,36 +206,48 @@ class Mailbox: self.reply_queue_expires = reply_queue_expires self._producer_pool = producer_pool - def __call__(self, connection): + def __call__(self, connection: ClientT) -> 'Mailbox': bound = copy(self) bound.connection = connection return bound - def Node(self, hostname=None, state=None, channel=None, handlers=None): + def Node(self, + hostname: str = None, + state: Any = None, + channel: ChannelT = None, + handlers: Mapping[str, Callable] = None) -> Node: hostname = hostname or socket.gethostname() return self.node_cls(hostname, state, channel, handlers, mailbox=self) - def call(self, destination, command, kwargs={}, - timeout=None, callback=None, channel=None): + def call(self, destination: str, command: str, + kwargs: Mapping[str, Any] = {}, + timeout: float = None, + callback: Callable = None, + channel: ChannelT = None) -> Sequence[Mapping]: return self._broadcast(command, kwargs, destination, reply=True, timeout=timeout, callback=callback, channel=channel) - def cast(self, destination, command, kwargs={}): - return self._broadcast(command, kwargs, destination, reply=False) + def cast(self, destination: str, command: str, + kwargs: Mapping[str, Any] = {}) -> None: + self._broadcast(command, kwargs, destination, reply=False) - def abcast(self, command, kwargs={}): - return self._broadcast(command, kwargs, reply=False) + def abcast(self, command: str, kwargs: Mapping[str, Any] = {}) -> None: + self._broadcast(command, kwargs, reply=False) - def multi_call(self, command, kwargs={}, timeout=1, - limit=None, callback=None, channel=None): + def multi_call(self, command: str, + kwargs: Mapping[str, Any] = {}, + timeout: str = 1, + limit: int = None, + callback: Callable = None, + channel: ChannelT = None) -> Sequence[Mapping]: return self._broadcast(command, kwargs, reply=True, timeout=timeout, limit=limit, callback=callback, channel=channel) - def get_reply_queue(self): + def get_reply_queue(self) -> Queue: oid = self.oid return Queue( '%s.%s' % (oid, self.reply_exchange.name), @@ -223,10 +260,10 @@ class Mailbox: ) @cached_property - def reply_queue(self): + def reply_queue(self) -> Queue: return self.get_reply_queue() - def get_queue(self, hostname): + def get_queue(self, hostname: str) -> Queue: return Queue( '%s.%s.pidbox' % (hostname, self.namespace), exchange=self.exchange, @@ -237,7 +274,9 @@ class Mailbox: ) @contextmanager - def producer_or_acquire(self, producer=None, channel=None): + def producer_or_acquire(self, + producer: ProducerT = None, + channel: ChannelT = None) -> ProducerT: if producer: yield producer elif self.producer_pool: @@ -246,8 +285,11 @@ class Mailbox: else: yield Producer(channel, auto_declare=False) - def _publish_reply(self, reply, exchange, routing_key, ticket, - channel=None, producer=None, **opts): + def _publish_reply(self, reply: Any, + exchange: str, routing_key: str, ticket: str, + channel: ChannelT = None, + producer: ProducerT = None, + **opts) -> None: chan = channel or self.connection.default_channel exchange = Exchange(exchange, exchange_type='direct', delivery_mode='transient', @@ -265,9 +307,13 @@ class Mailbox: # queue probably deleted and no one is expecting a reply. pass - def _publish(self, type, arguments, destination=None, - reply_ticket=None, channel=None, timeout=None, - serializer=None, producer=None): + def _publish(self, type: str, arguments: Mapping[str, Any], + destination: str = None, + reply_ticket: str = None, + channel: ChannelT = None, + timeout: float = None, + serializer: str = None, + producer: ProducerT = None) -> None: message = {'method': type, 'arguments': arguments, 'destination': destination} @@ -287,9 +333,15 @@ class Mailbox: serializer=serializer, ) - def _broadcast(self, command, arguments=None, destination=None, - reply=False, timeout=1, limit=None, - callback=None, channel=None, serializer=None): + def _broadcast(self, command: str, + arguments: Mapping[str, Any] = None, + destination: str = None, + reply: bool = False, + timeout: float = 1.0, + limit: int = None, + callback: Callable = None, + channel: ChannelT = None, + serializer: str = None) -> Sequence[Mapping]: if destination is not None and \ not isinstance(destination, (list, tuple)): raise ValueError( @@ -317,9 +369,12 @@ class Mailbox: callback=callback, channel=chan) - def _collect(self, ticket, - limit=None, timeout=1, callback=None, - channel=None, accept=None): + def _collect(self, ticket: str, + limit: str = None, + timeout: float = 1.0, + callback: Callable = None, + channel: ChannelT = None, + accept: Set[str] = None) -> Sequence[Mapping]: if accept is None: accept = self.accept chan = channel or self.connection.default_channel @@ -334,7 +389,7 @@ class Mailbox: except KeyError: pass - def on_message(body, message): + def on_message(body: Any, message: MessageT) -> None: # ticket header added in kombu 2.5 header = message.headers.get adjust_clock(header('clock') or 0) @@ -361,20 +416,20 @@ class Mailbox: finally: chan.after_reply_message_received(queue.name) - def _get_exchange(self, namespace, type): + def _get_exchange(self, namespace: str, type: str) -> Exchange: return Exchange(self.exchange_fmt % namespace, type=type, durable=False, delivery_mode='transient') - def _get_reply_exchange(self, namespace): + def _get_reply_exchange(self, namespace: str) -> Exchange: return Exchange(self.reply_exchange_fmt % namespace, type='direct', durable=False, delivery_mode='transient') @cached_property - def oid(self): + def oid(self) -> str: try: return self._tls.OID except AttributeError: @@ -382,5 +437,5 @@ class Mailbox: return oid @cached_property - def producer_pool(self): + def producer_pool(self) -> ResourceT: return maybe_evaluate(self._producer_pool) diff --git a/kombu/pools.py b/kombu/pools.py index f9168886..88770212 100644 --- a/kombu/pools.py +++ b/kombu/pools.py @@ -1,23 +1,25 @@ """Public resource pools.""" import os - from itertools import chain - +from typing import Any from .connection import Resource from .messaging import Producer +from .types import ClientT, ProducerT, ResourceT from .utils.collections import EqualityDict from .utils.compat import register_after_fork from .utils.functional import lazy -__all__ = ['ProducerPool', 'PoolGroup', 'register_group', - 'connections', 'producers', 'get_limit', 'set_limit', 'reset'] +__all__ = [ + 'ProducerPool', 'PoolGroup', 'register_group', + 'connections', 'producers', 'get_limit', 'set_limit', 'reset', +] _limit = [10] _groups = [] -use_global_limit = object() +use_global_limit = 44444444444444444444444444 disable_limit_protection = os.environ.get('KOMBU_DISABLE_LIMIT_PROTECTION') -def _after_fork_cleanup_group(group): +def _after_fork_cleanup_group(group: 'PoolGroup') -> None: group.clear() @@ -27,15 +29,19 @@ class ProducerPool(Resource): Producer = Producer close_after_fork = True - def __init__(self, connections, *args, Producer=None, **kwargs): + def __init__(self, + connections: ResourceT, + *args, + Producer: type = None, + **kwargs) -> None: self.connections = connections self.Producer = Producer or self.Producer super().__init__(*args, **kwargs) - def _acquire_connection(self): + def _acquire_connection(self) -> ClientT: return self.connections.acquire(block=True) - def create_producer(self): + def create_producer(self) -> ProducerT: conn = self._acquire_connection() try: return self.Producer(conn) @@ -43,18 +49,18 @@ class ProducerPool(Resource): conn.release() raise - def new(self): + def new(self) -> lazy: return lazy(self.create_producer) - def setup(self): + def setup(self) -> None: if self.limit: for _ in range(self.limit): self._resource.put_nowait(self.new()) - def close_resource(self, resource): + def close_resource(self, resource: ProducerT) -> None: ... - def prepare(self, p): + def prepare(self, p: Any) -> None: if callable(p): p = p() if p._channel is None: @@ -66,7 +72,7 @@ class ProducerPool(Resource): raise return p - def release(self, resource): + def release(self, resource: ProducerT) -> None: if resource.__connection__: resource.__connection__.release() resource.channel = None @@ -76,24 +82,25 @@ class ProducerPool(Resource): class PoolGroup(EqualityDict): """Collection of resource pools.""" - def __init__(self, limit=None, close_after_fork=True): + def __init__(self, limit: int = None, + close_after_fork: bool = True) -> None: self.limit = limit self.close_after_fork = close_after_fork if self.close_after_fork and register_after_fork is not None: register_after_fork(self, _after_fork_cleanup_group) - def create(self, resource, limit): + def create(self, connection: ClientT, limit: int) -> Any: raise NotImplementedError('PoolGroups must define ``create``') - def __missing__(self, resource): + def __missing__(self, resource: Any) -> Any: limit = self.limit - if limit is use_global_limit: + if limit == use_global_limit: limit = get_limit() k = self[resource] = self.create(resource, limit) return k -def register_group(group): +def register_group(group: PoolGroup) -> PoolGroup: """Register group (can be used as decorator).""" _groups.append(group) return group @@ -102,7 +109,7 @@ def register_group(group): class Connections(PoolGroup): """Collection of connection pools.""" - def create(self, connection, limit): + def create(self, connection: ClientT, limit: int) -> Any: return connection.Pool(limit=limit) connections = register_group(Connections(limit=use_global_limit)) # noqa: E305 @@ -110,21 +117,24 @@ connections = register_group(Connections(limit=use_global_limit)) # noqa: E305 class Producers(PoolGroup): """Collection of producer pools.""" - def create(self, connection, limit): + def create(self, connection: ClientT, limit: int) -> Any: return ProducerPool(connections[connection], limit=limit) producers = register_group(Producers(limit=use_global_limit)) # noqa: E305 -def _all_pools(): +def _all_pools() -> chain[ResourceT]: return chain(*[(g.values() if g else iter([])) for g in _groups]) -def get_limit(): +def get_limit() -> int: """Get current connection pool limit.""" return _limit[0] -def set_limit(limit, force=False, reset_after=False, ignore_errors=False): +def set_limit(limit: int, + force: bool = False, + reset_after: bool = False, + ignore_errors: bool = False) -> int: """Set new connection pool limit.""" limit = limit or 0 glimit = _limit[0] or 0 @@ -135,7 +145,7 @@ def set_limit(limit, force=False, reset_after=False, ignore_errors=False): return limit -def reset(*args, **kwargs): +def reset(*args, **kwargs) -> None: """Reset all pools by closing open resources.""" for pool in _all_pools(): try: diff --git a/kombu/resource.py b/kombu/resource.py index d98f49e7..1829b466 100644 --- a/kombu/resource.py +++ b/kombu/resource.py @@ -1,16 +1,15 @@ """Generic resource pool implementation.""" import os - from collections import deque from queue import Empty, LifoQueue as _LifoQueue - +from typing import Any from . import exceptions -from .utils import abstract +from .types import ResourceT from .utils.compat import register_after_fork from .utils.functional import lazy -def _after_fork_cleanup_resource(resource): +def _after_fork_cleanup_resource(resource: ResourceT) -> None: try: resource.force_close_all() except Exception: @@ -20,19 +19,21 @@ def _after_fork_cleanup_resource(resource): class LifoQueue(_LifoQueue): """Last in first out version of Queue.""" - def _init(self, maxsize): + def _init(self, maxsize: int) -> None: self.queue = deque() -@abstract.Resource.register -class Resource: +class Resource(ResourceT): """Pool of resources.""" LimitExceeded = exceptions.LimitExceeded close_after_fork = False - def __init__(self, limit=None, preload=None, close_after_fork=None): + def __init__(self, + limit: int = None, + preload: int = None, + close_after_fork: bool = None): self._limit = limit self.preload = preload or 0 self._closed = False @@ -47,10 +48,10 @@ class Resource: register_after_fork(self, _after_fork_cleanup_resource) self.setup() - def setup(self): + def setup(self) -> None: raise NotImplementedError('subclass responsibility') - def _add_when_empty(self): + def _add_when_empty(self) -> None: if self.limit and len(self._dirty) >= self.limit: raise self.LimitExceeded(self.limit) # All taken, put new on the queue and @@ -58,7 +59,7 @@ class Resource: # will get the resource. self._resource.put_nowait(self.new()) - def acquire(self, block=False, timeout=None): + def acquire(self, block: bool = False, timeout: int = None): """Acquire resource. Arguments: @@ -94,7 +95,7 @@ class Resource: else: R = self.prepare(self.new()) - def release(): + def release() -> None: """Release resource so it can be used by another thread. Warnings: @@ -107,16 +108,16 @@ class Resource: return R - def prepare(self, resource): + def prepare(self, resource: Any) -> Any: return resource - def close_resource(self, resource): + def close_resource(self, resource: Any) -> None: resource.close() - def release_resource(self, resource): + def release_resource(self, resource: Any) -> None: ... - def replace(self, resource): + def replace(self, resource: Any) -> None: """Replace existing resource with a new instance. This can be used in case of defective resources. @@ -125,7 +126,7 @@ class Resource: self._dirty.discard(resource) self.close_resource(resource) - def release(self, resource): + def release(self, resource: Any) -> None: if self.limit: self._dirty.discard(resource) self._resource.put_nowait(resource) @@ -133,10 +134,10 @@ class Resource: else: self.close_resource(resource) - def collect_resource(self, resource): + def collect_resource(self, resource: Any) -> None: ... - def force_close_all(self): + def force_close_all(self) -> None: """Close and remove all resources in the pool (also those in use). Used to close resources from parent processes after fork @@ -169,7 +170,11 @@ class Resource: except AttributeError: pass # Issue #78 - def resize(self, limit, force=False, ignore_errors=False, reset=False): + def resize( + self, limit: int, + force: bool = False, + ignore_errors: bool = False, + reset: bool = False) -> None: prev_limit = self._limit if (self._dirty and limit < self._limit) and not ignore_errors: if not force: @@ -187,7 +192,7 @@ class Resource: if limit < prev_limit: self._shrink_down() - def _shrink_down(self): + def _shrink_down(self) -> None: resource = self._resource # Items to the left are last recently used, so we remove those first. with resource.mutex: @@ -195,11 +200,11 @@ class Resource: self.collect_resource(resource.queue.popleft()) @property - def limit(self): + def limit(self) -> int: return self._limit @limit.setter - def limit(self, limit): + def limit(self, limit: int) -> None: self.resize(limit) if os.environ.get('KOMBU_DEBUG_POOL'): # pragma: no cover @@ -208,7 +213,7 @@ class Resource: _next_resource_id = 0 - def acquire(self, *args, **kwargs): # noqa + def acquire(self, *args, **kwargs) -> Any: # noqa import traceback id = self._next_resource_id = self._next_resource_id + 1 print('+{0} ACQUIRE {1}'.format(id, self.__class__.__name__)) @@ -220,7 +225,7 @@ class Resource: r.acquired_by.append(traceback.format_stack()) return r - def release(self, resource): # noqa + def release(self, resource: Any) -> None: # noqa id = resource._resource_id print('+{0} RELEASE {1}'.format(id, self.__class__.__name__)) r = self._orig_release(resource) diff --git a/kombu/serialization.py b/kombu/serialization.py index 7f885d70..a8a9cc48 100644 --- a/kombu/serialization.py +++ b/kombu/serialization.py @@ -9,9 +9,9 @@ try: except ImportError: # pragma: no cover cpickle = None # noqa -from collections import namedtuple from contextlib import contextmanager from io import BytesIO +from typing import Callable, NamedTuple from .exceptions import ( ContentDisallowed, DecodeError, EncodeError, SerializerNotInstalled @@ -37,7 +37,13 @@ pickle_load = pickle.load #: There's a new protocol (3) but this is only supported by Python 3. pickle_protocol = int(os.environ.get('PICKLE_PROTOCOL', 2)) -codec = namedtuple('codec', ('content_type', 'content_encoding', 'encoder')) + +class codec(NamedTuple): + """Codec registration triple.""" + + content_type: str + content_encoding: str + encoder: Callable @contextmanager diff --git a/kombu/simple.py b/kombu/simple.py index 4facedf7..2436a28e 100644 --- a/kombu/simple.py +++ b/kombu/simple.py @@ -1,28 +1,32 @@ """Simple messaging interface.""" import socket - from collections import deque from time import monotonic from queue import Empty - +from typing import Any, Mapping from . import entity from . import messaging from .connection import maybe_channel +from .types import ChannelArgT, ConsumerT, ProducerT, MessageT, SimpleQueueT __all__ = ['SimpleQueue', 'SimpleBuffer'] -class SimpleBase: +class SimpleBase(SimpleQueueT): Empty = Empty - _consuming = False + _consuming: bool = False - def __enter__(self): + def __enter__(self) -> SimpleQueueT: return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info) -> None: self.close() - def __init__(self, channel, producer, consumer, no_ack=False): + def __init__(self, + channel: ChannelArgT, + producer: ProducerT, + consumer: ConsumerT, + no_ack: bool = False) -> None: self.channel = maybe_channel(channel) self.producer = producer self.consumer = consumer @@ -31,7 +35,7 @@ class SimpleBase: self.buffer = deque() self.consumer.register_callback(self._receive) - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: float = None) -> MessageT: if not block: return self.get_nowait() self._consume() @@ -49,14 +53,18 @@ class SimpleBase: elapsed += monotonic() - time_start remaining = timeout and timeout - elapsed or None - def get_nowait(self): + def get_nowait(self) -> MessageT: m = self.queue.get(no_ack=self.no_ack) if not m: raise self.Empty() return m - def put(self, message, serializer=None, headers=None, compression=None, - routing_key=None, **kwargs): + def put(self, message: Any, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + routing_key: str = None, + **kwargs) -> None: self.producer.publish(message, serializer=serializer, routing_key=routing_key, @@ -64,61 +72,72 @@ class SimpleBase: compression=compression, **kwargs) - def clear(self): + def clear(self) -> int: return self.consumer.purge() - def qsize(self): + def qsize(self) -> int: _, size, _ = self.queue.queue_declare(passive=True) return size - def close(self): + def close(self) -> None: self.consumer.cancel() - def _receive(self, message_data, message): + def _receive(self, message_data: Any, message: MessageT) -> None: self.buffer.append(message) - def _consume(self): + def _consume(self) -> None: if not self._consuming: self.consumer.consume(no_ack=self.no_ack) self._consuming = True - def __len__(self): + def __len__(self) -> int: """`len(self) -> self.qsize()`.""" return self.qsize() - def __bool__(self): + def __bool__(self) -> bool: return True class SimpleQueue(SimpleBase): """Simple API for persistent queues.""" - no_ack = False - queue_opts = {} - exchange_opts = {'type': 'direct'} - - def __init__(self, channel, name, no_ack=None, queue_opts=None, - exchange_opts=None, serializer=None, - compression=None, **kwargs): + no_ack: bool = False + queue_opts: Mapping = {} + exchange_opts: Mapping = {'type': 'direct'} + + def __init__( + self, + channel: ChannelArgT, + name: str, + no_ack: bool = None, + queue_opts: Mapping = None, + exchange_opts: Mapping = None, + serializer: str = None, + compression: str = None, + **kwargs) -> None: queue = name - queue_opts = dict(self.queue_opts, **queue_opts or {}) - exchange_opts = dict(self.exchange_opts, **exchange_opts or {}) + all_queue_opts = dict(self.queue_opts) + if queue_opts: + all_queue_opts.update(queue_opts) + all_exchange_opts = dict(self.exchange_opts) + if exchange_opts: + all_exchange_opts.update(exchange_opts) if no_ack is None: no_ack = self.no_ack if not isinstance(queue, entity.Queue): - exchange = entity.Exchange(name, **exchange_opts) - queue = entity.Queue(name, exchange, name, **queue_opts) + exchange = entity.Exchange(name, **all_exchange_opts) + queue = entity.Queue(name, exchange, name, **all_queue_opts) routing_key = name else: name = queue.name exchange = queue.exchange routing_key = queue.routing_key - consumer = messaging.Consumer(channel, queue) + consumer = messaging.Consumer(channel, [queue]) producer = messaging.Producer(channel, exchange, serializer=serializer, routing_key=routing_key, compression=compression) - super().__init__(channel, producer, consumer, no_ack, **kwargs) + super().__init__(channel, producer, consumer, no_ack) class SimpleBuffer(SimpleQueue): diff --git a/kombu/transport/SLMQ.py b/kombu/transport/SLMQ.py deleted file mode 100644 index 91eebe5f..00000000 --- a/kombu/transport/SLMQ.py +++ /dev/null @@ -1,181 +0,0 @@ -"""SoftLayer Message Queue transport.""" -import os -import socket -import string - -from queue import Empty - -from kombu.utils.encoding import bytes_to_str, safe_str -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property - -from . import virtual - -try: - from softlayer_messaging import get_client - from softlayer_messaging.errors import ResponseError -except ImportError: # pragma: no cover - get_client = ResponseError = None # noqa - -# dots are replaced by dash, all other punctuation replaced by underscore. -CHARS_REPLACE_TABLE = { - ord(c): 0x5f for c in string.punctuation if c not in '_' -} - - -class Channel(virtual.Channel): - """SLMQ Channel.""" - - default_visibility_timeout = 1800 # 30 minutes. - domain_format = 'kombu%(vhost)s' - _slmq = None - _queue_cache = {} - _noack_queues = set() - - def __init__(self, *args, **kwargs): - if get_client is None: - raise ImportError( - 'SLMQ transport requires the softlayer_messaging library', - ) - super().__init__(*args, **kwargs) - queues = self.slmq.queues() - for queue in queues: - self._queue_cache[queue] = queue - - def basic_consume(self, queue, no_ack, *args, **kwargs): - if no_ack: - self._noack_queues.add(queue) - return super().basic_consume(queue, no_ack, *args, **kwargs) - - def basic_cancel(self, consumer_tag): - if consumer_tag in self._consumers: - queue = self._tag_to_queue[consumer_tag] - self._noack_queues.discard(queue) - return super().basic_cancel(consumer_tag) - - def entity_name(self, name, table=CHARS_REPLACE_TABLE): - """Format AMQP queue name into a valid SLQS queue name.""" - return str(safe_str(name)).translate(table) - - def _new_queue(self, queue, **kwargs): - """Ensure a queue exists in SLQS.""" - queue = self.entity_name(self.queue_name_prefix + queue) - try: - return self._queue_cache[queue] - except KeyError: - try: - self.slmq.create_queue( - queue, visibility_timeout=self.visibility_timeout) - except ResponseError: - pass - q = self._queue_cache[queue] = self.slmq.queue(queue) - return q - - def _delete(self, queue, *args, **kwargs): - """Delete queue by name.""" - queue_name = self.entity_name(queue) - self._queue_cache.pop(queue_name, None) - self.slmq.queue(queue_name).delete(force=True) - super()._delete(queue_name) - - def _put(self, queue, message, **kwargs): - """Put message onto queue.""" - q = self._new_queue(queue) - q.push(dumps(message)) - - def _get(self, queue): - """Try to retrieve a single message off ``queue``.""" - q = self._new_queue(queue) - rs = q.pop(1) - if rs['items']: - m = rs['items'][0] - payload = loads(bytes_to_str(m['body'])) - if queue in self._noack_queues: - q.message(m['id']).delete() - else: - payload['properties']['delivery_info'].update({ - 'slmq_message_id': m['id'], 'slmq_queue_name': q.name}) - return payload - raise Empty() - - def basic_ack(self, delivery_tag): - delivery_info = self.qos.get(delivery_tag).delivery_info - try: - queue = delivery_info['slmq_queue_name'] - except KeyError: - pass - else: - self.delete_message(queue, delivery_info['slmq_message_id']) - super().basic_ack(delivery_tag) - - def _size(self, queue): - """Return the number of messages in a queue.""" - return self._new_queue(queue).detail()['message_count'] - - def _purge(self, queue): - """Delete all current messages in a queue.""" - q = self._new_queue(queue) - n = 0 - l = q.pop(10) - while l['items']: - for m in l['items']: - self.delete_message(queue, m['id']) - n += 1 - l = q.pop(10) - return n - - def delete_message(self, queue, message_id): - q = self.slmq.queue(self.entity_name(queue)) - return q.message(message_id).delete() - - @property - def slmq(self): - if self._slmq is None: - conninfo = self.conninfo - account = os.environ.get('SLMQ_ACCOUNT', conninfo.virtual_host) - user = os.environ.get('SL_USERNAME', conninfo.userid) - api_key = os.environ.get('SL_API_KEY', conninfo.password) - host = os.environ.get('SLMQ_HOST', conninfo.hostname) - port = os.environ.get('SLMQ_PORT', conninfo.port) - secure = bool(os.environ.get( - 'SLMQ_SECURE', self.transport_options.get('secure')) or True, - ) - endpoint = '{0}://{1}{2}'.format( - 'https' if secure else 'http', host, - ':{0}'.format(port) if port else '', - ) - - self._slmq = get_client(account, endpoint=endpoint) - self._slmq.authenticate(user, api_key) - return self._slmq - - @property - def conninfo(self): - return self.connection.client - - @property - def transport_options(self): - return self.connection.client.transport_options - - @cached_property - def visibility_timeout(self): - return (self.transport_options.get('visibility_timeout') or - self.default_visibility_timeout) - - @cached_property - def queue_name_prefix(self): - return self.transport_options.get('queue_name_prefix', '') - - -class Transport(virtual.Transport): - """SLMQ Transport.""" - - Channel = Channel - - polling_interval = 1 - default_port = None - connection_errors = ( - virtual.Transport.connection_errors + ( - ResponseError, socket.error - ) - ) diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py deleted file mode 100644 index f5766edb..00000000 --- a/kombu/transport/SQS.py +++ /dev/null @@ -1,483 +0,0 @@ -"""Amazon SQS Transport. - -Amazon SQS transport module for Kombu. This package implements an AMQP-like -interface on top of Amazons SQS service, with the goal of being optimized for -high performance and reliability. - -The default settings for this module are focused now on high performance in -task queue situations where tasks are small, idempotent and run very fast. - -SQS Features supported by this transport: - Long Polling: - http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ - sqs-long-polling.html - - Long polling is enabled by setting the `wait_time_seconds` transport - option to a number > 1. Amazon supports up to 20 seconds. This is - enabled with 10 seconds by default. - - Batch API Actions: - http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ - sqs-batch-api.html - - The default behavior of the SQS Channel.drain_events() method is to - request up to the 'prefetch_count' messages on every request to SQS. - These messages are stored locally in a deque object and passed back - to the Transport until the deque is empty, before triggering a new - API call to Amazon. - - This behavior dramatically speeds up the rate that you can pull tasks - from SQS when you have short-running tasks (or a large number of workers). - - When a Celery worker has multiple queues to monitor, it will pull down - up to 'prefetch_count' messages from queueA and work on them all before - moving on to queueB. If queueB is empty, it will wait up until - 'polling_interval' expires before moving back and checking on queueA. -""" -import collections -import socket -import string - -from queue import Empty - -from vine import transform, ensure_promise, promise - -from kombu.async import get_event_loop -from kombu.async.aws import sqs as _asynsqs -from kombu.async.aws.ext import boto, exception -from kombu.async.aws.sqs.connection import AsyncSQSConnection, SQSConnection -from kombu.async.aws.sqs.ext import regions -from kombu.async.aws.sqs.message import Message -from kombu.log import get_logger -from kombu.utils import scheduling -from kombu.utils.encoding import bytes_to_str, safe_str -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property - -from . import virtual - -logger = get_logger(__name__) - -# dots are replaced by dash, all other punctuation -# replaced by underscore. -CHARS_REPLACE_TABLE = { - ord(c): 0x5f for c in string.punctuation if c not in '-_.' -} -CHARS_REPLACE_TABLE[0x2e] = 0x2d # '.' -> '-' - -#: SQS bulk get supports a maximum of 10 messages at a time. -SQS_MAX_MESSAGES = 10 - - -def maybe_int(x): - """Try to convert x' to int, or return x' if that fails.""" - try: - return int(x) - except ValueError: - return x - - -class Channel(virtual.Channel): - """SQS Channel.""" - - default_region = 'us-east-1' - default_visibility_timeout = 1800 # 30 minutes. - default_wait_time_seconds = 10 # up to 20 seconds max - domain_format = 'kombu%(vhost)s' - _asynsqs = None - _sqs = None - _queue_cache = {} - _noack_queues = set() - - def __init__(self, *args, hub=None, **kwargs): - if boto is None: - raise ImportError('boto is not installed') - super().__init__(*args, **kwargs) - - # SQS blows up if you try to create a new queue when one already - # exists but with a different visibility_timeout. This prepopulates - # the queue_cache to protect us from recreating - # queues that are known to already exist. - self._update_queue_cache(self.queue_name_prefix) - - # The drain_events() method stores extra messages in a local - # Deque object. This allows multiple messages to be requested from - # SQS at once for performance, but maintains the same external API - # to the caller of the drain_events() method. - self._queue_message_cache = collections.deque() - - self.hub = hub or get_event_loop() - - def _update_queue_cache(self, queue_name_prefix): - try: - queues = self.sqs.get_all_queues(prefix=queue_name_prefix) - except exception.SQSError as exc: - if exc.status == 403: - raise RuntimeError( - 'SQS authorization error, access_key={0}'.format( - self.sqs.access_key)) - raise - else: - self._queue_cache.update({ - queue.name: queue for queue in queues - }) - - def basic_consume(self, queue, no_ack, *args, **kwargs): - if no_ack: - self._noack_queues.add(queue) - if self.hub: - self._loop1(queue) - return super().basic_consume(queue, no_ack, *args, **kwargs) - - def basic_cancel(self, consumer_tag): - if consumer_tag in self._consumers: - queue = self._tag_to_queue[consumer_tag] - self._noack_queues.discard(queue) - return super().basic_cancel(consumer_tag) - - def drain_events(self, timeout=None): - """Return a single payload message from one of our queues. - - Raises: - Queue.Empty: if no messages available. - """ - # If we're not allowed to consume or have no consumers, raise Empty - if not self._consumers or not self.qos.can_consume(): - raise Empty() - - # At this point, go and get more messages from SQS - self._poll(self.cycle, self.connection._deliver, timeout=timeout) - - def _reset_cycle(self): - """Reset the consume cycle. - - Returns: - FairCycle: object that points to our _get_bulk() method - rather than the standard _get() method. This allows for - multiple messages to be returned at once from SQS ( - based on the prefetch limit). - """ - self._cycle = scheduling.FairCycle( - self._get_bulk, self._active_queues, Empty, - ) - - def entity_name(self, name, table=CHARS_REPLACE_TABLE): - """Format AMQP queue name into a legal SQS queue name.""" - return str(safe_str(name)).translate(table) - - def _new_queue(self, queue, **kwargs): - """Ensure a queue with given name exists in SQS.""" - if not isinstance(queue, str): - return queue - # Translate to SQS name for consistency with initial - # _queue_cache population. - queue = self.entity_name(self.queue_name_prefix + queue) - - # The SQS ListQueues method only returns 1000 queues. When you have - # so many queues, it's possible that the queue you are looking for is - # not cached. In this case, we could update the cache with the exact - # queue name first. - if queue not in self._queue_cache: - self._update_queue_cache(queue) - try: - return self._queue_cache[queue] - except KeyError: - q = self._queue_cache[queue] = self.sqs.create_queue( - queue, self.visibility_timeout, - ) - return q - - def _delete(self, queue, *args, **kwargs): - """delete queue by name.""" - super()._delete(queue) - self._queue_cache.pop(queue, None) - - def _put(self, queue, message, **kwargs): - """Put message onto queue.""" - q = self._new_queue(queue) - m = Message() - m.set_body(dumps(message)) - q.write(m) - - def _message_to_python(self, message, queue_name, queue): - payload = loads(bytes_to_str(message.get_body())) - if queue_name in self._noack_queues: - queue.delete_message(message) - else: - try: - properties = payload['properties'] - delivery_info = payload['properties']['delivery_info'] - except KeyError: - # json message not sent by kombu? - delivery_info = {} - properties = {'delivery_info': delivery_info} - payload.update({ - 'body': bytes_to_str(message.get_body()), - 'properties': properties, - }) - # set delivery tag to SQS receipt handle - delivery_info.update({ - 'sqs_message': message, 'sqs_queue': queue, - }) - properties['delivery_tag'] = message.receipt_handle - return payload - - def _messages_to_python(self, messages, queue): - """Convert a list of SQS Message objects into Payloads. - - This method handles converting SQS Message objects into - Payloads, and appropriately updating the queue depending on - the 'ack' settings for that queue. - - Arguments: - messages (SQSMessage): A list of SQS Message objects. - queue (str): Name representing the queue they came from. - - Returns: - List: A list of Payload objects - """ - q = self._new_queue(queue) - return [self._message_to_python(m, queue, q) for m in messages] - - def _get_bulk(self, queue, - max_if_unlimited=SQS_MAX_MESSAGES, callback=None): - """Try to retrieve multiple messages off ``queue``. - - Where :meth:`_get` returns a single Payload object, this method - returns a list of Payload objects. The number of objects returned - is determined by the total number of messages available in the queue - and the number of messages the QoS object allows (based on the - prefetch_count). - - Note: - Ignores QoS limits so caller is responsible for checking - that we are allowed to consume at least one message from the - queue. get_bulk will then ask QoS for an estimate of - the number of extra messages that we can consume. - - Arguments: - queue (str): The queue name to pull from. - - Returns: - List[Message] - """ - # drain_events calls `can_consume` first, consuming - # a token, so we know that we are allowed to consume at least - # one message. - maxcount = self._get_message_estimate() - if maxcount: - q = self._new_queue(queue) - messages = q.get_messages(num_messages=maxcount) - - if messages: - for msg in self._messages_to_python(messages, queue): - self.connection._deliver(msg, queue) - return - raise Empty() - - def _get(self, queue): - """Try to retrieve a single message off ``queue``.""" - q = self._new_queue(queue) - messages = q.get_messages(num_messages=1) - if messages: - return self._messages_to_python(messages, queue)[0] - raise Empty() - - def _loop1(self, queue, _=None): - self.hub.call_soon(self._schedule_queue, queue) - - def _schedule_queue(self, queue): - if queue in self._active_queues: - if self.qos.can_consume(): - self._get_bulk_async( - queue, callback=promise(self._loop1, (queue,)), - ) - else: - self._loop1(queue) - - def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES): - maxcount = self.qos.can_consume_max_estimate() - return min( - max_if_unlimited if maxcount is None else max(maxcount, 1), - max_if_unlimited, - ) - - def _get_bulk_async(self, queue, - max_if_unlimited=SQS_MAX_MESSAGES, callback=None): - maxcount = self._get_message_estimate() - if maxcount: - return self._get_async(queue, maxcount, callback=callback) - # Not allowed to consume, make sure to notify callback.. - callback = ensure_promise(callback) - callback([]) - return callback - - def _get_async(self, queue, count=1, callback=None): - q = self._new_queue(queue) - return self._get_from_sqs( - q, count=count, connection=self.asynsqs, - callback=transform(self._on_messages_ready, callback, q, queue), - ) - - def _on_messages_ready(self, queue, qname, messages): - if messages: - callbacks = self.connection._callbacks - for raw_message in messages: - message = self._message_to_python(raw_message, qname, queue) - callbacks[qname](message) - - def _get_from_sqs(self, queue, - count=1, connection=None, callback=None): - """Retrieve and handle messages from SQS. - - Uses long polling and returns :class:`~vine.promises.promise`. - """ - connection = connection if connection is not None else queue.connection - return connection.receive_message( - queue, number_messages=count, - wait_time_seconds=self.wait_time_seconds, - callback=callback, - ) - - def _restore(self, message, - unwanted_delivery_info=('sqs_message', 'sqs_queue')): - for unwanted_key in unwanted_delivery_info: - # Remove objects that aren't JSON serializable (Issue #1108). - message.delivery_info.pop(unwanted_key, None) - return super()._restore(message) - - def basic_ack(self, delivery_tag, multiple=False): - delivery_info = self.qos.get(delivery_tag).delivery_info - try: - queue = delivery_info['sqs_queue'] - except KeyError: - pass - else: - queue.delete_message(delivery_info['sqs_message']) - super().basic_ack(delivery_tag) - - def _size(self, queue): - """Return the number of messages in a queue.""" - return self._new_queue(queue).count() - - def _purge(self, queue): - """Delete all current messages in a queue.""" - q = self._new_queue(queue) - # SQS is slow at registering messages, so run for a few - # iterations to ensure messages are deleted. - size = 0 - for i in range(10): - size += q.count() - if not size: - break - q.clear() - return size - - def close(self): - super().close() - for conn in (self._sqs, self._asynsqs): - if conn: - try: - conn.close() - except AttributeError as exc: # FIXME ??? - if "can't set attribute" not in str(exc): - raise - - def _get_regioninfo(self, regions): - if self.regioninfo: - return self.regioninfo - if self.region: - for _r in regions: - if _r.name == self.region: - return _r - - def _aws_connect_to(self, fun, regions): - conninfo = self.conninfo - region = self._get_regioninfo(regions) - is_secure = self.is_secure if self.is_secure is not None else True - port = self.port if self.port is not None else conninfo.port - return fun(region=region, - aws_access_key_id=conninfo.userid, - aws_secret_access_key=conninfo.password, - is_secure=is_secure, - port=port) - - @property - def sqs(self): - if self._sqs is None: - self._sqs = self._aws_connect_to(SQSConnection, regions()) - return self._sqs - - @property - def asynsqs(self): - if self._asynsqs is None: - self._asynsqs = self._aws_connect_to( - AsyncSQSConnection, _asynsqs.regions(), - ) - return self._asynsqs - - @property - def conninfo(self): - return self.connection.client - - @property - def transport_options(self): - return self.connection.client.transport_options - - @cached_property - def visibility_timeout(self): - return (self.transport_options.get('visibility_timeout') or - self.default_visibility_timeout) - - @cached_property - def queue_name_prefix(self): - return self.transport_options.get('queue_name_prefix', '') - - @cached_property - def supports_fanout(self): - return False - - @cached_property - def region(self): - return self.transport_options.get('region') or self.default_region - - @cached_property - def regioninfo(self): - return self.transport_options.get('regioninfo') - - @cached_property - def is_secure(self): - return self.transport_options.get('is_secure') - - @cached_property - def port(self): - return self.transport_options.get('port') - - @cached_property - def wait_time_seconds(self): - return self.transport_options.get('wait_time_seconds', - self.default_wait_time_seconds) - - -class Transport(virtual.Transport): - """SQS Transport.""" - - Channel = Channel - - polling_interval = 1 - wait_time_seconds = 0 - default_port = None - connection_errors = ( - virtual.Transport.connection_errors + - (exception.SQSError, socket.error) - ) - channel_errors = ( - virtual.Transport.channel_errors + (exception.SQSDecodeError,) - ) - driver_type = 'sqs' - driver_name = 'sqs' - - implements = virtual.Transport.implements.extend( - async=True, - exchange_type=frozenset(['direct']), - ) diff --git a/kombu/transport/base.py b/kombu/transport/base.py index 0b1f3f35..059de82f 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -1,20 +1,19 @@ """Base transport interface.""" -import amqp.abstract +import amqp.types import errno import socket - +from typing import Any, Callable, ChannelT, Dict, Mapping, Sequence, Tuple from amqp.exceptions import RecoverableConnectionError - from kombu.exceptions import ChannelError, ConnectionError -from kombu.five import items from kombu.message import Message +from kombu.types import ClientT, ConnectionT, ConsumerT, ProducerT, TransportT from kombu.utils.functional import dictfilter from kombu.utils.objects import cached_property from kombu.utils.time import maybe_s_to_ms __all__ = ['Message', 'StdChannel', 'Management', 'Transport'] -RABBITMQ_QUEUE_ARGUMENTS = { # type: Mapping[str, Tuple[str, Callable]] +RABBITMQ_QUEUE_ARGUMENTS: Mapping[str, Tuple[str, Callable]] = { 'expires': ('x-expires', maybe_s_to_ms), 'message_ttl': ('x-message-ttl', maybe_s_to_ms), 'max_length': ('x-max-length', int), @@ -23,8 +22,7 @@ RABBITMQ_QUEUE_ARGUMENTS = { # type: Mapping[str, Tuple[str, Callable]] } -def to_rabbitmq_queue_arguments(arguments, **options): - # type: (Mapping, **Any) -> Dict +def to_rabbitmq_queue_arguments(arguments: Mapping, **options) -> Dict: """Convert queue arguments to RabbitMQ queue arguments. This is the implementation for Channel.prepare_queue_arguments @@ -52,40 +50,39 @@ def to_rabbitmq_queue_arguments(arguments, **options): """ prepared = dictfilter(dict( _to_rabbitmq_queue_argument(key, value) - for key, value in items(options) + for key, value in options.items() )) return dict(arguments, **prepared) if prepared else arguments -def _to_rabbitmq_queue_argument(key, value): - # type: (str, Any) -> Tuple[str, Any] +def _to_rabbitmq_queue_argument(key: str, value: Any) -> Tuple[str, Any]: opt, typ = RABBITMQ_QUEUE_ARGUMENTS[key] return opt, typ(value) if value is not None else value -def _LeftBlank(obj, method): +def _LeftBlank(obj: Any, method: str) -> Exception: return NotImplementedError( 'Transport {0.__module__}.{0.__name__} does not implement {1}'.format( obj.__class__, method)) -class StdChannel: +class StdChannel(amqp.types.ChannelT): """Standard channel base class.""" no_ack_consumers = None - def Consumer(self, *args, **kwargs): + def Consumer(self, *args, **kwargs) -> ConsumerT: from kombu.messaging import Consumer return Consumer(self, *args, **kwargs) - def Producer(self, *args, **kwargs): + def Producer(self, *args, **kwargs) -> ProducerT: from kombu.messaging import Producer return Producer(self, *args, **kwargs) - def get_bindings(self): + async def get_bindings(self) -> Sequence[Mapping]: raise _LeftBlank(self, 'get_bindings') - def after_reply_message_received(self, queue): + async def after_reply_message_received(self, queue: str) -> None: """Callback called after RPC reply received. Notes: @@ -94,42 +91,39 @@ class StdChannel: """ ... - def prepare_queue_arguments(self, arguments, **kwargs): + def prepare_queue_arguments(self, arguments: Mapping, **kwargs) -> Mapping: return arguments - def __enter__(self): + def __enter__(self) -> ChannelT: return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info) -> None: self.close() -amqp.abstract.Channel.register(StdChannel) - - class Management: """AMQP Management API (incomplete).""" - def __init__(self, transport): + def __init__(self, transport: TransportT): self.transport = transport - def get_bindings(self): + async def get_bindings(self) -> Sequence[Mapping]: raise _LeftBlank(self, 'get_bindings') class Implements(dict): """Helper class used to define transport features.""" - def __getattr__(self, key): + def __getattr__(self, key: str) -> bool: try: return self[key] except KeyError: raise AttributeError(key) - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: bool) -> None: self[key] = value - def extend(self, **kwargs): + def extend(self, **kwargs) -> 'Implements': return self.__class__(self, **kwargs) @@ -140,65 +134,66 @@ default_transport_capabilities = Implements( ) -class Transport: +class Transport(amqp.types.ConnectionT): """Base class for transports.""" - Management = Management + Management: type = Management #: The :class:`~kombu.Connection` owning this instance. - client = None + client: ClientT = None #: Set to True if :class:`~kombu.Connection` should pass the URL #: unmodified. - can_parse_url = False + can_parse_url: bool = False #: Default port used when no port has been specified. - default_port = None + default_port: int = None #: Tuple of errors that can happen due to connection failure. - connection_errors = (ConnectionError,) + connection_errors: Tuple[type, ...] = (ConnectionError,) #: Tuple of errors that can happen due to channel/method failure. - channel_errors = (ChannelError,) + channel_errors: Tuple[type, ...] = (ChannelError,) #: Type of driver, can be used to separate transports #: using the AMQP protocol (driver_type: 'amqp'), #: Redis (driver_type: 'redis'), etc... - driver_type = 'N/A' + driver_type: str = 'N/A' #: Name of driver library (e.g. 'py-amqp', 'redis'). - driver_name = 'N/A' + driver_name: str = 'N/A' __reader = None implements = default_transport_capabilities.extend() - def __init__(self, client, **kwargs): + def __init__(self, client: ClientT, **kwargs) -> None: self.client = client - def establish_connection(self): + async def establish_connection(self) -> ConnectionT: raise _LeftBlank(self, 'establish_connection') - def close_connection(self, connection): + async def close_connection(self, connection: ConnectionT) -> None: raise _LeftBlank(self, 'close_connection') - def create_channel(self, connection): + def create_channel(self, connection: ConnectionT) -> ChannelT: raise _LeftBlank(self, 'create_channel') - def close_channel(self, connection): + async def close_channel(self, connection: ConnectionT) -> None: raise _LeftBlank(self, 'close_channel') - def drain_events(self, connection, **kwargs): + async def drain_events(self, connection: ConnectionT, **kwargs) -> None: raise _LeftBlank(self, 'drain_events') - def heartbeat_check(self, connection, rate=2): + async def heartbeat_check(self, connection: ConnectionT, + rate: int = 2) -> None: ... - def driver_version(self): + def driver_version(self) -> str: return 'N/A' - def get_heartbeat_interval(self, connection): - return 0 + def get_heartbeat_interval(self, connection: ConnectionT) -> float: + return 0.0 def register_with_event_loop(self, connection, loop): ... @@ -206,7 +201,7 @@ class Transport: def unregister_from_event_loop(self, connection, loop): ... - def verify_connection(self, connection): + def verify_connection(self, connection: ConnectionT) -> bool: return True def _make_reader(self, connection, timeout=socket.timeout, @@ -228,7 +223,7 @@ class Transport: return _read - def qos_semantics_matches_spec(self, connection): + def qos_semantics_matches_spec(self, connection: ConnectionT) -> bool: return True def on_readable(self, connection, loop): @@ -238,20 +233,20 @@ class Transport: reader(loop) @property - def default_connection_params(self): + def default_connection_params(self) -> Mapping: return {} - def get_manager(self, *args, **kwargs): + def get_manager(self, *args, **kwargs) -> Management: return self.Management(self) @cached_property - def manager(self): + def manager(self) -> Management: return self.get_manager() @property - def supports_heartbeats(self): + def supports_heartbeats(self) -> bool: return self.implements.heartbeats @property - def supports_ev(self): + def supports_ev(self) -> bool: return self.implements.async diff --git a/kombu/transport/consul.py b/kombu/transport/consul.py deleted file mode 100644 index fa77e6e9..00000000 --- a/kombu/transport/consul.py +++ /dev/null @@ -1,291 +0,0 @@ -"""Consul Transport. - -It uses Consul.io's Key/Value store to transport messages in Queues - -It uses python-consul for talking to Consul's HTTP API -""" -import uuid -import socket - -from collections import defaultdict -from contextlib import contextmanager -from time import monotonic -from queue import Empty - -from kombu.exceptions import ChannelError -from kombu.log import get_logger -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property - -from . import virtual - -try: - import consul -except ImportError: - consul = None - -logger = get_logger('kombu.transport.consul') - -DEFAULT_PORT = 8500 -DEFAULT_HOST = 'localhost' - - -class LockError(Exception): - """An error occurred while trying to acquire the lock.""" - - -class Channel(virtual.Channel): - """Consul Channel class which talks to the Consul Key/Value store.""" - - prefix = 'kombu' - index = None - timeout = '10s' - session_ttl = 30 - - def __init__(self, *args, **kwargs): - if consul is None: - raise ImportError('Missing python-consul library') - - super().__init__(*args, **kwargs) - - port = self.connection.client.port or self.connection.default_port - host = self.connection.client.hostname or DEFAULT_HOST - - logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout) - - self.queues = defaultdict(dict) - - self.client = consul.Consul(host=host, port=int(port)) - - def _lock_key(self, queue): - return '{0}/{1}.lock'.format(self.prefix, queue) - - def _key_prefix(self, queue): - return '{0}/{1}'.format(self.prefix, queue) - - def _get_or_create_session(self, queue): - """Get or create consul session. - - Try to renew the session if it exists, otherwise create a new - session in Consul. - - This session is used to acquire a lock inside Consul so that we achieve - read-consistency between the nodes. - - Arguments: - queue (str): The name of the Queue. - - Returns: - str: The ID of the session. - """ - try: - session_id = self.queues[queue]['session_id'] - except KeyError: - session_id = None - return (self._renew_existing_session(session_id) - if session_id is not None else self._create_new_session()) - - def _renew_existing_session(self, session_id): - logger.debug('Trying to renew existing session %s', session_id) - session = self.client.session.renew(session_id=session_id) - return session.get('ID') - - def _create_new_session(self): - logger.debug('Creating session %s with TTL %s', - self.lock_name, self.session_ttl) - session_id = self.client.session.create( - name=self.lock_name, ttl=self.session_ttl) - logger.debug('Created session %s with id %s', - self.lock_name, session_id) - return session_id - - @contextmanager - def _queue_lock(self, queue, raising=LockError): - """Try to acquire a lock on the Queue. - - It does so by creating a object called 'lock' which is locked by the - current session.. - - This way other nodes are not able to write to the lock object which - means that they have to wait before the lock is released. - - Arguments: - queue (str): The name of the Queue. - raising (Exception): Set custom lock error class. - - Raises: - LockError: if the lock cannot be acquired. - - Returns: - bool: success? - """ - self._acquire_lock(queue, raising=raising) - try: - yield - finally: - self._release_lock(queue) - - def _acquire_lock(self, queue, raising=LockError): - session_id = self._get_or_create_session(queue) - lock_key = self._lock_key(queue) - - logger.debug('Trying to create lock object %s with session %s', - lock_key, session_id) - - if self.client.kv.put(key=lock_key, - acquire=session_id, - value=self.lock_name): - self.queues[queue]['session_id'] = session_id - return - logger.info('Could not acquire lock on key %s', lock_key) - raise raising() - - def _release_lock(self, queue): - """Try to release a lock. - - It does so by simply removing the lock key in Consul. - - Arguments: - queue (str): The name of the queue we want to release - the lock from. - """ - logger.debug('Removing lock key %s', self._lock_key(queue)) - self.client.kv.delete(key=self._lock_key(queue)) - - def _destroy_session(self, queue): - """Destroy a previously created Consul session. - - Will release all locks it still might hold. - - Arguments: - queue (str): The name of the Queue. - """ - logger.debug('Destroying session %s', self.queues[queue]['session_id']) - self.client.session.destroy(self.queues[queue]['session_id']) - - def _new_queue(self, queue, **_): - self.queues[queue] = {'session_id': None} - return self.client.kv.put(key=self._key_prefix(queue), value=None) - - def _delete(self, queue, *args, **_): - self._destroy_session(queue) - self.queues.pop(queue, None) - self._purge(queue) - - def _put(self, queue, payload, **_): - """Put `message` onto `queue`. - - This simply writes a key to the K/V store of Consul - """ - key = '{0}/msg/{1}_{2}'.format( - self._key_prefix(queue), - int(round(monotonic() * 1000)), - uuid.uuid4(), - ) - if not self.client.kv.put(key=key, value=dumps(payload), cas=0): - raise ChannelError('Cannot add key {0!r} to consul'.format(key)) - - def _get(self, queue, timeout=None): - """Get the first available message from the queue. - - Before it does so it acquires a lock on the Key/Value store so - only one node reads at the same time. This is for read consistency - """ - with self._queue_lock(queue, raising=Empty): - key = '{0}/msg/'.format(self._key_prefix(queue)) - logger.debug('Fetching key %s with index %s', key, self.index) - self.index, data = self.client.kv.get( - key=key, recurse=True, - index=self.index, wait=self.timeout, - ) - - try: - if data is None: - raise Empty() - - logger.debug('Removing key %s with modifyindex %s', - data[0]['Key'], data[0]['ModifyIndex']) - - self.client.kv.delete(key=data[0]['Key'], - cas=data[0]['ModifyIndex']) - - return loads(data[0]['Value']) - except TypeError: - pass - - raise Empty() - - def _purge(self, queue): - self._destroy_session(queue) - return self.client.kv.delete( - key='{0}/msg/'.format(self._key_prefix(queue)), - recurse=True, - ) - - def _size(self, queue): - size = 0 - try: - key = '{0}/msg/'.format(self._key_prefix(queue)) - logger.debug('Fetching key recursively %s with index %s', - key, self.index) - self.index, data = self.client.kv.get( - key=key, recurse=True, - index=self.index, wait=self.timeout, - ) - size = len(data) - except TypeError: - pass - - logger.debug('Found %s keys under %s with index %s', - size, key, self.index) - return size - - @cached_property - def lock_name(self): - return '{0}'.format(socket.gethostname()) - - -class Transport(virtual.Transport): - """Consul K/V storage Transport for Kombu.""" - - Channel = Channel - - default_port = DEFAULT_PORT - driver_type = 'consul' - driver_name = 'consul' - - def __init__(self, *args, **kwargs): - if consul is None: - raise ImportError('Missing python-consul library') - - super().__init__(*args, **kwargs) - - self.connection_errors = ( - virtual.Transport.connection_errors + ( - consul.ConsulException, consul.base.ConsulException - ) - ) - - self.channel_errors = ( - virtual.Transport.channel_errors + ( - consul.ConsulException, consul.base.ConsulException - ) - ) - - def verify_connection(self, connection): - port = connection.client.port or self.default_port - host = connection.client.hostname or DEFAULT_HOST - - logger.debug('Verify Consul connection to %s:%s', host, port) - - try: - client = consul.Consul(host=host, port=int(port)) - client.agent.self() - return True - except ValueError: - pass - - return False - - def driver_version(self): - return consul.__version__ diff --git a/kombu/transport/etcd.py b/kombu/transport/etcd.py deleted file mode 100644 index 4d5c652a..00000000 --- a/kombu/transport/etcd.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Etcd Transport. - -It uses Etcd as a store to transport messages in Queues - -It uses python-etcd for talking to Etcd's HTTP API -""" -from __future__ import absolute_import, unicode_literals - -import os -import socket - -from collections import defaultdict -from contextlib import contextmanager - -from kombu.exceptions import ChannelError -from kombu.five import Empty -from kombu.log import get_logger -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property - -from . import virtual - -try: - import etcd -except ImportError: - etcd = None - -logger = get_logger('kombu.transport.etcd') - -DEFAULT_PORT = 2379 -DEFAULT_HOST = 'localhost' - - -class Channel(virtual.Channel): - """Etcd Channel class which talks to the Etcd.""" - - prefix = 'kombu' - index = None - timeout = 10 - session_ttl = 30 - lock_ttl = 10 - - def __init__(self, *args, **kwargs): - if etcd is None: - raise ImportError('Missing python-etcd library') - - super(Channel, self).__init__(*args, **kwargs) - - port = self.connection.client.port or self.connection.default_port - host = self.connection.client.hostname or DEFAULT_HOST - - logger.debug('Host: %s Port: %s Timeout: %s', host, port, self.timeout) - - self.queues = defaultdict(dict) - - self.client = etcd.Client(host=host, port=int(port)) - - def _key_prefix(self, queue): - """Create and return the `queue` with the proper prefix. - - Arguments: - queue (str): The name of the queue. - """ - return '{0}/{1}'.format(self.prefix, queue) - - @contextmanager - def _queue_lock(self, queue): - """Try to acquire a lock on the Queue. - - It does so by creating a object called 'lock' which is locked by the - current session.. - - This way other nodes are not able to write to the lock object which - means that they have to wait before the lock is released. - - Arguments: - queue (str): The name of the queue. - """ - lock = etcd.Lock(self.client, queue) - lock._uuid = self.lock_value - logger.debug('Acquiring lock {0}'.format(lock.name)) - lock.acquire(blocking=True, lock_ttl=self.lock_ttl) - try: - yield - finally: - logger.debug('Releasing lock {0}'.format(lock.name)) - lock.release() - - def _new_queue(self, queue, **_): - """Create a new `queue` if the `queue` doesn't already exist. - - Arguments: - queue (str): The name of the queue. - """ - self.queues[queue] = queue - with self._queue_lock(queue): - try: - return self.client.write( - key=self._key_prefix(queue), dir=True, value=None) - except etcd.EtcdNotFile: - logger.debug('Queue "{0}" already exists'.format(queue)) - return self.client.read(key=self._key_prefix(queue)) - - def _has_queue(self, queue, **kwargs): - """Verify that queue exists. - - Returns: - bool: Should return :const:`True` if the queue exists - or :const:`False` otherwise. - """ - try: - self.client.read(self._key_prefix(queue)) - return True - except etcd.EtcdKeyNotFound: - return False - - def _delete(self, queue, *args, **_): - """Delete a `queue`. - - Arguments: - queue (str): The name of the queue. - """ - self.queues.pop(queue, None) - self._purge(queue) - - def _put(self, queue, payload, **_): - """Put `message` onto `queue`. - - This simply writes a key to the Etcd store - - Arguments: - queue (str): The name of the queue. - payload (dict): Message data which will be dumped to etcd. - """ - with self._queue_lock(queue): - key = self._key_prefix(queue) - if not self.client.write( - key=key, - value=dumps(payload), - append=True): - raise ChannelError('Cannot add key {0!r} to etcd'.format(key)) - - def _get(self, queue, timeout=None): - """Get the first available message from the queue. - - Before it does so it acquires a lock on the store so - only one node reads at the same time. This is for read consistency - - Arguments: - queue (str): The name of the queue. - timeout (int): Optional seconds to wait for a response. - """ - with self._queue_lock(queue): - key = self._key_prefix(queue) - logger.debug('Fetching key %s with index %s', key, self.index) - - try: - result = self.client.read( - key=key, recursive=True, - index=self.index, timeout=self.timeout) - - if result is None: - raise Empty() - - item = result._children[-1] - logger.debug('Removing key {0}'.format(item['key'])) - - msg_content = loads(item['value']) - self.client.delete(key=item['key']) - return msg_content - except (TypeError, IndexError, etcd.EtcdError) as error: - logger.debug('_get failed: {0}:{1}'.format(type(error), error)) - - raise Empty() - - def _purge(self, queue): - """Remove all `message`s from a `queue`. - - Arguments: - queue (str): The name of the queue. - """ - with self._queue_lock(queue): - key = self._key_prefix(queue) - logger.debug('Purging queue at key {0}'.format(key)) - return self.client.delete(key=key, recursive=True) - - def _size(self, queue): - """Return the size of the `queue`. - - Arguments: - queue (str): The name of the queue. - """ - with self._queue_lock(queue): - size = 0 - try: - key = self._key_prefix(queue) - logger.debug('Fetching key recursively %s with index %s', - key, self.index) - result = self.client.read( - key=key, recursive=True, - index=self.index) - size = len(result._children) - except TypeError: - pass - - logger.debug('Found %s keys under %s with index %s', - size, key, self.index) - return size - - @cached_property - def lock_value(self): - return '{0}.{1}'.format(socket.gethostname(), os.getpid()) - - -class Transport(virtual.Transport): - """Etcd storage Transport for Kombu.""" - - Channel = Channel - - default_port = DEFAULT_PORT - driver_type = 'etcd' - driver_name = 'python-etcd' - polling_interval = 3 - - implements = virtual.Transport.implements.extend( - exchange_type=frozenset(['direct'])) - - def __init__(self, *args, **kwargs): - """Create a new instance of etcd.Transport.""" - if etcd is None: - raise ImportError('Missing python-etcd library') - - super(Transport, self).__init__(*args, **kwargs) - - self.connection_errors = ( - virtual.Transport.connection_errors + (etcd.EtcdError, ) - ) - - self.channel_errors = ( - virtual.Transport.channel_errors + (etcd.EtcdError, ) - ) - - def verify_connection(self, connection): - """Verify the connection works.""" - port = connection.client.port or self.default_port - host = connection.client.hostname or DEFAULT_HOST - - logger.debug('Verify Etcd connection to %s:%s', host, port) - - try: - etcd.Client(host=host, port=int(port)) - return True - except ValueError: - pass - - return False - - def driver_version(self): - """Return the version of the etcd library. - - .. note:: - - python-etcd has no __version__. This is a workaround. - """ - try: - import pip.commands.freeze - for x in pip.commands.freeze.freeze(): - if x.startswith('python-etcd'): - return x.split('==')[1] - except (ImportError, IndexError): - logger.warn('Unable to find the python-etcd version.') - return 'Unknown' diff --git a/kombu/transport/filesystem.py b/kombu/transport/filesystem.py deleted file mode 100644 index 2a5737fb..00000000 --- a/kombu/transport/filesystem.py +++ /dev/null @@ -1,195 +0,0 @@ -"""File-system Transport. - -Transport using the file-system as the message store. -""" -import os -import shutil -import uuid -import tempfile - -from time import monotonic -from queue import Empty - -from . import virtual -from kombu.exceptions import ChannelError -from kombu.utils.encoding import bytes_to_str, str_to_bytes -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property - - -VERSION = (1, 0, 0) -__version__ = '.'.join(map(str, VERSION)) - -# needs win32all to work on Windows -if os.name == 'nt': - - import win32con - import win32file - import pywintypes - - LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK - # 0 is the default - LOCK_SH = 0 # noqa - LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY # noqa - __overlapped = pywintypes.OVERLAPPED() - - def lock(file, flags): - """Create file lock.""" - hfile = win32file._get_osfhandle(file.fileno()) - win32file.LockFileEx(hfile, flags, 0, 0xffff0000, __overlapped) - - def unlock(file): - """Remove file lock.""" - hfile = win32file._get_osfhandle(file.fileno()) - win32file.UnlockFileEx(hfile, 0, 0xffff0000, __overlapped) - -elif os.name == 'posix': - - import fcntl - from fcntl import LOCK_EX, LOCK_SH, LOCK_NB # noqa - - def lock(file, flags): # noqa - """Create file lock.""" - fcntl.flock(file.fileno(), flags) - - def unlock(file): # noqa - """Remove file lock.""" - fcntl.flock(file.fileno(), fcntl.LOCK_UN) -else: - raise RuntimeError( - 'Filesystem plugin only defined for NT and POSIX platforms') - - -class Channel(virtual.Channel): - """Filesystem Channel.""" - - def _put(self, queue, payload, **kwargs): - """Put `message` onto `queue`.""" - filename = '%s_%s.%s.msg' % (int(round(monotonic() * 1000)), - uuid.uuid4(), queue) - filename = os.path.join(self.data_folder_out, filename) - - try: - f = open(filename, 'wb') - lock(f, LOCK_EX) - f.write(str_to_bytes(dumps(payload))) - except (IOError, OSError): - raise ChannelError( - 'Cannot add file {0!r} to directory'.format(filename)) - finally: - unlock(f) - f.close() - - def _get(self, queue): - """Get next message from `queue`.""" - queue_find = '.' + queue + '.msg' - folder = os.listdir(self.data_folder_in) - folder = sorted(folder) - while len(folder) > 0: - filename = folder.pop(0) - - # only handle message for the requested queue - if filename.find(queue_find) < 0: - continue - - if self.store_processed: - processed_folder = self.processed_folder - else: - processed_folder = tempfile.gettempdir() - - try: - # move the file to the tmp/processed folder - shutil.move(os.path.join(self.data_folder_in, filename), - processed_folder) - except IOError: - pass # file could be locked, or removed in meantime so ignore - - filename = os.path.join(processed_folder, filename) - try: - f = open(filename, 'rb') - payload = f.read() - f.close() - if not self.store_processed: - os.remove(filename) - except (IOError, OSError): - raise ChannelError( - 'Cannot read file {0!r} from queue.'.format(filename)) - - return loads(bytes_to_str(payload)) - - raise Empty() - - def _purge(self, queue): - """Remove all messages from `queue`.""" - count = 0 - queue_find = '.' + queue + '.msg' - - folder = os.listdir(self.data_folder_in) - while len(folder) > 0: - filename = folder.pop() - try: - # only purge messages for the requested queue - if filename.find(queue_find) < 0: - continue - - filename = os.path.join(self.data_folder_in, filename) - os.remove(filename) - - count += 1 - - except OSError: - # we simply ignore its existence, as it was probably - # processed by another worker - pass - - return count - - def _size(self, queue): - """Return the number of messages in `queue` as an :class:`int`.""" - count = 0 - - queue_find = '.{0}.msg'.format(queue) - folder = os.listdir(self.data_folder_in) - while len(folder) > 0: - filename = folder.pop() - - # only handle message for the requested queue - if filename.find(queue_find) < 0: - continue - - count += 1 - - return count - - @property - def transport_options(self): - return self.connection.client.transport_options - - @cached_property - def data_folder_in(self): - return self.transport_options.get('data_folder_in', 'data_in') - - @cached_property - def data_folder_out(self): - return self.transport_options.get('data_folder_out', 'data_out') - - @cached_property - def store_processed(self): - return self.transport_options.get('store_processed', False) - - @cached_property - def processed_folder(self): - return self.transport_options.get('processed_folder', 'processed') - - -class Transport(virtual.Transport): - """Filesystem Transport.""" - - Channel = Channel - - default_port = 0 - driver_type = 'filesystem' - driver_name = 'filesystem' - - def driver_version(self): - return 'N/A' diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py deleted file mode 100644 index 2ea5b779..00000000 --- a/kombu/transport/librabbitmq.py +++ /dev/null @@ -1,182 +0,0 @@ -"""`librabbitmq`_ transport. - -.. _`librabbitmq`: http://pypi.python.org/librabbitmq/ -""" -import os -import socket -import warnings - -import librabbitmq as amqp -from librabbitmq import ChannelError, ConnectionError - -from kombu.utils.amq_manager import get_manager -from kombu.utils.text import version_string_as_tuple - -from . import base -from .base import to_rabbitmq_queue_arguments - -W_VERSION = """ - librabbitmq version too old to detect RabbitMQ version information - so make sure you are using librabbitmq 1.5 when using rabbitmq > 3.3 -""" -DEFAULT_PORT = 5672 -DEFAULT_SSL_PORT = 5671 - -NO_SSL_ERROR = """\ -ssl not supported by librabbitmq, please use pyamqp:// or stunnel\ -""" - - -class Message(base.Message): - """AMQP Message (librabbitmq).""" - - def __init__(self, channel, props, info, body): - super().__init__( - channel=channel, - body=body, - delivery_info=info, - properties=props, - delivery_tag=info.get('delivery_tag'), - content_type=props.get('content_type'), - content_encoding=props.get('content_encoding'), - headers=props.get('headers')) - - -class Channel(amqp.Channel, base.StdChannel): - """AMQP Channel (librabbitmq).""" - - Message = Message - - def prepare_message(self, body, priority=None, - content_type=None, content_encoding=None, - headers=None, properties=None): - """Encapsulate data into a AMQP message.""" - properties = properties if properties is not None else {} - properties.update({'content_type': content_type, - 'content_encoding': content_encoding, - 'headers': headers, - 'priority': priority}) - return body, properties - - def prepare_queue_arguments(self, arguments, **kwargs): - return to_rabbitmq_queue_arguments(arguments, **kwargs) - - -class Connection(amqp.Connection): - """AMQP Connection (librabbitmq).""" - - Channel = Channel - Message = Message - - -class Transport(base.Transport): - """AMQP Transport (librabbitmq).""" - - Connection = Connection - - default_port = DEFAULT_PORT - default_ssl_port = DEFAULT_SSL_PORT - - connection_errors = ( - base.Transport.connection_errors + ( - ConnectionError, socket.error, IOError, OSError) - ) - channel_errors = ( - base.Transport.channel_errors + (ChannelError,) - ) - driver_type = 'amqp' - driver_name = 'librabbitmq' - - implements = base.Transport.implements.extend( - async=True, - heartbeats=False, - ) - - def __init__(self, client, - default_port=None, default_ssl_port=None, **kwargs): - self.client = client - self.default_port = default_port or self.default_port - self.default_ssl_port = default_ssl_port or self.default_ssl_port - self.__reader = None - - def driver_version(self): - return amqp.__version__ - - def create_channel(self, connection): - return connection.channel() - - def drain_events(self, connection, **kwargs): - return connection.drain_events(**kwargs) - - def establish_connection(self): - """Establish connection to the AMQP broker.""" - conninfo = self.client - for name, default_value in self.default_connection_params.items(): - if not getattr(conninfo, name, None): - setattr(conninfo, name, default_value) - if conninfo.ssl: - raise NotImplementedError(NO_SSL_ERROR) - opts = dict({ - 'host': conninfo.host, - 'userid': conninfo.userid, - 'password': conninfo.password, - 'virtual_host': conninfo.virtual_host, - 'login_method': conninfo.login_method, - 'insist': conninfo.insist, - 'ssl': conninfo.ssl, - 'connect_timeout': conninfo.connect_timeout, - }, **conninfo.transport_options or {}) - conn = self.Connection(**opts) - conn.client = self.client - self.client.drain_events = conn.drain_events - return conn - - def close_connection(self, connection): - """Close the AMQP broker connection.""" - self.client.drain_events = None - connection.close() - - def _collect(self, connection): - if connection is not None: - for channel in connection.channels.values(): - channel.connection = None - try: - os.close(connection.fileno()) - except OSError: - pass - connection.channels.clear() - connection.callbacks.clear() - self.client.drain_events = None - self.client = None - - def verify_connection(self, connection): - return connection.connected - - def register_with_event_loop(self, connection, loop): - loop.add_reader( - connection.fileno(), self.on_readable, connection, loop, - ) - - def get_manager(self, *args, **kwargs): - return get_manager(self.client, *args, **kwargs) - - def qos_semantics_matches_spec(self, connection): - try: - props = connection.server_properties - except AttributeError: - warnings.warn(UserWarning(W_VERSION)) - else: - if props.get('product') == 'RabbitMQ': - return version_string_as_tuple(props['version']) < (3, 3) - return True - - @property - def default_connection_params(self): - return { - 'userid': 'guest', - 'password': 'guest', - 'port': (self.default_ssl_port if self.client.ssl - else self.default_port), - 'hostname': 'localhost', - 'login_method': 'AMQPLAIN', - } diff --git a/kombu/transport/memory.py b/kombu/transport/memory.py deleted file mode 100644 index e3b4e441..00000000 --- a/kombu/transport/memory.py +++ /dev/null @@ -1,76 +0,0 @@ -"""In-memory transport.""" -from queue import Queue - -from . import base -from . import virtual - - -class Channel(virtual.Channel): - """In-memory Channel.""" - - queues = {} - do_restore = False - supports_fanout = True - - def _has_queue(self, queue, **kwargs): - return queue in self.queues - - def _new_queue(self, queue, **kwargs): - if queue not in self.queues: - self.queues[queue] = Queue() - - def _get(self, queue, timeout=None): - return self._queue_for(queue).get(block=False) - - def _queue_for(self, queue): - if queue not in self.queues: - self.queues[queue] = Queue() - return self.queues[queue] - - def _queue_bind(self, *args): - ... - - def _put_fanout(self, exchange, message, routing_key=None, **kwargs): - for queue in self._lookup(exchange, routing_key): - self._queue_for(queue).put(message) - - def _put(self, queue, message, **kwargs): - self._queue_for(queue).put(message) - - def _size(self, queue): - return self._queue_for(queue).qsize() - - def _delete(self, queue, *args, **kwargs): - self.queues.pop(queue, None) - - def _purge(self, queue): - q = self._queue_for(queue) - size = q.qsize() - q.queue.clear() - return size - - def close(self): - super().close() - for queue in self.queues.values(): - queue.empty() - self.queues = {} - - def after_reply_message_received(self, queue): - ... - - -class Transport(virtual.Transport): - """In-memory Transport.""" - - Channel = Channel - - #: memory backend state is global. - state = virtual.BrokerState() - - implements = base.Transport.implements - - driver_type = 'memory' - driver_name = 'memory' - - def driver_version(self): - return 'N/A' diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py deleted file mode 100644 index 7c8b5bb3..00000000 --- a/kombu/transport/mongodb.py +++ /dev/null @@ -1,460 +0,0 @@ -"""MongoDB transport. - -:copyright: (c) 2010 - 2013 by Flavio Percoco Premoli. -:license: BSD, see LICENSE for more details. -""" -import datetime - -from queue import Empty - -import pymongo -from pymongo import errors -from pymongo import MongoClient, uri_parser -from pymongo.cursor import CursorType - -from kombu.exceptions import VersionMismatch -from kombu.utils.compat import _detect_environment -from kombu.utils.encoding import bytes_to_str -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property - -from . import virtual - -E_SERVER_VERSION = """\ -Kombu requires MongoDB version 1.3+ (server is {0})\ -""" - -E_NO_TTL_INDEXES = """\ -Kombu requires MongoDB version 2.2+ (server is {0}) for TTL indexes support\ -""" - - -class BroadcastCursor: - """Cursor for broadcast queues.""" - - def __init__(self, cursor): - self._cursor = cursor - - self.purge(rewind=False) - - def get_size(self): - return self._cursor.count() - self._offset - - def close(self): - self._cursor.close() - - def purge(self, rewind=True): - if rewind: - self._cursor.rewind() - - # Fast forward the cursor past old events - self._offset = self._cursor.count() - self._cursor = self._cursor.skip(self._offset) - - def __iter__(self): - return self - - def __next__(self): - while True: - try: - msg = next(self._cursor) - except pymongo.errors.OperationFailure as exc: - # In some cases tailed cursor can become invalid - # and have to be reinitalized - if 'not valid at server' in str(exc): - self.purge() - - continue - - raise - else: - break - - self._offset += 1 - - return msg - next = __next__ - - -class Channel(virtual.Channel): - """MongoDB Channel.""" - - supports_fanout = True - - # Mutable container. Shared by all class instances - _fanout_queues = {} - - # Options - ssl = False - ttl = False - connect_timeout = None - capped_queue_size = 100000 - calc_queue_size = True - - default_hostname = '127.0.0.1' - default_port = 27017 - default_database = 'kombu_default' - - messages_collection = 'messages' - routing_collection = 'messages.routing' - broadcast_collection = 'messages.broadcast' - queues_collection = 'messages.queues' - - from_transport_options = (virtual.Channel.from_transport_options + ( - 'connect_timeout', 'ssl', 'ttl', 'capped_queue_size', - 'default_hostname', 'default_port', 'default_database', - 'messages_collection', 'routing_collection', - 'broadcast_collection', 'queues_collection', - 'calc_queue_size', - )) - - def __init__(self, *vargs, **kwargs): - super().__init__(*vargs, **kwargs) - - self._broadcast_cursors = {} - - # Evaluate connection - self.client - - # AbstractChannel/Channel interface implementation - - def _new_queue(self, queue, **kwargs): - if self.ttl: - self.queues.update( - {'_id': queue}, - {'_id': queue, - 'options': kwargs, - 'expire_at': self._get_expire(kwargs, 'x-expires')}, - upsert=True) - - def _get(self, queue): - if queue in self._fanout_queues: - try: - msg = next(self._get_broadcast_cursor(queue)) - except StopIteration: - msg = None - else: - msg = self.messages.find_and_modify( - query={'queue': queue}, - sort=[('priority', pymongo.ASCENDING)], - remove=True, - ) - - if self.ttl: - self._update_queues_expire(queue) - - if msg is None: - raise Empty() - - return loads(bytes_to_str(msg['payload'])) - - def _size(self, queue): - # Do not calculate actual queue size if requested - # for performance considerations - if not self.calc_queue_size: - return super()._size(queue) - - if queue in self._fanout_queues: - return self._get_broadcast_cursor(queue).get_size() - - return self.messages.find({'queue': queue}).count() - - def _put(self, queue, message, **kwargs): - data = { - 'payload': dumps(message), - 'queue': queue, - 'priority': self._get_message_priority(message, reverse=True) - } - - if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-message-ttl') - - self.messages.insert(data) - - def _put_fanout(self, exchange, message, routing_key, **kwargs): - self.broadcast.insert({'payload': dumps(message), - 'queue': exchange}) - - def _purge(self, queue): - size = self._size(queue) - - if queue in self._fanout_queues: - self._get_broadcast_cursor(queue).purge() - else: - self.messages.remove({'queue': queue}) - - return size - - def get_table(self, exchange): - localRoutes = frozenset(self.state.exchanges[exchange]['table']) - brokerRoutes = self.routing.find( - {'exchange': exchange} - ) - - return localRoutes | frozenset( - (r['routing_key'], r['pattern'], r['queue']) - for r in brokerRoutes - ) - - def _queue_bind(self, exchange, routing_key, pattern, queue): - if self.typeof(exchange).type == 'fanout': - self._create_broadcast_cursor( - exchange, routing_key, pattern, queue) - self._fanout_queues[queue] = exchange - - lookup = { - 'exchange': exchange, - 'queue': queue, - 'routing_key': routing_key, - 'pattern': pattern, - } - - data = lookup.copy() - - if self.ttl: - data['expire_at'] = self._get_expire(queue, 'x-expires') - - self.routing.update(lookup, data, upsert=True) - - def queue_delete(self, queue, **kwargs): - self.routing.remove({'queue': queue}) - - if self.ttl: - self.queues.remove({'_id': queue}) - - super().queue_delete(queue, **kwargs) - - if queue in self._fanout_queues: - try: - cursor = self._broadcast_cursors.pop(queue) - except KeyError: - pass - else: - cursor.close() - - self._fanout_queues.pop(queue) - - # Implementation details - - def _parse_uri(self, scheme='mongodb://'): - # See mongodb uri documentation: - # http://docs.mongodb.org/manual/reference/connection-string/ - client = self.connection.client - hostname = client.hostname - - if not hostname.startswith(scheme): - hostname = scheme + hostname - - if not hostname[len(scheme):]: - hostname += self.default_hostname - - if client.userid and '@' not in hostname: - head, tail = hostname.split('://') - - credentials = client.userid - if client.password: - credentials += ':' + client.password - - hostname = head + '://' + credentials + '@' + tail - - port = client.port if client.port else self.default_port - - parsed = uri_parser.parse_uri(hostname, port) - - dbname = parsed['database'] or client.virtual_host - - if dbname in ('/', None): - dbname = self.default_database - - options = { - 'auto_start_request': True, - 'ssl': self.ssl, - 'connectTimeoutMS': (int(self.connect_timeout * 1000) - if self.connect_timeout else None), - } - options.update(parsed['options']) - - return hostname, dbname, options - - def _prepare_client_options(self, options): - if pymongo.version_tuple >= (3,): - options.pop('auto_start_request', None) - return options - - def _open(self, scheme='mongodb://'): - hostname, dbname, options = self._parse_uri(scheme=scheme) - - conf = self._prepare_client_options(options) - conf['host'] = hostname - - env = _detect_environment() - if env == 'gevent': - from gevent import monkey - monkey.patch_all() - elif env == 'eventlet': - from eventlet import monkey_patch - monkey_patch() - - mongoconn = MongoClient(**conf) - database = mongoconn[dbname] - - version_str = mongoconn.server_info()['version'] - version = tuple(map(int, version_str.split('.'))) - - if version < (1, 3): - raise VersionMismatch(E_SERVER_VERSION.format(version_str)) - elif self.ttl and version < (2, 2): - raise VersionMismatch(E_NO_TTL_INDEXES.format(version_str)) - - return database - - def _create_broadcast(self, database): - """Create capped collection for broadcast messages.""" - if self.broadcast_collection in database.collection_names(): - return - - database.create_collection(self.broadcast_collection, - size=self.capped_queue_size, - capped=True) - - def _ensure_indexes(self, database): - """Ensure indexes on collections.""" - messages = database[self.messages_collection] - messages.ensure_index( - [('queue', 1), ('priority', 1), ('_id', 1)], background=True, - ) - - database[self.broadcast_collection].ensure_index([('queue', 1)]) - - routing = database[self.routing_collection] - routing.ensure_index([('queue', 1), ('exchange', 1)]) - - if self.ttl: - messages.ensure_index([('expire_at', 1)], expireAfterSeconds=0) - routing.ensure_index([('expire_at', 1)], expireAfterSeconds=0) - - database[self.queues_collection].ensure_index( - [('expire_at', 1)], expireAfterSeconds=0) - - def _create_client(self): - """Actualy creates connection.""" - database = self._open() - self._create_broadcast(database) - self._ensure_indexes(database) - - return database - - @cached_property - def client(self): - return self._create_client() - - @cached_property - def messages(self): - return self.client[self.messages_collection] - - @cached_property - def routing(self): - return self.client[self.routing_collection] - - @cached_property - def broadcast(self): - return self.client[self.broadcast_collection] - - @cached_property - def queues(self): - return self.client[self.queues_collection] - - def _get_broadcast_cursor(self, queue): - try: - return self._broadcast_cursors[queue] - except KeyError: - # Cursor may be absent when Channel created more than once. - # _fanout_queues is a class-level mutable attribute so it's - # shared over all Channel instances. - return self._create_broadcast_cursor( - self._fanout_queues[queue], None, None, queue, - ) - - def _create_broadcast_cursor(self, exchange, routing_key, pattern, queue): - if pymongo.version_tuple >= (3, ): - query = dict( - filter={'queue': exchange}, - cursor_type=CursorType.TAILABLE - ) - else: - query = dict( - query={'queue': exchange}, - tailable=True - ) - - cursor = self.broadcast.find(**query) - ret = self._broadcast_cursors[queue] = BroadcastCursor(cursor) - return ret - - def _get_expire(self, queue, argument): - """Get expiration header named `argument` of queue definition. - - Note: - `queue` must be either queue name or options itself. - """ - if isinstance(queue, str): - doc = self.queues.find_one({'_id': queue}) - - if not doc: - return - - data = doc['options'] - else: - data = queue - - try: - value = data['arguments'][argument] - except (KeyError, TypeError): - return - - return self.get_now() + datetime.timedelta(milliseconds=value) - - def _update_queues_expire(self, queue): - """Update expiration field on queues documents.""" - expire_at = self._get_expire(queue, 'x-expires') - - if not expire_at: - return - - self.routing.update( - {'queue': queue}, {'$set': {'expire_at': expire_at}}, - multiple=True) - self.queues.update( - {'_id': queue}, {'$set': {'expire_at': expire_at}}, - multiple=True) - - def get_now(self): - """Return current time in UTC.""" - return datetime.datetime.utcnow() - - -class Transport(virtual.Transport): - """MongoDB Transport.""" - - Channel = Channel - - can_parse_url = True - polling_interval = 1 - default_port = Channel.default_port - connection_errors = ( - virtual.Transport.connection_errors + (errors.ConnectionFailure,) - ) - channel_errors = ( - virtual.Transport.channel_errors + ( - errors.ConnectionFailure, - errors.OperationFailure) - ) - driver_type = 'mongodb' - driver_name = 'pymongo' - - implements = virtual.Transport.implements.extend( - exchange_type=frozenset(['direct', 'topic', 'fanout']), - ) - - def driver_version(self): - return pymongo.version diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py index 9ba7d8b9..b00c63a9 100644 --- a/kombu/transport/pyamqp.py +++ b/kombu/transport/pyamqp.py @@ -96,14 +96,14 @@ class Transport(base.Transport): def create_channel(self, connection): return connection.channel() - def drain_events(self, connection, **kwargs): - return connection.drain_events(**kwargs) + async def drain_events(self, connection, **kwargs): + await connection.drain_events(**kwargs) def _collect(self, connection): if connection is not None: connection.collect() - def establish_connection(self): + async def establish_connection(self): """Establish connection to the AMQP broker.""" conninfo = self.client for name, default_value in self.default_connection_params.items(): @@ -124,16 +124,16 @@ class Transport(base.Transport): }, **conninfo.transport_options or {}) conn = self.Connection(**opts) conn.client = self.client - conn.connect() + await conn.connect() return conn def verify_connection(self, connection): return connection.connected - def close_connection(self, connection): + async def close_connection(self, connection): """Close the AMQP broker connection.""" connection.client = None - connection.close() + await connection.close() def get_heartbeat_interval(self, connection): return connection.heartbeat @@ -142,8 +142,8 @@ class Transport(base.Transport): connection.transport.raise_on_initial_eintr = True loop.add_reader(connection.sock, self.on_readable, connection, loop) - def heartbeat_check(self, connection, rate=2): - return connection.heartbeat_tick(rate=rate) + async def heartbeat_check(self, connection, rate=2): + await connection.heartbeat_tick(rate=rate) def qos_semantics_matches_spec(self, connection): props = connection.server_properties diff --git a/kombu/transport/pyro.py b/kombu/transport/pyro.py deleted file mode 100644 index c52532aa..00000000 --- a/kombu/transport/pyro.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Pyro transport. - -Requires the :mod:`Pyro4` library to be installed. -""" -import sys - -from kombu.utils.objects import cached_property - -from . import virtual - -try: - import Pyro4 as pyro - from Pyro4.errors import NamingError -except ImportError: # pragma: no cover - pyro = NamingError = None # noqa - -DEFAULT_PORT = 9090 -E_LOOKUP = """\ -Unable to locate pyro nameserver {0.virtual_host} on host {0.hostname}\ -""" - - -class Channel(virtual.Channel): - """Pyro Channel.""" - - def queues(self): - return self.shared_queues.get_queue_names() - - def _new_queue(self, queue, **kwargs): - if queue not in self.queues(): - self.shared_queues.new_queue(queue) - - def _get(self, queue, timeout=None): - queue = self._queue_for(queue) - msg = self.shared_queues._get(queue) - return msg - - def _queue_for(self, queue): - if queue not in self.queues(): - self.shared_queues.new_queue(queue) - return queue - - def _put(self, queue, message, **kwargs): - queue = self._queue_for(queue) - self.shared_queues._put(queue, message) - - def _size(self, queue): - return self.shared_queues._size(queue) - - def _delete(self, queue, *args, **kwargs): - self.shared_queues._delete(queue) - - def _purge(self, queue): - return self.shared_queues._purge(queue) - - def after_reply_message_received(self, queue): - ... - - @cached_property - def shared_queues(self): - return self.connection.shared_queues - - -class Transport(virtual.Transport): - """Pyro Transport.""" - - Channel = Channel - - #: memory backend state is global. - state = virtual.BrokerState() - - default_port = DEFAULT_PORT - - driver_type = driver_name = 'pyro' - - def _open(self): - conninfo = self.client - pyro.config.HMAC_KEY = conninfo.virtual_host - try: - nameserver = pyro.locateNS(host=conninfo.hostname, - port=self.default_port) - # name of registered pyro object - uri = nameserver.lookup(conninfo.virtual_host) - return pyro.Proxy(uri) - except NamingError: - raise NamingError(E_LOOKUP.format(conninfo)).with_traceback( - sys.exc_info()[2]) - - def driver_version(self): - return pyro.__version__ - - @cached_property - def shared_queues(self): - return self._open() diff --git a/kombu/transport/qpid.py b/kombu/transport/qpid.py deleted file mode 100644 index 22046241..00000000 --- a/kombu/transport/qpid.py +++ /dev/null @@ -1,1739 +0,0 @@ -"""Qpid Transport. - -`Qpid`_ transport using `qpid-python`_ as the client and `qpid-tools`_ for -broker management. - -The use this transport you must install the necessary dependencies. These -dependencies are available via PyPI and can be installed using the pip -command: - -.. code-block:: console - - $ pip install kombu[qpid] - -or to install the requirements manually: - -.. code-block:: console - - $ pip install qpid-tools qpid-python - -.. admonition:: Python 3 and PyPy Limitations - - The Qpid transport does not support Python 3 or PyPy environments due - to underlying dependencies not being compatible. This version is - tested and works with with Python 2.7. - -.. _`Qpid`: http://qpid.apache.org/ -.. _`qpid-python`: http://pypi.python.org/pypi/qpid-python/ -.. _`qpid-tools`: http://pypi.python.org/pypi/qpid-tools/ - -Authentication -============== - -This transport supports SASL authentication with the Qpid broker. Normally, -SASL mechanisms are negotiated from a client list and a server list of -possible mechanisms, but in practice, different SASL client libraries give -different behaviors. These different behaviors cause the expected SASL -mechanism to not be selected in many cases. As such, this transport restricts -the mechanism types based on Kombu's configuration according to the following -table. - -+------------------------------------+--------------------+ -| **Broker String** | **SASL Mechanism** | -+------------------------------------+--------------------+ -| qpid://hostname/ | ANONYMOUS | -+------------------------------------+--------------------+ -| qpid://username:password@hostname/ | PLAIN | -+------------------------------------+--------------------+ -| see instructions below | EXTERNAL | -+------------------------------------+--------------------+ - -The user can override the above SASL selection behaviors and specify the SASL -string using the :attr:`~kombu.Connection.login_method` argument to the -:class:`~kombu.Connection` object. The string can be a single SASL mechanism -or a space separated list of SASL mechanisms. If you are using Celery with -Kombu, this can be accomplished by setting the *BROKER_LOGIN_METHOD* Celery -option. - -.. note:: - - While using SSL, Qpid users may want to override the SASL mechanism to - use *EXTERNAL*. In that case, Qpid requires a username to be presented - that matches the *CN* of the SSL client certificate. Ensure that the - broker string contains the corresponding username. For example, if the - client certificate has *CN=asdf* and the client connects to *example.com* - on port 5671, the broker string should be: - - **qpid://asdf@example.com:5671/** - -Transport Options -================= - -The :attr:`~kombu.Connection.transport_options` argument to the -:class:`~kombu.Connection` object are passed directly to the -:class:`qpid.messaging.endpoints.Connection` as keyword arguments. These -options override and replace any other default or specified values. If using -Celery, this can be accomplished by setting the -*BROKER_TRANSPORT_OPTIONS* Celery option. -""" -from __future__ import absolute_import, unicode_literals - -from collections import OrderedDict -import os -import select -import socket -import ssl -import sys -import uuid - -from gettext import gettext as _ - -import amqp.protocol - -try: - import fcntl -except ImportError: - fcntl = None # noqa - -try: - import qpidtoollibs -except ImportError: # pragma: no cover - qpidtoollibs = None # noqa - -try: - from qpid.messaging.exceptions import ConnectionError, NotFound - from qpid.messaging.exceptions import Empty as QpidEmpty - from qpid.messaging.exceptions import SessionClosed -except ImportError: # pragma: no cover - ConnectionError = None - NotFound = None - QpidEmpty = None - SessionClosed = None - -try: - import qpid -except ImportError: # pragma: no cover - qpid = None - - -from kombu.five import Empty, items, monotonic -from kombu.log import get_logger -from kombu.transport.virtual import Base64, Message -from kombu.transport import base - - -logger = get_logger(__name__) - - -OBJECT_ALREADY_EXISTS_STRING = 'object already exists' - -VERSION = (1, 0, 0) -__version__ = '.'.join(map(str, VERSION)) - -PY3 = sys.version_info[0] == 3 - - -def dependency_is_none(dependency): - """Return True if the dependency is None, otherwise False. - - This is done using a function so that tests can mock this - behavior easily. - - :param dependency: The module to check if it is None - :return: True if dependency is None otherwise False. - - """ - return dependency is None - - -class AuthenticationFailure(Exception): - """Cannot authenticate with Qpid.""" - - -class QoS(object): - """A helper object for message prefetch and ACKing purposes. - - :keyword prefetch_count: Initial prefetch count, hard set to 1. - :type prefetch_count: int - - - NOTE: prefetch_count is currently hard set to 1, and needs to be improved - - This object is instantiated 1-for-1 with a - :class:`~.kombu.transport.qpid.Channel` instance. QoS allows - ``prefetch_count`` to be set to the number of outstanding messages - the corresponding :class:`~kombu.transport.qpid.Channel` should be - allowed to prefetch. Setting ``prefetch_count`` to 0 disables - prefetch limits, and the object can hold an arbitrary number of messages. - - Messages are added using :meth:`append`, which are held until they are - ACKed asynchronously through a call to :meth:`ack`. Messages that are - received, but not ACKed will not be delivered by the broker to another - consumer until an ACK is received, or the session is closed. Messages - are referred to using delivery_tag, which are unique per - :class:`Channel`. Delivery tags are managed outside of this object and - are passed in with a message to :meth:`append`. Un-ACKed messages can - be looked up from QoS using :meth:`get` and can be rejected and - forgotten using :meth:`reject`. - - """ - - def __init__(self, session, prefetch_count=1): - self.session = session - self.prefetch_count = 1 - self._not_yet_acked = OrderedDict() - - def can_consume(self): - """Return True if the :class:`Channel` can consume more messages. - - Used to ensure the client adheres to currently active prefetch - limits. - - :returns: True, if this QoS object can accept more messages - without violating the prefetch_count. If prefetch_count is 0, - can_consume will always return True. - :rtype: bool - - """ - return ( - not self.prefetch_count or - len(self._not_yet_acked) < self.prefetch_count - ) - - def can_consume_max_estimate(self): - """Return the remaining message capacity. - - Returns an estimated number of outstanding messages that a - :class:`kombu.transport.qpid.Channel` can accept without - exceeding ``prefetch_count``. If ``prefetch_count`` is 0, then - this method returns 1. - - :returns: The number of estimated messages that can be fetched - without violating the prefetch_count. - :rtype: int - - """ - return 1 if not self.prefetch_count else ( - self.prefetch_count - len(self._not_yet_acked) - ) - - def append(self, message, delivery_tag): - """Append message to the list of un-ACKed messages. - - Add a message, referenced by the delivery_tag, for ACKing, - rejecting, or getting later. Messages are saved into an - :class:`collections.OrderedDict` by delivery_tag. - - :param message: A received message that has not yet been ACKed. - :type message: qpid.messaging.Message - :param delivery_tag: A UUID to refer to this message by - upon receipt. - :type delivery_tag: uuid.UUID - - """ - self._not_yet_acked[delivery_tag] = message - - def get(self, delivery_tag): - """Get an un-ACKed message by delivery_tag. - - If called with an invalid delivery_tag a :exc:`KeyError` is raised. - - :param delivery_tag: The delivery tag associated with the message - to be returned. - :type delivery_tag: uuid.UUID - - :return: An un-ACKed message that is looked up by delivery_tag. - :rtype: qpid.messaging.Message - - """ - return self._not_yet_acked[delivery_tag] - - def ack(self, delivery_tag): - """Acknowledge a message by delivery_tag. - - Called asynchronously once the message has been handled and can be - forgotten by the broker. - - :param delivery_tag: the delivery tag associated with the message - to be acknowledged. - :type delivery_tag: uuid.UUID - - """ - message = self._not_yet_acked.pop(delivery_tag) - self.session.acknowledge(message=message) - - def reject(self, delivery_tag, requeue=False): - """Reject a message by delivery_tag. - - Explicitly notify the broker that the channel associated - with this QoS object is rejecting the message that was previously - delivered. - - If requeue is False, then the message is not requeued for delivery - to another consumer. If requeue is True, then the message is - requeued for delivery to another consumer. - - :param delivery_tag: The delivery tag associated with the message - to be rejected. - :type delivery_tag: uuid.UUID - :keyword requeue: If True, the broker will be notified to requeue - the message. If False, the broker will be told to drop the - message entirely. In both cases, the message will be removed - from this object. - :type requeue: bool - - """ - message = self._not_yet_acked.pop(delivery_tag) - QpidDisposition = qpid.messaging.Disposition - if requeue: - disposition = QpidDisposition(qpid.messaging.RELEASED) - else: - disposition = QpidDisposition(qpid.messaging.REJECTED) - self.session.acknowledge(message=message, disposition=disposition) - - -class Channel(base.StdChannel): - """Supports broker configuration and messaging send and receive. - - :param connection: A Connection object that this Channel can - reference. Currently only used to access callbacks. - :type connection: kombu.transport.qpid.Connection - :param transport: The Transport this Channel is associated with. - :type transport: kombu.transport.qpid.Transport - - A channel object is designed to have method-parity with a Channel as - defined in AMQP 0-10 and earlier, which allows for the following broker - actions: - - - exchange declare and delete - - queue declare and delete - - queue bind and unbind operations - - queue length and purge operations - - sending/receiving/rejecting messages - - structuring, encoding, and decoding messages - - supports synchronous and asynchronous reads - - reading state about the exchange, queues, and bindings - - Channels are designed to all share a single TCP connection with a - broker, but provide a level of isolated communication with the broker - while benefiting from a shared TCP connection. The Channel is given - its :class:`~kombu.transport.qpid.Connection` object by the - :class:`~kombu.transport.qpid.Transport` that - instantiates the channel. - - This channel inherits from :class:`~kombu.transport.base.StdChannel`, - which makes this a 'native' channel versus a 'virtual' channel which - would inherit from :class:`kombu.transports.virtual`. - - Messages sent using this channel are assigned a delivery_tag. The - delivery_tag is generated for a message as they are prepared for - sending by :meth:`basic_publish`. The delivery_tag is unique per - channel instance. The delivery_tag has no meaningful context in other - objects, and is only maintained in the memory of this object, and the - underlying :class:`QoS` object that provides support. - - Each channel object instantiates exactly one :class:`QoS` object for - prefetch limiting, and asynchronous ACKing. The :class:`QoS` object is - lazily instantiated through a property method :meth:`qos`. The - :class:`QoS` object is a supporting object that should not be accessed - directly except by the channel itself. - - Synchronous reads on a queue are done using a call to :meth:`basic_get` - which uses :meth:`_get` to perform the reading. These methods read - immediately and do not accept any form of timeout. :meth:`basic_get` - reads synchronously and ACKs messages before returning them. ACKing is - done in all cases, because an application that reads messages using - qpid.messaging, but does not ACK them will experience a memory leak. - The no_ack argument to :meth:`basic_get` does not affect ACKing - functionality. - - Asynchronous reads on a queue are done by starting a consumer using - :meth:`basic_consume`. Each call to :meth:`basic_consume` will cause a - :class:`~qpid.messaging.endpoints.Receiver` to be created on the - :class:`~qpid.messaging.endpoints.Session` started by the :class: - `Transport`. The receiver will asynchronously read using - qpid.messaging, and prefetch messages before the call to - :meth:`Transport.basic_drain` occurs. The prefetch_count value of the - :class:`QoS` object is the capacity value of the new receiver. The new - receiver capacity must always be at least 1, otherwise none of the - receivers will appear to be ready for reading, and will never be read - from. - - Each call to :meth:`basic_consume` creates a consumer, which is given a - consumer tag that is identified by the caller of :meth:`basic_consume`. - Already started consumers can be cancelled using by their consumer_tag - using :meth:`basic_cancel`. Cancellation of a consumer causes the - :class:`~qpid.messaging.endpoints.Receiver` object to be closed. - - Asynchronous message ACKing is supported through :meth:`basic_ack`, - and is referenced by delivery_tag. The Channel object uses its - :class:`QoS` object to perform the message ACKing. - - """ - - #: A class reference that will be instantiated using the qos property. - QoS = QoS - - #: A class reference that identifies - # :class:`~kombu.transport.virtual.Message` as the message class type - Message = Message - - #: Default body encoding. - #: NOTE: ``transport_options['body_encoding']`` will override this value. - body_encoding = 'base64' - - #: Binary <-> ASCII codecs. - codecs = {'base64': Base64()} - - def __init__(self, connection, transport): - self.connection = connection - self.transport = transport - qpid_connection = connection.get_qpid_connection() - self._broker = qpidtoollibs.BrokerAgent(qpid_connection) - self.closed = False - self._tag_to_queue = {} - self._receivers = {} - self._qos = None - - def _get(self, queue): - """Non-blocking, single-message read from a queue. - - An internal method to perform a non-blocking, single-message read - from a queue by name. This method creates a - :class:`~qpid.messaging.endpoints.Receiver` to read from the queue - using the :class:`~qpid.messaging.endpoints.Session` saved on the - associated :class:`~kombu.transport.qpid.Transport`. The receiver - is closed before the method exits. If a message is available, a - :class:`qpid.messaging.Message` object is returned. If no message is - available, a :class:`qpid.messaging.exceptions.Empty` exception is - raised. - - This is an internal method. External calls for get functionality - should be done using :meth:`basic_get`. - - :param queue: The queue name to get the message from - :type queue: str - - :return: The received message. - :rtype: :class:`qpid.messaging.Message` - :raises: :class:`qpid.messaging.exceptions.Empty` if no - message is available. - - """ - rx = self.transport.session.receiver(queue) - try: - message = rx.fetch(timeout=0) - finally: - rx.close() - return message - - def _put(self, routing_key, message, exchange=None, **kwargs): - """Synchronously send a single message onto a queue or exchange. - - An internal method which synchronously sends a single message onto - a given queue or exchange. If exchange is not specified, - the message is sent directly to a queue specified by routing_key. - If no queue is found by the name of routing_key while exchange is - not specified an exception is raised. If an exchange is specified, - then the message is delivered onto the requested - exchange using routing_key. Message sending is synchronous using - sync=True because large messages in kombu funtests were not being - fully sent before the receiver closed. - - This method creates a :class:`qpid.messaging.endpoints.Sender` to - send the message to the queue using the - :class:`qpid.messaging.endpoints.Session` created and referenced by - the associated :class:`~kombu.transport.qpid.Transport`. The sender - is closed before the method exits. - - External calls for put functionality should be done using - :meth:`basic_publish`. - - :param routing_key: If exchange is None, treated as the queue name - to send the message to. If exchange is not None, treated as the - routing_key to use as the message is submitted onto the exchange. - :type routing_key: str - :param message: The message to be sent as prepared by - :meth:`basic_publish`. - :type message: dict - :keyword exchange: keyword parameter of the exchange this message - should be sent on. If no exchange is specified, the message is - sent directly to a queue specified by routing_key. - :type exchange: str - - """ - if not exchange: - address = '%s; {assert: always, node: {type: queue}}' % ( - routing_key,) - msg_subject = None - else: - address = '%s/%s; {assert: always, node: {type: topic}}' % ( - exchange, routing_key) - msg_subject = str(routing_key) - sender = self.transport.session.sender(address) - qpid_message = qpid.messaging.Message(content=message, - subject=msg_subject) - try: - sender.send(qpid_message, sync=True) - finally: - sender.close() - - def _purge(self, queue): - """Purge all undelivered messages from a queue specified by name. - - An internal method to purge all undelivered messages from a queue - specified by name. If the queue does not exist a - :class:`qpid.messaging.exceptions.NotFound` exception is raised. - - The queue message depth is first checked, and then the broker is - asked to purge that number of messages. The integer number of - messages requested to be purged is returned. The actual number of - messages purged may be different than the requested number of - messages to purge (see below). - - Sometimes delivered messages are asked to be purged, but are not. - This case fails silently, which is the correct behavior when a - message that has been delivered to a different consumer, who has - not ACKed the message, and still has an active session with the - broker. Messages in that case are not safe for purging and will be - retained by the broker. The client is unable to change this - delivery behavior. - - This is an internal method. External calls for purge functionality - should be done using :meth:`queue_purge`. - - :param queue: the name of the queue to be purged - :type queue: str - - :return: The number of messages requested to be purged. - :rtype: int - - :raises: :class:`qpid.messaging.exceptions.NotFound` if the queue - being purged cannot be found. - - """ - queue_to_purge = self._broker.getQueue(queue) - if queue_to_purge is None: - error_text = "NOT_FOUND - no queue '{0}'".format(queue) - raise NotFound(code=404, text=error_text) - message_count = queue_to_purge.values['msgDepth'] - if message_count > 0: - queue_to_purge.purge(message_count) - return message_count - - def _size(self, queue): - """Get the number of messages in a queue specified by name. - - An internal method to return the number of messages in a queue - specified by name. It returns an integer count of the number - of messages currently in the queue. - - :param queue: The name of the queue to be inspected for the number - of messages - :type queue: str - - :return the number of messages in the queue specified by name. - :rtype: int - - """ - queue_to_check = self._broker.getQueue(queue) - message_depth = queue_to_check.values['msgDepth'] - return message_depth - - def _delete(self, queue, *args, **kwargs): - """Delete a queue and all messages on that queue. - - An internal method to delete a queue specified by name and all the - messages on it. First, all messages are purged from a queue using a - call to :meth:`_purge`. Second, the broker is asked to delete the - queue. - - This is an internal method. External calls for queue delete - functionality should be done using :meth:`queue_delete`. - - :param queue: The name of the queue to be deleted. - :type queue: str - - """ - self._purge(queue) - self._broker.delQueue(queue) - - def _has_queue(self, queue, **kwargs): - """Determine if the broker has a queue specified by name. - - :param queue: The queue name to check if the queue exists. - :type queue: str - - :return: True if a queue exists on the broker, and false - otherwise. - :rtype: bool - - """ - if self._broker.getQueue(queue): - return True - else: - return False - - def queue_declare(self, queue, passive=False, durable=False, - exclusive=False, auto_delete=True, nowait=False, - arguments=None): - """Create a new queue specified by name. - - If the queue already exists, no change is made to the queue, - and the return value returns information about the existing queue. - - The queue name is required and specified as the first argument. - - If passive is True, the server will not create the queue. The - client can use this to check whether a queue exists without - modifying the server state. Default is False. - - If durable is True, the queue will be durable. Durable queues - remain active when a server restarts. Non-durable queues ( - transient queues) are purged if/when a server restarts. Note that - durable queues do not necessarily hold persistent messages, - although it does not make sense to send persistent messages to a - transient queue. Default is False. - - If exclusive is True, the queue will be exclusive. Exclusive queues - may only be consumed by the current connection. Setting the - 'exclusive' flag always implies 'auto-delete'. Default is False. - - If auto_delete is True, the queue is deleted when all consumers - have finished using it. The last consumer can be cancelled either - explicitly or because its channel is closed. If there was no - consumer ever on the queue, it won't be deleted. Default is True. - - The nowait parameter is unused. It was part of the 0-9-1 protocol, - but this AMQP client implements 0-10 which removed the nowait option. - - The arguments parameter is a set of arguments for the declaration of - the queue. Arguments are passed as a dict or None. This field is - ignored if passive is True. Default is None. - - This method returns a :class:`~collections.namedtuple` with the name - 'queue_declare_ok_t' and the queue name as 'queue', message count - on the queue as 'message_count', and the number of active consumers - as 'consumer_count'. The named tuple values are ordered as queue, - message_count, and consumer_count respectively. - - Due to Celery's non-ACKing of events, a ring policy is set on any - queue that starts with the string 'celeryev' or ends with the string - 'pidbox'. These are celery event queues, and Celery does not ack - them, causing the messages to build-up. Eventually Qpid stops serving - messages unless the 'ring' policy is set, at which point the buffer - backing the queue becomes circular. - - :param queue: The name of the queue to be created. - :type queue: str - :param passive: If True, the sever will not create the queue. - :type passive: bool - :param durable: If True, the queue will be durable. - :type durable: bool - :param exclusive: If True, the queue will be exclusive. - :type exclusive: bool - :param auto_delete: If True, the queue is deleted when all - consumers have finished using it. - :type auto_delete: bool - :param nowait: This parameter is unused since the 0-10 - specification does not include it. - :type nowait: bool - :param arguments: A set of arguments for the declaration of the - queue. - :type arguments: dict or None - - :return: A named tuple representing the declared queue as a named - tuple. The tuple values are ordered as queue, message count, - and the active consumer count. - :rtype: :class:`~collections.namedtuple` - - """ - options = {'passive': passive, - 'durable': durable, - 'exclusive': exclusive, - 'auto-delete': auto_delete, - 'arguments': arguments} - if queue.startswith('celeryev') or queue.endswith('pidbox'): - options['qpid.policy_type'] = 'ring' - try: - self._broker.addQueue(queue, options=options) - except Exception as exc: - if OBJECT_ALREADY_EXISTS_STRING not in str(exc): - raise exc - queue_to_check = self._broker.getQueue(queue) - message_count = queue_to_check.values['msgDepth'] - consumer_count = queue_to_check.values['consumerCount'] - return amqp.protocol.queue_declare_ok_t(queue, message_count, - consumer_count) - - def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs): - """Delete a queue by name. - - Delete a queue specified by name. Using the if_unused keyword - argument, the delete can only occur if there are 0 consumers bound - to it. Using the if_empty keyword argument, the delete can only - occur if there are 0 messages in the queue. - - :param queue: The name of the queue to be deleted. - :type queue: str - :keyword if_unused: If True, delete only if the queue has 0 - consumers. If False, delete a queue even with consumers bound - to it. - :type if_unused: bool - :keyword if_empty: If True, only delete the queue if it is empty. If - False, delete the queue if it is empty or not. - :type if_empty: bool - - """ - if self._has_queue(queue): - if if_empty and self._size(queue): - return - queue_obj = self._broker.getQueue(queue) - consumer_count = queue_obj.getAttributes()['consumerCount'] - if if_unused and consumer_count > 0: - return - self._delete(queue) - - def exchange_declare(self, exchange='', type='direct', durable=False, - **kwargs): - """Create a new exchange. - - Create an exchange of a specific type, and optionally have the - exchange be durable. If an exchange of the requested name already - exists, no action is taken and no exceptions are raised. Durable - exchanges will survive a broker restart, non-durable exchanges will - not. - - Exchanges provide behaviors based on their type. The expected - behaviors are those defined in the AMQP 0-10 and prior - specifications including 'direct', 'topic', and 'fanout' - functionality. - - :keyword type: The exchange type. Valid values include 'direct', - 'topic', and 'fanout'. - :type type: str - :keyword exchange: The name of the exchange to be created. If no - exchange is specified, then a blank string will be used as the - name. - :type exchange: str - :keyword durable: True if the exchange should be durable, or False - otherwise. - :type durable: bool - - """ - options = {'durable': durable} - try: - self._broker.addExchange(type, exchange, options) - except Exception as exc: - if OBJECT_ALREADY_EXISTS_STRING not in str(exc): - raise exc - - def exchange_delete(self, exchange_name, **kwargs): - """Delete an exchange specified by name. - - :param exchange_name: The name of the exchange to be deleted. - :type exchange_name: str - - """ - self._broker.delExchange(exchange_name) - - def queue_bind(self, queue, exchange, routing_key, **kwargs): - """Bind a queue to an exchange with a bind key. - - Bind a queue specified by name, to an exchange specified by name, - with a specific bind key. The queue and exchange must already - exist on the broker for the bind to complete successfully. Queues - may be bound to exchanges multiple times with different keys. - - :param queue: The name of the queue to be bound. - :type queue: str - :param exchange: The name of the exchange that the queue should be - bound to. - :type exchange: str - :param routing_key: The bind key that the specified queue should - bind to the specified exchange with. - :type routing_key: str - - """ - self._broker.bind(exchange, queue, routing_key) - - def queue_unbind(self, queue, exchange, routing_key, **kwargs): - """Unbind a queue from an exchange with a given bind key. - - Unbind a queue specified by name, from an exchange specified by - name, that is already bound with a bind key. The queue and - exchange must already exist on the broker, and bound with the bind - key for the operation to complete successfully. Queues may be - bound to exchanges multiple times with different keys, thus the - bind key is a required field to unbind in an explicit way. - - :param queue: The name of the queue to be unbound. - :type queue: str - :param exchange: The name of the exchange that the queue should be - unbound from. - :type exchange: str - :param routing_key: The existing bind key between the specified - queue and a specified exchange that should be unbound. - :type routing_key: str - - """ - self._broker.unbind(exchange, queue, routing_key) - - def queue_purge(self, queue, **kwargs): - """Remove all undelivered messages from queue. - - Purge all undelivered messages from a queue specified by name. If the - queue does not exist an exception is raised. The queue message - depth is first checked, and then the broker is asked to purge that - number of messages. The integer number of messages requested to be - purged is returned. The actual number of messages purged may be - different than the requested number of messages to purge. - - Sometimes delivered messages are asked to be purged, but are not. - This case fails silently, which is the correct behavior when a - message that has been delivered to a different consumer, who has - not ACKed the message, and still has an active session with the - broker. Messages in that case are not safe for purging and will be - retained by the broker. The client is unable to change this - delivery behavior. - - Internally, this method relies on :meth:`_purge`. - - :param queue: The name of the queue which should have all messages - removed. - :type queue: str - - :return: The number of messages requested to be purged. - :rtype: int - - :raises: :class:`qpid.messaging.exceptions.NotFound` if the queue - being purged cannot be found. - - """ - return self._purge(queue) - - def basic_get(self, queue, no_ack=False, **kwargs): - """Non-blocking single message get and ACK from a queue by name. - - Internally this method uses :meth:`_get` to fetch the message. If - an :class:`~qpid.messaging.exceptions.Empty` exception is raised by - :meth:`_get`, this method silences it and returns None. If - :meth:`_get` does return a message, that message is ACKed. The no_ack - parameter has no effect on ACKing behavior, and all messages are - ACKed in all cases. This method never adds fetched Messages to the - internal QoS object for asynchronous ACKing. - - This method converts the object type of the method as it passes - through. Fetching from the broker, :meth:`_get` returns a - :class:`qpid.messaging.Message`, but this method takes the payload - of the :class:`qpid.messaging.Message` and instantiates a - :class:`~kombu.transport.virtual.Message` object with the payload - based on the class setting of self.Message. - - :param queue: The queue name to fetch a message from. - :type queue: str - :keyword no_ack: The no_ack parameter has no effect on the ACK - behavior of this method. Un-ACKed messages create a memory leak in - qpid.messaging, and need to be ACKed in all cases. - :type noack: bool - - :return: The received message. - :rtype: :class:`~kombu.transport.virtual.Message` - - """ - try: - qpid_message = self._get(queue) - raw_message = qpid_message.content - message = self.Message(raw_message, channel=self) - self.transport.session.acknowledge(message=qpid_message) - return message - except Empty: - pass - - def basic_ack(self, delivery_tag): - """Acknowledge a message by delivery_tag. - - Acknowledges a message referenced by delivery_tag. Messages can - only be ACKed using :meth:`basic_ack` if they were acquired using - :meth:`basic_consume`. This is the ACKing portion of the - asynchronous read behavior. - - Internally, this method uses the :class:`QoS` object, which stores - messages and is responsible for the ACKing. - - :param delivery_tag: The delivery tag associated with the message - to be acknowledged. - :type delivery_tag: uuid.UUID - - """ - self.qos.ack(delivery_tag) - - def basic_reject(self, delivery_tag, requeue=False): - """Reject a message by delivery_tag. - - Rejects a message that has been received by the Channel, but not - yet acknowledged. Messages are referenced by their delivery_tag. - - If requeue is False, the rejected message will be dropped by the - broker and not delivered to any other consumers. If requeue is - True, then the rejected message will be requeued for delivery to - another consumer, potentially to the same consumer who rejected the - message previously. - - :param delivery_tag: The delivery tag associated with the message - to be rejected. - :type delivery_tag: uuid.UUID - :keyword requeue: If False, the rejected message will be dropped by - the broker and not delivered to any other consumers. If True, - then the rejected message will be requeued for delivery to - another consumer, potentially to the same consumer who rejected - the message previously. - :type requeue: bool - - """ - self.qos.reject(delivery_tag, requeue=requeue) - - def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): - """Start an asynchronous consumer that reads from a queue. - - This method starts a consumer of type - :class:`~qpid.messaging.endpoints.Receiver` using the - :class:`~qpid.messaging.endpoints.Session` created and referenced by - the :class:`Transport` that reads messages from a queue - specified by name until stopped by a call to :meth:`basic_cancel`. - - - Messages are available later through a synchronous call to - :meth:`Transport.drain_events`, which will drain from the consumer - started by this method. :meth:`Transport.drain_events` is - synchronous, but the receiving of messages over the network occurs - asynchronously, so it should still perform well. - :meth:`Transport.drain_events` calls the callback provided here with - the Message of type self.Message. - - Each consumer is referenced by a consumer_tag, which is provided by - the caller of this method. - - This method sets up the callback onto the self.connection object in a - dict keyed by queue name. :meth:`~Transport.drain_events` is - responsible for calling that callback upon message receipt. - - All messages that are received are added to the QoS object to be - saved for asynchronous ACKing later after the message has been - handled by the caller of :meth:`~Transport.drain_events`. Messages - can be ACKed after being received through a call to :meth:`basic_ack`. - - If no_ack is True, The no_ack flag indicates that the receiver of - the message will not call :meth:`basic_ack` later. Since the - message will not be ACKed later, it is ACKed immediately. - - :meth:`basic_consume` transforms the message object type prior to - calling the callback. Initially the message comes in as a - :class:`qpid.messaging.Message`. This method unpacks the payload - of the :class:`qpid.messaging.Message` and creates a new object of - type self.Message. - - This method wraps the user delivered callback in a runtime-built - function which provides the type transformation from - :class:`qpid.messaging.Message` to - :class:`~kombu.transport.virtual.Message`, and adds the message to - the associated :class:`QoS` object for asynchronous ACKing - if necessary. - - :param queue: The name of the queue to consume messages from - :type queue: str - :param no_ack: If True, then messages will not be saved for ACKing - later, but will be ACKed immediately. If False, then messages - will be saved for ACKing later with a call to :meth:`basic_ack`. - :type no_ack: bool - :param callback: a callable that will be called when messages - arrive on the queue. - :type callback: a callable object - :param consumer_tag: a tag to reference the created consumer by. - This consumer_tag is needed to cancel the consumer. - :type consumer_tag: an immutable object - - """ - self._tag_to_queue[consumer_tag] = queue - - def _callback(qpid_message): - raw_message = qpid_message.content - message = self.Message(raw_message, channel=self) - delivery_tag = message.delivery_tag - self.qos.append(qpid_message, delivery_tag) - if no_ack: - # Celery will not ack this message later, so we should ack now - self.basic_ack(delivery_tag) - return callback(message) - - self.connection._callbacks[queue] = _callback - new_receiver = self.transport.session.receiver(queue) - new_receiver.capacity = self.qos.prefetch_count - self._receivers[consumer_tag] = new_receiver - - def basic_cancel(self, consumer_tag): - """Cancel consumer by consumer tag. - - Request the consumer stops reading messages from its queue. The - consumer is a :class:`~qpid.messaging.endpoints.Receiver`, and it is - closed using :meth:`~qpid.messaging.endpoints.Receiver.close`. - - This method also cleans up all lingering references of the consumer. - - :param consumer_tag: The tag which refers to the consumer to be - cancelled. Originally specified when the consumer was created - as a parameter to :meth:`basic_consume`. - :type consumer_tag: an immutable object - - """ - if consumer_tag in self._receivers: - receiver = self._receivers.pop(consumer_tag) - receiver.close() - queue = self._tag_to_queue.pop(consumer_tag, None) - self.connection._callbacks.pop(queue, None) - - def close(self): - """Cancel all associated messages and close the Channel. - - This cancels all consumers by calling :meth:`basic_cancel` for each - known consumer_tag. It also closes the self._broker sessions. Closing - the sessions implicitly causes all outstanding, un-ACKed messages to - be considered undelivered by the broker. - - """ - if not self.closed: - self.closed = True - for consumer_tag in self._receivers.keys(): - self.basic_cancel(consumer_tag) - if self.connection is not None: - self.connection.close_channel(self) - self._broker.close() - - @property - def qos(self): - """:class:`QoS` manager for this channel. - - Lazily instantiates an object of type :class:`QoS` upon access to - the self.qos attribute. - - :return: An already existing, or newly created QoS object - :rtype: :class:`QoS` - - """ - if self._qos is None: - self._qos = self.QoS(self.transport.session) - return self._qos - - def basic_qos(self, prefetch_count, *args): - """Change :class:`QoS` settings for this Channel. - - Set the number of un-acknowledged messages this Channel can fetch and - hold. The prefetch_value is also used as the capacity for any new - :class:`~qpid.messaging.endpoints.Receiver` objects. - - Currently, this value is hard coded to 1. - - :param prefetch_count: Not used. This method is hard-coded to 1. - :type prefetch_count: int - - """ - self.qos.prefetch_count = 1 - - def prepare_message(self, body, priority=None, content_type=None, - content_encoding=None, headers=None, properties=None): - """Prepare message data for sending. - - This message is typically called by - :meth:`kombu.messaging.Producer._publish` as a preparation step in - message publication. - - :param body: The body of the message - :type body: str - :keyword priority: A number between 0 and 9 that sets the priority of - the message. - :type priority: int - :keyword content_type: The content_type the message body should be - treated as. If this is unset, the - :class:`qpid.messaging.endpoints.Sender` object tries to - autodetect the content_type from the body. - :type content_type: str - :keyword content_encoding: The content_encoding the message body is - encoded as. - :type content_encoding: str - :keyword headers: Additional Message headers that should be set. - Passed in as a key-value pair. - :type headers: dict - :keyword properties: Message properties to be set on the message. - :type properties: dict - - :return: Returns a dict object that encapsulates message - attributes. See parameters for more details on attributes that - can be set. - :rtype: dict - - """ - properties = properties or {} - info = properties.setdefault('delivery_info', {}) - info['priority'] = priority or 0 - - return {'body': body, - 'content-encoding': content_encoding, - 'content-type': content_type, - 'headers': headers or {}, - 'properties': properties or {}} - - def basic_publish(self, message, exchange, routing_key, **kwargs): - """Publish message onto an exchange using a routing key. - - Publish a message onto an exchange specified by name using a - routing key specified by routing_key. Prepares the message in the - following ways before sending: - - - encodes the body using :meth:`encode_body` - - wraps the body as a buffer object, so that - :class:`qpid.messaging.endpoints.Sender` uses a content type - that can support arbitrarily large messages. - - sets delivery_tag to a random uuid.UUID - - sets the exchange and routing_key info as delivery_info - - Internally uses :meth:`_put` to send the message synchronously. This - message is typically called by - :class:`kombu.messaging.Producer._publish` as the final step in - message publication. - - :param message: A dict containing key value pairs with the message - data. A valid message dict can be generated using the - :meth:`prepare_message` method. - :type message: dict - :param exchange: The name of the exchange to submit this message - onto. - :type exchange: str - :param routing_key: The routing key to be used as the message is - submitted onto the exchange. - :type routing_key: str - - """ - message['body'], body_encoding = self.encode_body( - message['body'], self.body_encoding, - ) - message['body'] = buffer(message['body']) - props = message['properties'] - props.update( - body_encoding=body_encoding, - delivery_tag=uuid.uuid4(), - ) - props['delivery_info'].update( - exchange=exchange, - routing_key=routing_key, - ) - self._put(routing_key, message, exchange, **kwargs) - - def encode_body(self, body, encoding=None): - """Encode a body using an optionally specified encoding. - - The encoding can be specified by name, and is looked up in - self.codecs. self.codecs uses strings as its keys which specify - the name of the encoding, and then the value is an instantiated - object that can provide encoding/decoding of that type through - encode and decode methods. - - :param body: The body to be encoded. - :type body: str - :keyword encoding: The encoding type to be used. Must be a supported - codec listed in self.codecs. - :type encoding: str - - :return: If encoding is specified, return a tuple with the first - position being the encoded body, and the second position the - encoding used. If encoding is not specified, the body is passed - through unchanged. - :rtype: tuple - - """ - if encoding: - return self.codecs.get(encoding).encode(body), encoding - return body, encoding - - def decode_body(self, body, encoding=None): - """Decode a body using an optionally specified encoding. - - The encoding can be specified by name, and is looked up in - self.codecs. self.codecs uses strings as its keys which specify - the name of the encoding, and then the value is an instantiated - object that can provide encoding/decoding of that type through - encode and decode methods. - - :param body: The body to be encoded. - :type body: str - :keyword encoding: The encoding type to be used. Must be a supported - codec listed in self.codecs. - :type encoding: str - - :return: If encoding is specified, the decoded body is returned. - If encoding is not specified, the body is returned unchanged. - :rtype: str - - """ - if encoding: - return self.codecs.get(encoding).decode(body) - return body - - def typeof(self, exchange, default='direct'): - """Get the exchange type. - - Lookup and return the exchange type for an exchange specified by - name. Exchange types are expected to be 'direct', 'topic', - and 'fanout', which correspond with exchange functionality as - specified in AMQP 0-10 and earlier. If the exchange cannot be - found, the default exchange type is returned. - - :param exchange: The exchange to have its type lookup up. - :type exchange: str - :keyword default: The type of exchange to assume if the exchange does - not exist. - :type default: str - - :return: The exchange type either 'direct', 'topic', or 'fanout'. - :rtype: str - - """ - qpid_exchange = self._broker.getExchange(exchange) - if qpid_exchange: - qpid_exchange_attributes = qpid_exchange.getAttributes() - return qpid_exchange_attributes['type'] - else: - return default - - -class Connection(object): - """Qpid Connection. - - Encapsulate a connection object for the - :class:`~kombu.transport.qpid.Transport`. - - :param host: The host that connections should connect to. - :param port: The port that connection should connect to. - :param username: The username that connections should connect with. - Optional. - :param password: The password that connections should connect with. - Optional but requires a username. - :param transport: The transport type that connections should use. - Either 'tcp', or 'ssl' are expected as values. - :param timeout: the timeout used when a Connection connects - to the broker. - :param sasl_mechanisms: The sasl authentication mechanism type to use. - refer to SASL documentation for an explanation of valid - values. - - .. note:: - - qpid.messaging has an AuthenticationFailure exception type, but - instead raises a ConnectionError with a message that indicates an - authentication failure occurred in those situations. - ConnectionError is listed as a recoverable error type, so kombu - will attempt to retry if a ConnectionError is raised. Retrying - the operation without adjusting the credentials is not correct, - so this method specifically checks for a ConnectionError that - indicates an Authentication Failure occurred. In those - situations, the error type is mutated while preserving the - original message and raised so kombu will allow the exception to - not be considered recoverable. - - - A connection object is created by a - :class:`~kombu.transport.qpid.Transport` during a call to - :meth:`~kombu.transport.qpid.Transport.establish_connection`. The - :class:`~kombu.transport.qpid.Transport` passes in - connection options as keywords that should be used for any connections - created. Each :class:`~kombu.transport.qpid.Transport` creates exactly - one Connection. - - A Connection object maintains a reference to a - :class:`~qpid.messaging.endpoints.Connection` which can be accessed - through a bound getter method named :meth:`get_qpid_connection` method. - Each Channel uses a the Connection for each - :class:`~qpidtoollibs.BrokerAgent`, and the Transport maintains a session - for all senders and receivers. - - The Connection object is also responsible for maintaining the - dictionary of references to callbacks that should be called when - messages are received. These callbacks are saved in _callbacks, - and keyed on the queue name associated with the received message. The - _callbacks are setup in :meth:`Channel.basic_consume`, removed in - :meth:`Channel.basic_cancel`, and called in - :meth:`Transport.drain_events`. - - The following keys are expected to be passed in as keyword arguments - at a minimum: - - All keyword arguments are collected into the connection_options dict - and passed directly through to - :meth:`qpid.messaging.endpoints.Connection.establish`. - - """ - - # A class reference to the :class:`Channel` object - Channel = Channel - - def __init__(self, **connection_options): - self.connection_options = connection_options - self.channels = [] - self._callbacks = {} - self._qpid_conn = None - establish = qpid.messaging.Connection.establish - - # There are several inconsistent behaviors in the sasl libraries - # used on different systems. Although qpid.messaging allows - # multiple space separated sasl mechanisms, this implementation - # only advertises one type to the server. These are either - # ANONYMOUS, PLAIN, or an overridden value specified by the user. - - sasl_mech = connection_options['sasl_mechanisms'] - - try: - msg = _('Attempting to connect to qpid with ' - 'SASL mechanism %s') % sasl_mech - logger.debug(msg) - self._qpid_conn = establish(**self.connection_options) - # connection was successful if we got this far - msg = _('Connected to qpid with SASL ' - 'mechanism %s') % sasl_mech - logger.info(msg) - except ConnectionError as conn_exc: - # if we get one of these errors, do not raise an exception. - # Raising will cause the connection to be retried. Instead, - # just continue on to the next mech. - coded_as_auth_failure = getattr(conn_exc, 'code', None) == 320 - contains_auth_fail_text = \ - 'Authentication failed' in conn_exc.text - contains_mech_fail_text = \ - 'sasl negotiation failed: no mechanism agreed' \ - in conn_exc.text - contains_mech_unavail_text = 'no mechanism available' \ - in conn_exc.text - if coded_as_auth_failure or \ - contains_auth_fail_text or contains_mech_fail_text or \ - contains_mech_unavail_text: - msg = _('Unable to connect to qpid with SASL ' - 'mechanism %s') % sasl_mech - logger.error(msg) - raise AuthenticationFailure(sys.exc_info()[1]) - raise - - def get_qpid_connection(self): - """Return the existing connection (singleton). - - :return: The existing qpid.messaging.Connection - :rtype: :class:`qpid.messaging.endpoints.Connection` - - """ - return self._qpid_conn - - def close(self): - """Close the connection. - - Closing the connection will close all associated session, senders, or - receivers used by the Connection. - - """ - self._qpid_conn.close() - - def close_channel(self, channel): - """Close a Channel. - - Close a channel specified by a reference to the - :class:`~kombu.transport.qpid.Channel` object. - - :param channel: Channel that should be closed. - :type channel: :class:`~kombu.transport.qpid.Channel`. - - """ - try: - self.channels.remove(channel) - except ValueError: - pass - finally: - channel.connection = None - - -class Transport(base.Transport): - """Kombu native transport for a Qpid broker. - - Provide a native transport for Kombu that allows consumers and - producers to read and write messages to/from a broker. This Transport - is capable of supporting both synchronous and asynchronous reading. - All writes are synchronous through the :class:`Channel` objects that - support this Transport. - - Asynchronous reads are done using a call to :meth:`drain_events`, - which synchronously reads messages that were fetched asynchronously, and - then handles them through calls to the callback handlers maintained on - the :class:`Connection` object. - - The Transport also provides methods to establish and close a connection - to the broker. This Transport establishes a factory-like pattern that - allows for singleton pattern to consolidate all Connections into a single - one. - - The Transport can create :class:`Channel` objects to communicate with the - broker with using the :meth:`create_channel` method. - - The Transport identifies recoverable connection errors and recoverable - channel errors according to the Kombu 3.0 interface. These exception are - listed as tuples and store in the Transport class attribute - `recoverable_connection_errors` and `recoverable_channel_errors` - respectively. Any exception raised that is not a member of one of these - tuples is considered non-recoverable. This allows Kombu support for - automatic retry of certain operations to function correctly. - - For backwards compatibility to the pre Kombu 3.0 exception interface, the - recoverable errors are also listed as `connection_errors` and - `channel_errors`. - - """ - - # Reference to the class that should be used as the Connection object - Connection = Connection - - # This Transport does not specify a polling interval. - polling_interval = None - - # This Transport does support the Celery asynchronous event model. - supports_ev = True - - # The driver type and name for identification purposes. - driver_type = 'qpid' - driver_name = 'qpid' - - # Exceptions that can be recovered from, but where the connection must be - # closed and re-established first. - recoverable_connection_errors = ( - ConnectionError, - select.error, - ) - - # Exceptions that can be automatically recovered from without - # re-establishing the connection. - recoverable_channel_errors = ( - NotFound, - ) - - # Support the pre 3.0 Kombu exception labeling interface which treats - # connection_errors and channel_errors both as recoverable via a - # reconnect. - connection_errors = recoverable_connection_errors - channel_errors = recoverable_channel_errors - - def __init__(self, *args, **kwargs): - self.verify_runtime_environment() - super(Transport, self).__init__(*args, **kwargs) - self.use_async_interface = False - - def verify_runtime_environment(self): - """Verify that the runtime environment is acceptable. - - This method is called as part of __init__ and raises a RuntimeError - in Python3 or PyPi environments. This module is not compatible with - Python3 or PyPi. The RuntimeError identifies this to the user up - front along with suggesting Python 2.6+ be used instead. - - This method also checks that the dependencies qpidtoollibs and - qpid.messaging are installed. If either one is not installed a - RuntimeError is raised. - - :raises: RuntimeError if the runtime environment is not acceptable. - - """ - if getattr(sys, 'pypy_version_info', None): - raise RuntimeError( - 'The Qpid transport for Kombu does not ' - 'support PyPy. Try using Python 2.6+', - ) - if PY3: - raise RuntimeError( - 'The Qpid transport for Kombu does not ' - 'support Python 3. Try using Python 2.6+', - ) - - if dependency_is_none(qpidtoollibs): - raise RuntimeError( - 'The Python package "qpidtoollibs" is missing. Install it ' - 'with your package manager. You can also try `pip install ' - 'qpid-tools`.') - - if dependency_is_none(qpid): - raise RuntimeError( - 'The Python package "qpid.messaging" is missing. Install it ' - 'with your package manager. You can also try `pip install ' - 'qpid-python`.') - - def _qpid_message_ready_handler(self, session): - if self.use_async_interface: - os.write(self._w, '0') - - def _qpid_async_exception_notify_handler(self, obj_with_exception, exc): - if self.use_async_interface: - os.write(self._w, 'e') - - def on_readable(self, connection, loop): - """Handle any messages associated with this Transport. - - This method clears a single message from the externally monitored - file descriptor by issuing a read call to the self.r file descriptor - which removes a single '0' character that was placed into the pipe - by the Qpid session message callback handler. Once a '0' is read, - all available events are drained through a call to - :meth:`drain_events`. - - The file descriptor self.r is modified to be non-blocking, ensuring - that an accidental call to this method when no more messages will - not cause indefinite blocking. - - Nothing is expected to be returned from :meth:`drain_events` because - :meth:`drain_events` handles messages by calling callbacks that are - maintained on the :class:`~kombu.transport.qpid.Connection` object. - When :meth:`drain_events` returns, all associated messages have been - handled. - - This method calls drain_events() which reads as many messages as are - available for this Transport, and then returns. It blocks in the - sense that reading and handling a large number of messages may take - time, but it does not block waiting for a new message to arrive. When - :meth:`drain_events` is called a timeout is not specified, which - causes this behavior. - - One interesting behavior of note is where multiple messages are - ready, and this method removes a single '0' character from - self.r, but :meth:`drain_events` may handle an arbitrary amount of - messages. In that case, extra '0' characters may be left on self.r - to be read, where messages corresponding with those '0' characters - have already been handled. The external epoll loop will incorrectly - think additional data is ready for reading, and will call - on_readable unnecessarily, once for each '0' to be read. Additional - calls to :meth:`on_readable` produce no negative side effects, - and will eventually clear out the symbols from the self.r file - descriptor. If new messages show up during this draining period, - they will also be properly handled. - - :param connection: The connection associated with the readable - events, which contains the callbacks that need to be called for - the readable objects. - :type connection: kombu.transport.qpid.Connection - :param loop: The asynchronous loop object that contains epoll like - functionality. - :type loop: kombu.async.Hub - - """ - os.read(self.r, 1) - try: - self.drain_events(connection) - except socket.timeout: - pass - - def register_with_event_loop(self, connection, loop): - """Register a file descriptor and callback with the loop. - - Register the callback self.on_readable to be called when an - external epoll loop sees that the file descriptor registered is - ready for reading. The file descriptor is created by this Transport, - and is written to when a message is available. - - Because supports_ev == True, Celery expects to call this method to - give the Transport an opportunity to register a read file descriptor - for external monitoring by celery using an Event I/O notification - mechanism such as epoll. A callback is also registered that is to - be called once the external epoll loop is ready to handle the epoll - event associated with messages that are ready to be handled for - this Transport. - - The registration call is made exactly once per Transport after the - Transport is instantiated. - - :param connection: A reference to the connection associated with - this Transport. - :type connection: kombu.transport.qpid.Connection - :param loop: A reference to the external loop. - :type loop: kombu.async.hub.Hub - - """ - self.r, self._w = os.pipe() - if fcntl is not None: - fcntl.fcntl(self.r, fcntl.F_SETFL, os.O_NONBLOCK) - self.use_async_interface = True - loop.add_reader(self.r, self.on_readable, connection, loop) - - def establish_connection(self): - """Establish a Connection object. - - Determines the correct options to use when creating any - connections needed by this Transport, and create a - :class:`Connection` object which saves those values for - connections generated as they are needed. The options are a - mixture of what is passed in through the creator of the - Transport, and the defaults provided by - :meth:`default_connection_params`. Options cover broker network - settings, timeout behaviors, authentication, and identity - verification settings. - - This method also creates and stores a - :class:`~qpid.messaging.endpoints.Session` using the - :class:`~qpid.messaging.endpoints.Connection` created by this - method. The Session is stored on self. - - :return: The created :class:`Connection` object is returned. - :rtype: :class:`Connection` - - """ - conninfo = self.client - for name, default_value in items(self.default_connection_params): - if not getattr(conninfo, name, None): - setattr(conninfo, name, default_value) - if conninfo.ssl: - conninfo.qpid_transport = 'ssl' - conninfo.transport_options['ssl_keyfile'] = conninfo.ssl[ - 'keyfile'] - conninfo.transport_options['ssl_certfile'] = conninfo.ssl[ - 'certfile'] - conninfo.transport_options['ssl_trustfile'] = conninfo.ssl[ - 'ca_certs'] - if conninfo.ssl['cert_reqs'] == ssl.CERT_REQUIRED: - conninfo.transport_options['ssl_skip_hostname_check'] = False - else: - conninfo.transport_options['ssl_skip_hostname_check'] = True - else: - conninfo.qpid_transport = 'tcp' - - credentials = {} - if conninfo.login_method is None: - if conninfo.userid is not None and conninfo.password is not None: - sasl_mech = 'PLAIN' - credentials['username'] = conninfo.userid - credentials['password'] = conninfo.password - elif conninfo.userid is None and conninfo.password is not None: - raise Exception( - 'Password configured but no username. SASL PLAIN ' - 'requires a username when using a password.') - elif conninfo.userid is not None and conninfo.password is None: - raise Exception( - 'Username configured but no password. SASL PLAIN ' - 'requires a password when using a username.') - else: - sasl_mech = 'ANONYMOUS' - else: - sasl_mech = conninfo.login_method - if conninfo.userid is not None: - credentials['username'] = conninfo.userid - - opts = { - 'host': conninfo.hostname, - 'port': conninfo.port, - 'sasl_mechanisms': sasl_mech, - 'timeout': conninfo.connect_timeout, - 'transport': conninfo.qpid_transport - } - - opts.update(credentials) - opts.update(conninfo.transport_options) - - conn = self.Connection(**opts) - conn.client = self.client - self.session = conn.get_qpid_connection().session() - self.session.set_message_received_notify_handler( - self._qpid_message_ready_handler - ) - conn.get_qpid_connection().set_async_exception_notify_handler( - self._qpid_async_exception_notify_handler - ) - self.session.set_async_exception_notify_handler( - self._qpid_async_exception_notify_handler - ) - return conn - - def close_connection(self, connection): - """Close the :class:`Connection` object. - - :param connection: The Connection that should be closed. - :type connection: :class:`kombu.transport.qpid.Connection` - - """ - connection.close() - - def drain_events(self, connection, timeout=0, **kwargs): - """Handle and call callbacks for all ready Transport messages. - - Drains all events that are ready from all - :class:`~qpid.messaging.endpoints.Receiver` that are asynchronously - fetching messages. - - For each drained message, the message is called to the appropriate - callback. Callbacks are organized by queue name. - - :param connection: The :class:`~kombu.transport.qpid.Connection` that - contains the callbacks, indexed by queue name, which will be called - by this method. - :type connection: kombu.transport.qpid.Connection - :keyword timeout: The timeout that limits how long this method will - run for. The timeout could interrupt a blocking read that is - waiting for a new message, or cause this method to return before - all messages are drained. Defaults to 0. - :type timeout: int - - """ - start_time = monotonic() - elapsed_time = -1 - while elapsed_time < timeout: - try: - receiver = self.session.next_receiver(timeout=timeout) - message = receiver.fetch() - queue = receiver.source - except QpidEmpty: - raise socket.timeout() - else: - connection._callbacks[queue](message) - elapsed_time = monotonic() - start_time - raise socket.timeout() - - def create_channel(self, connection): - """Create and return a :class:`~kombu.transport.qpid.Channel`. - - Creates a new channel, and appends the channel to the - list of channels known by the Connection. Once the new - channel is created, it is returned. - - :param connection: The connection that should support the new - :class:`~kombu.transport.qpid.Channel`. - :type connection: kombu.transport.qpid.Connection - - :return: The new Channel that is made. - :rtype: :class:`kombu.transport.qpid.Channel`. - - """ - channel = connection.Channel(connection, self) - connection.channels.append(channel) - return channel - - @property - def default_connection_params(self): - """Return a dict with default connection parameters. - - These connection parameters will be used whenever the creator of - Transport does not specify a required parameter. - - :return: A dict containing the default parameters. - :rtype: dict - - """ - return { - 'hostname': 'localhost', - 'port': 5672, - } - - def __del__(self): - """Ensure file descriptors opened in __init__() are closed.""" - if getattr(self, 'use_async_interface', False): - for fd in (self.r, self._w): - try: - os.close(fd) - except OSError: - # ignored - pass diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py deleted file mode 100644 index e3e6de10..00000000 --- a/kombu/transport/redis.py +++ /dev/null @@ -1,1099 +0,0 @@ -"""Redis transport.""" -import numbers -import socket - -from bisect import bisect -from collections import namedtuple -from contextlib import contextmanager -from time import time -from queue import Empty - -from vine import promise - -from kombu.exceptions import InconsistencyError, VersionMismatch -from kombu.log import get_logger -from kombu.utils.compat import register_after_fork -from kombu.utils.eventio import poll, READ, ERR -from kombu.utils.encoding import bytes_to_str -from kombu.utils.json import loads, dumps -from kombu.utils.objects import cached_property -from kombu.utils.scheduling import cycle_by_name -from kombu.utils.url import _parse_url -from kombu.utils.uuid import uuid - -from . import virtual - -try: - import redis -except ImportError: # pragma: no cover - redis = None # noqa - -try: - from redis import sentinel -except ImportError: # pragma: no cover - sentinel = None # noqa - - -logger = get_logger('kombu.transport.redis') -crit, warn = logger.critical, logger.warn - -DEFAULT_PORT = 6379 -DEFAULT_DB = 0 - -PRIORITY_STEPS = [0, 3, 6, 9] - -error_classes_t = namedtuple('error_classes_t', ( - 'connection_errors', 'channel_errors', -)) - -NO_ROUTE_ERROR = """ -Cannot route message for exchange {0!r}: Table empty or key no longer exists. -Probably the key ({1!r}) has been removed from the Redis database. -""" - -# This implementation may seem overly complex, but I assure you there is -# a good reason for doing it this way. -# -# Consuming from several connections enables us to emulate channels, -# which means we can have different service guarantees for individual -# channels. -# -# So we need to consume messages from multiple connections simultaneously, -# and using epoll means we don't have to do so using multiple threads. -# -# Also it means we can easily use PUBLISH/SUBSCRIBE to do fanout -# exchanges (broadcast), as an alternative to pushing messages to fanout-bound -# queues manually. - - -def get_redis_error_classes(): - """Return tuple of redis error classes.""" - from redis import exceptions - # This exception suddenly changed name between redis-py versions - if hasattr(exceptions, 'InvalidData'): - DataError = exceptions.InvalidData - else: - DataError = exceptions.DataError - return error_classes_t( - (virtual.Transport.connection_errors + ( - InconsistencyError, - socket.error, - IOError, - OSError, - exceptions.ConnectionError, - exceptions.AuthenticationError, - exceptions.TimeoutError)), - (virtual.Transport.channel_errors + ( - DataError, - exceptions.InvalidResponse, - exceptions.ResponseError)), - ) - - -def get_redis_ConnectionError(): - """Return the redis ConnectionError exception class.""" - from redis import exceptions - return exceptions.ConnectionError - - -class MutexHeld(Exception): - """Raised when another party holds the lock.""" - - -@contextmanager -def Mutex(client, name, expire): - """The Redis lock implementation (probably shaky).""" - lock_id = uuid() - i_won = client.setnx(name, lock_id) - try: - if i_won: - client.expire(name, expire) - yield - else: - if not client.ttl(name): - client.expire(name, expire) - raise MutexHeld() - finally: - if i_won: - try: - with client.pipeline(True) as pipe: - pipe.watch(name) - if pipe.get(name) == lock_id: - pipe.multi() - pipe.delete(name) - pipe.execute() - pipe.unwatch() - except redis.WatchError: - pass - - -def _after_fork_cleanup_channel(channel): - channel._after_fork() - - -class QoS(virtual.QoS): - """Redis Ack Emulation.""" - - restore_at_shutdown = True - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._vrestore_count = 0 - - def append(self, message, delivery_tag): - delivery = message.delivery_info - EX, RK = delivery['exchange'], delivery['routing_key'] - with self.pipe_or_acquire() as pipe: - pipe.zadd(self.unacked_index_key, time(), delivery_tag) \ - .hset(self.unacked_key, delivery_tag, - dumps([message._raw, EX, RK])) \ - .execute() - super().append(message, delivery_tag) - - def restore_unacked(self, client=None): - with self.channel.conn_or_acquire(client) as client: - for tag in self._delivered: - self.restore_by_tag(tag, client=client) - self._delivered.clear() - - def ack(self, delivery_tag): - self._remove_from_indices(delivery_tag).execute() - super().ack(delivery_tag) - - def reject(self, delivery_tag, requeue=False): - if requeue: - self.restore_by_tag(delivery_tag, leftmost=True) - self.ack(delivery_tag) - - @contextmanager - def pipe_or_acquire(self, pipe=None, client=None): - if pipe: - yield pipe - else: - with self.channel.conn_or_acquire(client) as client: - yield client.pipeline() - - def _remove_from_indices(self, delivery_tag, pipe=None): - with self.pipe_or_acquire(pipe) as pipe: - return pipe.zrem(self.unacked_index_key, delivery_tag) \ - .hdel(self.unacked_key, delivery_tag) - - def restore_visible(self, start=0, num=10, interval=10): - self._vrestore_count += 1 - if (self._vrestore_count - 1) % interval: - return - with self.channel.conn_or_acquire() as client: - ceil = time() - self.visibility_timeout - try: - with Mutex(client, self.unacked_mutex_key, - self.unacked_mutex_expire): - visible = client.zrevrangebyscore( - self.unacked_index_key, ceil, 0, - start=num and start, num=num, withscores=True) - for tag, score in visible or []: - self.restore_by_tag(tag, client) - except MutexHeld: - pass - - def restore_by_tag(self, tag, client=None, leftmost=False): - with self.channel.conn_or_acquire(client) as client: - with client.pipeline() as pipe: - p, _, _ = self._remove_from_indices( - tag, pipe.hget(self.unacked_key, tag)).execute() - if p: - M, EX, RK = loads(bytes_to_str(p)) # json is unicode - self.channel._do_restore_message(M, EX, RK, client, leftmost) - - @cached_property - def unacked_key(self): - return self.channel.unacked_key - - @cached_property - def unacked_index_key(self): - return self.channel.unacked_index_key - - @cached_property - def unacked_mutex_key(self): - return self.channel.unacked_mutex_key - - @cached_property - def unacked_mutex_expire(self): - return self.channel.unacked_mutex_expire - - @cached_property - def visibility_timeout(self): - return self.channel.visibility_timeout - - -class MultiChannelPoller: - """Async I/O poller for Redis transport.""" - - eventflags = READ | ERR - - #: Set by :meth:`get` while reading from the socket. - _in_protected_read = False - - #: Set of one-shot callbacks to call after reading from socket. - after_read = None - - def __init__(self): - # active channels - self._channels = set() - # file descriptor -> channel map. - self._fd_to_chan = {} - # channel -> socket map - self._chan_to_sock = {} - # poll implementation (epoll/kqueue/select) - self.poller = poll() - # one-shot callbacks called after reading from socket. - self.after_read = set() - - def close(self): - for fd in self._chan_to_sock.values(): - try: - self.poller.unregister(fd) - except (KeyError, ValueError): - pass - self._channels.clear() - self._fd_to_chan.clear() - self._chan_to_sock.clear() - - def add(self, channel): - self._channels.add(channel) - - def discard(self, channel): - self._channels.discard(channel) - - def _on_connection_disconnect(self, connection): - try: - self.poller.unregister(connection._sock) - except (AttributeError, TypeError): - pass - - def _register(self, channel, client, type): - if (channel, client, type) in self._chan_to_sock: - self._unregister(channel, client, type) - if client.connection._sock is None: # not connected yet. - client.connection.connect() - sock = client.connection._sock - self._fd_to_chan[sock.fileno()] = (channel, type) - self._chan_to_sock[(channel, client, type)] = sock - self.poller.register(sock, self.eventflags) - - def _unregister(self, channel, client, type): - self.poller.unregister(self._chan_to_sock[(channel, client, type)]) - - def _client_registered(self, channel, client, cmd): - if getattr(client, 'connection', None) is None: - client.connection = client.connection_pool.get_connection('_') - return (client.connection._sock is not None and - (channel, client, cmd) in self._chan_to_sock) - - def _register_BRPOP(self, channel): - """Enable BRPOP mode for channel.""" - ident = channel, channel.client, 'BRPOP' - if not self._client_registered(channel, channel.client, 'BRPOP'): - channel._in_poll = False - self._register(*ident) - if not channel._in_poll: # send BRPOP - channel._brpop_start() - - def _register_LISTEN(self, channel): - """Enable LISTEN mode for channel.""" - if not self._client_registered(channel, channel.subclient, 'LISTEN'): - channel._in_listen = False - self._register(channel, channel.subclient, 'LISTEN') - if not channel._in_listen: - channel._subscribe() # send SUBSCRIBE - - def on_poll_start(self): - for channel in self._channels: - if channel.active_queues: # BRPOP mode? - if channel.qos.can_consume(): - self._register_BRPOP(channel) - if channel.active_fanout_queues: # LISTEN mode? - self._register_LISTEN(channel) - - def on_poll_init(self, poller): - self.poller = poller - for channel in self._channels: - return channel.qos.restore_visible( - num=channel.unacked_restore_limit, - ) - - def maybe_restore_messages(self): - for channel in self._channels: - if channel.active_queues: - # only need to do this once, as they are not local to channel. - return channel.qos.restore_visible( - num=channel.unacked_restore_limit, - ) - - def on_readable(self, fileno): - chan, type = self._fd_to_chan[fileno] - if chan.qos.can_consume(): - chan.handlers[type]() - - def handle_event(self, fileno, event): - if event & READ: - return self.on_readable(fileno), self - elif event & ERR: - chan, type = self._fd_to_chan[fileno] - chan._poll_error(type) - - def get(self, callback, timeout=None): - self._in_protected_read = True - try: - for channel in self._channels: - if channel.active_queues: # BRPOP mode? - if channel.qos.can_consume(): - self._register_BRPOP(channel) - if channel.active_fanout_queues: # LISTEN mode? - self._register_LISTEN(channel) - - events = self.poller.poll(timeout) - if events: - for fileno, event in events: - ret = self.handle_event(fileno, event) - if ret: - return - # - no new data, so try to restore messages. - # - reset active redis commands. - self.maybe_restore_messages() - raise Empty() - finally: - self._in_protected_read = False - while self.after_read: - try: - fun = self.after_read.pop() - except KeyError: - break - else: - fun() - - @property - def fds(self): - return self._fd_to_chan - - -class Channel(virtual.Channel): - """Redis Channel.""" - - QoS = QoS - - _client = None - _subclient = None - _closing = False - supports_fanout = True - keyprefix_queue = '_kombu.binding.%s' - keyprefix_fanout = '/{db}.' - sep = '\x06\x16' - _in_poll = False - _in_listen = False - _fanout_queues = {} - ack_emulation = True - unacked_key = 'unacked' - unacked_index_key = 'unacked_index' - unacked_mutex_key = 'unacked_mutex' - unacked_mutex_expire = 300 # 5 minutes - unacked_restore_limit = None - visibility_timeout = 3600 # 1 hour - priority_steps = PRIORITY_STEPS - socket_timeout = None - socket_connect_timeout = None - socket_keepalive = None - socket_keepalive_options = None - max_connections = 10 - #: Transport option to disable fanout keyprefix. - #: Can also be string, in which case it changes the default - #: prefix ('/{db}.') into to something else. The prefix must - #: include a leading slash and a trailing dot. - #: - #: Enabled by default since Kombu 4.x. - #: Disable for backwards compatibility with Kombu 3.x. - fanout_prefix = True - - #: If enabled the fanout exchange will support patterns in routing - #: and binding keys (like a topic exchange but using PUB/SUB). - #: - #: Enabled by default since Kombu 4.x. - #: Disable for backwards compatibility with Kombu 3.x. - fanout_patterns = True - - #: Order in which we consume from queues. - #: - #: Can be either string alias, or a cycle strategy class - #: - #: - ``round_robin`` - #: (:class:`~kombu.utils.scheduling.round_robin_cycle`). - #: - #: Make sure each queue has an equal opportunity to be consumed from. - #: - #: - ``sorted`` - #: (:class:`~kombu.utils.scheduling.sorted_cycle`). - #: - #: Consume from queues in alphabetical order. - #: If the first queue in the sorted list always contains messages, - #: then the rest of the queues will never be consumed from. - #: - #: - ``priority`` - #: (:class:`~kombu.utils.scheduling.priority_cycle`). - #: - #: Consume from queues in original order, so that if the first - #: queue always contains messages, the rest of the queues - #: in the list will never be consumed from. - #: - #: The default is to consume from queues in round robin. - queue_order_strategy = 'round_robin' - - _async_pool = None - _pool = None - - from_transport_options = ( - virtual.Channel.from_transport_options + - ('ack_emulation', - 'unacked_key', - 'unacked_index_key', - 'unacked_mutex_key', - 'unacked_mutex_expire', - 'visibility_timeout', - 'unacked_restore_limit', - 'fanout_prefix', - 'fanout_patterns', - 'socket_timeout', - 'socket_connect_timeout', - 'socket_keepalive', - 'socket_keepalive_options', - 'queue_order_strategy', - 'max_connections', - 'priority_steps') # <-- do not add comma here! - ) - - connection_class = redis.Connection if redis else None - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - if not self.ack_emulation: # disable visibility timeout - self.QoS = virtual.QoS - - self._queue_cycle = cycle_by_name(self.queue_order_strategy)() - self.Client = self._get_client() - self.ResponseError = self._get_response_error() - self.active_fanout_queues = set() - self.auto_delete_queues = set() - self._fanout_to_queue = {} - self.handlers = {'BRPOP': self._brpop_read, 'LISTEN': self._receive} - - if self.fanout_prefix: - if isinstance(self.fanout_prefix, str): - self.keyprefix_fanout = self.fanout_prefix - else: - # previous versions did not set a fanout, so cannot enable - # by default. - self.keyprefix_fanout = '' - - # Evaluate connection. - try: - self.client.ping() - except Exception: - self._disconnect_pools() - raise - - self.connection.cycle.add(self) # add to channel poller. - # copy errors, in case channel closed but threads still - # are still waiting for data. - self.connection_errors = self.connection.connection_errors - - if register_after_fork is not None: - register_after_fork(self, _after_fork_cleanup_channel) - - def _after_fork(self): - self._disconnect_pools() - - def _disconnect_pools(self): - pool = self._pool - async_pool = self._async_pool - - self._async_pool = self._pool = None - - if pool is not None: - pool.disconnect() - - if async_pool is not None: - async_pool.disconnect() - - def _on_connection_disconnect(self, connection): - if self._in_poll is connection: - self._in_poll = None - if self._in_listen is connection: - self._in_listen = None - if self.connection and self.connection.cycle: - self.connection.cycle._on_connection_disconnect(connection) - - def _do_restore_message(self, payload, exchange, routing_key, - client=None, leftmost=False): - with self.conn_or_acquire(client) as client: - try: - try: - payload['headers']['redelivered'] = True - except KeyError: - pass - for queue in self._lookup(exchange, routing_key): - (client.lpush if leftmost else client.rpush)( - queue, dumps(payload), - ) - except Exception: - crit('Could not restore message: %r', payload, exc_info=True) - - def _restore(self, message, leftmost=False): - if not self.ack_emulation: - return super()._restore(message) - tag = message.delivery_tag - with self.conn_or_acquire() as client: - with client.pipeline() as pipe: - P, _ = pipe.hget(self.unacked_key, tag) \ - .hdel(self.unacked_key, tag) \ - .execute() - if P: - M, EX, RK = loads(bytes_to_str(P)) # json is unicode - self._do_restore_message(M, EX, RK, client, leftmost) - - def _restore_at_beginning(self, message): - return self._restore(message, leftmost=True) - - def basic_consume(self, queue, *args, **kwargs): - if queue in self._fanout_queues: - exchange, _ = self._fanout_queues[queue] - self.active_fanout_queues.add(queue) - self._fanout_to_queue[exchange] = queue - ret = super().basic_consume(queue, *args, **kwargs) - - # Update fair cycle between queues. - # - # We cycle between queues fairly to make sure that - # each queue is equally likely to be consumed from, - # so that a very busy queue will not block others. - # - # This works by using Redis's `BRPOP` command and - # by rotating the most recently used queue to the - # and of the list. See Kombu github issue #166 for - # more discussion of this method. - self._update_queue_cycle() - return ret - - def basic_cancel(self, consumer_tag): - # If we are busy reading messages we may experience - # a race condition where a message is consumed after - # canceling, so we must delay this operation until reading - # is complete (Issue celery/celery#1773). - connection = self.connection - if connection: - if connection.cycle._in_protected_read: - return connection.cycle.after_read.add( - promise(self._basic_cancel, (consumer_tag,)), - ) - return self._basic_cancel(consumer_tag) - - def _basic_cancel(self, consumer_tag): - try: - queue = self._tag_to_queue[consumer_tag] - except KeyError: - return - try: - self.active_fanout_queues.remove(queue) - except KeyError: - pass - else: - self._unsubscribe_from(queue) - try: - exchange, _ = self._fanout_queues[queue] - self._fanout_to_queue.pop(exchange) - except KeyError: - pass - ret = super().basic_cancel(consumer_tag) - self._update_queue_cycle() - return ret - - def _get_publish_topic(self, exchange, routing_key): - if routing_key and self.fanout_patterns: - return ''.join([self.keyprefix_fanout, exchange, '/', routing_key]) - return ''.join([self.keyprefix_fanout, exchange]) - - def _get_subscribe_topic(self, queue): - exchange, routing_key = self._fanout_queues[queue] - return self._get_publish_topic(exchange, routing_key) - - def _subscribe(self): - keys = [self._get_subscribe_topic(queue) - for queue in self.active_fanout_queues] - if not keys: - return - c = self.subclient - if c.connection._sock is None: - c.connection.connect() - self._in_listen = c.connection - c.psubscribe(keys) - - def _unsubscribe_from(self, queue): - topic = self._get_subscribe_topic(queue) - c = self.subclient - if c.connection and c.connection._sock: - c.unsubscribe([topic]) - - def _handle_message(self, client, r): - if bytes_to_str(r[0]) == 'unsubscribe' and r[2] == 0: - client.subscribed = False - return - - if bytes_to_str(r[0]) == 'pmessage': - type, pattern, channel, data = r[0], r[1], r[2], r[3] - else: - type, pattern, channel, data = r[0], None, r[1], r[2] - return { - 'type': type, - 'pattern': pattern, - 'channel': channel, - 'data': data, - } - - def _receive(self): - c = self.subclient - ret = [] - try: - ret.append(self._receive_one(c)) - except Empty: - pass - if c.connection is not None: - while c.connection.can_read(timeout=0): - ret.append(self._receive_one(c)) - return any(ret) - - def _receive_one(self, c): - response = None - try: - response = c.parse_response() - except self.connection_errors: - self._in_listen = None - raise - if response is not None: - payload = self._handle_message(c, response) - if bytes_to_str(payload['type']).endswith('message'): - channel = bytes_to_str(payload['channel']) - if payload['data']: - if channel[0] == '/': - _, _, channel = channel.partition('.') - try: - message = loads(bytes_to_str(payload['data'])) - except (TypeError, ValueError): - warn('Cannot process event on channel %r: %s', - channel, repr(payload)[:4096], exc_info=1) - raise Empty() - exchange = channel.split('/', 1)[0] - self.connection._deliver( - message, self._fanout_to_queue[exchange]) - return True - - def _brpop_start(self, timeout=1): - queues = self._queue_cycle.consume(len(self.active_queues)) - if not queues: - return - keys = [self._q_for_pri(queue, pri) for pri in self.priority_steps - for queue in queues] + [timeout or 0] - self._in_poll = self.client.connection - self.client.connection.send_command('BRPOP', *keys) - - def _brpop_read(self, **options): - try: - try: - dest__item = self.client.parse_response(self.client.connection, - 'BRPOP', - **options) - except self.connection_errors: - # if there's a ConnectionError, disconnect so the next - # iteration will reconnect automatically. - self.client.connection.disconnect() - raise - if dest__item: - dest, item = dest__item - dest = bytes_to_str(dest).rsplit(self.sep, 1)[0] - self._queue_cycle.rotate(dest) - self.connection._deliver(loads(bytes_to_str(item)), dest) - return True - else: - raise Empty() - finally: - self._in_poll = None - - def _poll_error(self, type, **options): - if type == 'LISTEN': - self.subclient.parse_response() - else: - self.client.parse_response(self.client.connection, type) - - def _get(self, queue): - with self.conn_or_acquire() as client: - for pri in self.priority_steps: - item = client.rpop(self._q_for_pri(queue, pri)) - if item: - return loads(bytes_to_str(item)) - raise Empty() - - def _size(self, queue): - with self.conn_or_acquire() as client: - with client.pipeline() as pipe: - for pri in self.priority_steps: - pipe = pipe.llen(self._q_for_pri(queue, pri)) - sizes = pipe.execute() - return sum(size for size in sizes - if isinstance(size, numbers.Integral)) - - def _q_for_pri(self, queue, pri): - pri = self.priority(pri) - return '%s%s%s' % ((queue, self.sep, pri) if pri else (queue, '', '')) - - def priority(self, n): - steps = self.priority_steps - return steps[bisect(steps, n) - 1] - - def _put(self, queue, message, **kwargs): - """Deliver message.""" - pri = self._get_message_priority(message, reverse=False) - - with self.conn_or_acquire() as client: - client.lpush(self._q_for_pri(queue, pri), dumps(message)) - - def _put_fanout(self, exchange, message, routing_key, **kwargs): - """Deliver fanout message.""" - with self.conn_or_acquire() as client: - client.publish( - self._get_publish_topic(exchange, routing_key), - dumps(message), - ) - - def _new_queue(self, queue, auto_delete=False, **kwargs): - if auto_delete: - self.auto_delete_queues.add(queue) - - def _queue_bind(self, exchange, routing_key, pattern, queue): - if self.typeof(exchange).type == 'fanout': - # Mark exchange as fanout. - self._fanout_queues[queue] = ( - exchange, routing_key.replace('#', '*'), - ) - with self.conn_or_acquire() as client: - client.sadd(self.keyprefix_queue % (exchange,), - self.sep.join([routing_key or '', - pattern or '', - queue or ''])) - - def _delete(self, queue, exchange, routing_key, pattern, - *args, client=None, **kwargs): - self.auto_delete_queues.discard(queue) - with self.conn_or_acquire(client=client) as client: - client.srem(self.keyprefix_queue % (exchange,), - self.sep.join([routing_key or '', - pattern or '', - queue or ''])) - with client.pipeline() as pipe: - for pri in self.priority_steps: - pipe = pipe.delete(self._q_for_pri(queue, pri)) - pipe.execute() - - def _has_queue(self, queue, **kwargs): - with self.conn_or_acquire() as client: - with client.pipeline() as pipe: - for pri in self.priority_steps: - pipe = pipe.exists(self._q_for_pri(queue, pri)) - return any(pipe.execute()) - - def get_table(self, exchange): - key = self.keyprefix_queue % exchange - with self.conn_or_acquire() as client: - values = client.smembers(key) - if not values: - raise InconsistencyError(NO_ROUTE_ERROR.format(exchange, key)) - return [tuple(bytes_to_str(val).split(self.sep)) for val in values] - - def _purge(self, queue): - with self.conn_or_acquire() as client: - with client.pipeline() as pipe: - for pri in self.priority_steps: - priq = self._q_for_pri(queue, pri) - pipe = pipe.llen(priq).delete(priq) - sizes = pipe.execute() - return sum(sizes[::2]) - - def close(self): - self._closing = True - if not self.closed: - # remove from channel poller. - self.connection.cycle.discard(self) - - # delete fanout bindings - client = self.__dict__.get('client') # only if property cached - if client is not None: - for queue in self._fanout_queues: - if queue in self.auto_delete_queues: - self.queue_delete(queue, client=client) - self._disconnect_pools() - self._close_clients() - super().close() - - def _close_clients(self): - # Close connections - for attr in 'client', 'subclient': - try: - client = self.__dict__[attr] - connection, client.connection = client.connection, None - connection.disconnect() - except (KeyError, AttributeError, self.ResponseError): - pass - - def _prepare_virtual_host(self, vhost): - if not isinstance(vhost, numbers.Integral): - if not vhost or vhost == '/': - vhost = DEFAULT_DB - elif vhost.startswith('/'): - vhost = vhost[1:] - try: - vhost = int(vhost) - except ValueError: - raise ValueError( - 'Database is int between 0 and limit - 1, not {0}'.format( - vhost, - )) - return vhost - - def _filter_tcp_connparams(self, socket_keepalive=None, - socket_keepalive_options=None, **params): - return params - - def _connparams(self, async=False): - conninfo = self.connection.client - connparams = { - 'host': conninfo.hostname or '127.0.0.1', - 'port': conninfo.port or self.connection.default_port, - 'virtual_host': conninfo.virtual_host, - 'password': conninfo.password, - 'max_connections': self.max_connections, - 'socket_timeout': self.socket_timeout, - 'socket_connect_timeout': self.socket_connect_timeout, - 'socket_keepalive': self.socket_keepalive, - 'socket_keepalive_options': self.socket_keepalive_options, - } - if conninfo.ssl: - # Connection(ssl={}) must be a dict containing the keys: - # 'ssl_cert_reqs', 'ssl_ca_certs', 'ssl_certfile', 'ssl_keyfile' - try: - connparams.update(conninfo.ssl) - connparams['connection_class'] = redis.SSLConnection - except TypeError: - pass - host = connparams['host'] - if '://' in host: - scheme, _, _, _, _, path, query = _parse_url(host) - if scheme == 'socket': - connparams = self._filter_tcp_connparams(**connparams) - connparams.update({ - 'connection_class': redis.UnixDomainSocketConnection, - 'path': '/' + path}, **query) - - connparams.pop('socket_connect_timeout', None) - connparams.pop('socket_keepalive', None) - connparams.pop('socket_keepalive_options', None) - - connparams.pop('host', None) - connparams.pop('port', None) - connparams['db'] = self._prepare_virtual_host( - connparams.pop('virtual_host', None)) - - channel = self - connection_cls = ( - connparams.get('connection_class') or - self.connection_class - ) - - if async: - class Connection(connection_cls): - def disconnect(self): - super().disconnect() - channel._on_connection_disconnect(self) - connection_cls = Connection - - connparams['connection_class'] = connection_cls - - return connparams - - def _create_client(self, async=False): - if async: - return self.Client(connection_pool=self.async_pool) - return self.Client(connection_pool=self.pool) - - def _get_pool(self, async=False): - params = self._connparams(async=async) - self.keyprefix_fanout = self.keyprefix_fanout.format(db=params['db']) - return redis.ConnectionPool(**params) - - def _get_client(self): - if redis.VERSION < (2, 10, 0): - raise VersionMismatch( - 'Redis transport requires redis-py versions 2.10.0 or later. ' - 'You have {0.__version__}'.format(redis)) - return redis.StrictRedis - - @contextmanager - def conn_or_acquire(self, client=None): - if client: - yield client - else: - yield self._create_client() - - @property - def pool(self): - if self._pool is None: - self._pool = self._get_pool() - return self._pool - - @property - def async_pool(self): - if self._async_pool is None: - self._async_pool = self._get_pool(async=True) - return self._async_pool - - @cached_property - def client(self): - """Client used to publish messages, BRPOP etc.""" - return self._create_client(async=True) - - @cached_property - def subclient(self): - """Pub/Sub connection used to consume fanout queues.""" - client = self._create_client(async=True) - return client.pubsub() - - def _update_queue_cycle(self): - self._queue_cycle.update(self.active_queues) - - def _get_response_error(self): - from redis import exceptions - return exceptions.ResponseError - - @property - def active_queues(self): - """Set of queues being consumed from (excluding fanout queues).""" - return {queue for queue in self._active_queues - if queue not in self.active_fanout_queues} - - -class Transport(virtual.Transport): - """Redis Transport.""" - - Channel = Channel - - polling_interval = None # disable sleep between unsuccessful polls. - default_port = DEFAULT_PORT - driver_type = 'redis' - driver_name = 'redis' - - implements = virtual.Transport.implements.extend( - async=True, - exchange_type=frozenset(['direct', 'topic', 'fanout']) - ) - - def __init__(self, *args, **kwargs): - if redis is None: - raise ImportError('Missing redis library (pip install redis)') - super().__init__(*args, **kwargs) - - # Get redis-py exceptions. - self.connection_errors, self.channel_errors = self._get_errors() - # All channels share the same poller. - self.cycle = MultiChannelPoller() - - def driver_version(self): - return redis.__version__ - - def register_with_event_loop(self, connection, loop): - cycle = self.cycle - cycle.on_poll_init(loop.poller) - cycle_poll_start = cycle.on_poll_start - add_reader = loop.add_reader - on_readable = self.on_readable - - def _on_disconnect(connection): - if connection._sock: - loop.remove(connection._sock) - cycle._on_connection_disconnect = _on_disconnect - - def on_poll_start(): - cycle_poll_start() - [add_reader(fd, on_readable, fd) for fd in cycle.fds] - loop.on_tick.add(on_poll_start) - loop.call_repeatedly(10, cycle.maybe_restore_messages) - - def on_readable(self, fileno): - """Handle AIO event for one of our file descriptors.""" - self.cycle.on_readable(fileno) - - def _get_errors(self): - """Utility to import redis-py's exceptions at runtime.""" - return get_redis_error_classes() - - -class SentinelChannel(Channel): - """Channel with explicit Redis Sentinel knowledge. - - Broker url is supposed to look like: - - sentinel://0.0.0.0:26379;sentinel://0.0.0.0:26380/... - - where each sentinel is separated by a `;`. Multiple sentinels are handled - by :class:`kombu.Connection` constructor, and placed in the alternative - list of servers to connect to in case of connection failure. - - Other arguments for the sentinel should come from the transport options - (see :method:`Celery.connection` which is in charge of creating the - `Connection` object). - - You must provide at least one option in Transport options: - * `service_name` - name of the redis group to poll - """ - - from_transport_options = Channel.from_transport_options + ( - 'master_name', - 'min_other_sentinels', - 'sentinel_kwargs') - - connection_class = sentinel.SentinelManagedConnection if sentinel else None - - def _sentinel_managed_pool(self, async=False): - connparams = self._connparams(async) - - additional_params = connparams.copy() - - additional_params.pop('host', None) - additional_params.pop('port', None) - - sentinel_inst = sentinel.Sentinel( - [(connparams['host'], connparams['port'])], - min_other_sentinels=getattr(self, 'min_other_sentinels', 0), - sentinel_kwargs=getattr(self, 'sentinel_kwargs', {}), - **additional_params) - - master_name = getattr(self, 'master_name', None) - - return sentinel_inst.master_for( - master_name, - self.Client, - ).connection_pool - - def _get_pool(self, async=False): - return self._sentinel_managed_pool(async) - - -class SentinelTransport(Transport): - """Redis Sentinel Transport.""" - - default_port = 26379 - Channel = SentinelChannel diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index 6960104c..b3bf63f3 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, unicode_literals - from .base import ( Base64, NotEquivalentError, UndeliverableWarning, BrokerState, QoS, Message, AbstractChannel, Channel, Management, Transport, diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index 75bca055..052a620d 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -2,24 +2,30 @@ Emulates the AMQ API for non-AMQ transports. """ -from __future__ import absolute_import, print_function, unicode_literals - +import abc import base64 import socket import sys import warnings from array import array -from collections import OrderedDict, defaultdict, namedtuple +from asyncio import sleep +from collections import OrderedDict, defaultdict from itertools import count from multiprocessing.util import Finalize -from time import sleep +from queue import Empty +from time import monotonic +from typing import ( + Any, AnyStr, Callable, IO, Iterable, List, Mapping, MutableMapping, + NamedTuple, Optional, Set, Sequence, Tuple, +) from amqp.protocol import queue_declare_ok_t +from amqp.types import ChannelT, ConnectionT from kombu.exceptions import ResourceError, ChannelError -from kombu.five import Empty, items, monotonic from kombu.log import get_logger +from kombu.types import ClientT, MessageT, TransportT from kombu.utils.encoding import str_to_bytes, bytes_to_str from kombu.utils.div import emergency_dump_state from kombu.utils.scheduling import FairCycle @@ -27,7 +33,7 @@ from kombu.utils.uuid import uuid from kombu.transport import base -from .exchange import STANDARD_EXCHANGE_TYPES +from .exchange import STANDARD_EXCHANGE_TYPES, ExchangeType ARRAY_TYPE_H = 'H' if sys.version_info[0] == 3 else b'H' @@ -50,24 +56,41 @@ RESTORE_PANIC_FMT = 'UNABLE TO RESTORE {0} MESSAGES: {1}' logger = get_logger(__name__) -#: Key format used for queue argument lookups in BrokerState.bindings. -binding_key_t = namedtuple('binding_key_t', ( - 'queue', 'exchange', 'routing_key', -)) -#: BrokerState.queue_bindings generates tuples in this format. -queue_binding_t = namedtuple('queue_binding_t', ( - 'exchange', 'routing_key', 'arguments', -)) +class binding_key_t(NamedTuple): + """Key format used for queue argument lookups in BrokerState.bindings.""" + + queue: str + exchange: str + routing_key: str + + +class queue_binding_t(NamedTuple): + """BrokerState.queue_bindings generates tuples in this format.""" + + exchange: str + routing_key: str + arguments: Mapping -class Base64(object): +class Codec(metaclass=abc.ABCMeta): + + @abc.abstractmethod + def encode(self, s: AnyStr) -> str: + ... + + @abc.abstractmethod + def decode(self, s: AnyStr) -> str: + ... + + +class Base64(Codec): """Base64 codec.""" - def encode(self, s): + def encode(self, s: AnyStr) -> str: return bytes_to_str(base64.b64encode(str_to_bytes(s))) - def decode(self, s): + def decode(self, s: AnyStr) -> str: return base64.b64decode(str_to_bytes(s)) @@ -84,7 +107,7 @@ class BrokerState(object): #: Mapping of exchange name to #: :class:`kombu.transport.virtual.exchange.ExchangeType` - exchanges = None + exchanges: Mapping[str, ExchangeType] = None #: This is the actual bindings registry, used to store bindings and to #: test 'in' relationships in constant time. It has the following @@ -94,7 +117,7 @@ class BrokerState(object): #: (queue, exchange, routing_key): arguments, #: # ..., #: } - bindings = None + bindings: Mapping[binding_key_t, Mapping] = None #: The queue index is used to access directly (constant time) #: all the bindings of a certain queue. It has the following structure:: @@ -105,27 +128,32 @@ class BrokerState(object): #: }, #: # ..., #: } - queue_index = None + queue_index: Mapping[str, binding_key_t] = None - def __init__(self, exchanges=None): + def __init__(self, exchanges: Mapping[str, ExchangeType] = None) -> None: self.exchanges = {} if exchanges is None else exchanges self.bindings = {} self.queue_index = defaultdict(set) - def clear(self): + def clear(self) -> None: self.exchanges.clear() self.bindings.clear() self.queue_index.clear() - def has_binding(self, queue, exchange, routing_key): + def has_binding(self, queue: str, exchange: str, routing_key: str) -> bool: return (queue, exchange, routing_key) in self.bindings - def binding_declare(self, queue, exchange, routing_key, arguments): + def binding_declare(self, + queue: str, + exchange: str, + routing_key: str, + arguments: Mapping) -> None: key = binding_key_t(queue, exchange, routing_key) self.bindings.setdefault(key, arguments) self.queue_index[queue].add(key) - def binding_delete(self, queue, exchange, routing_key): + def binding_delete( + self, queue: str, exchange: str, routing_key: str) -> None: key = binding_key_t(queue, exchange, routing_key) try: del self.bindings[key] @@ -134,7 +162,7 @@ class BrokerState(object): else: self.queue_index[queue].remove(key) - def queue_bindings_delete(self, queue): + def queue_bindings_delete(self, queue: str) -> None: try: bindings = self.queue_index.pop(queue) except KeyError: @@ -142,7 +170,7 @@ class BrokerState(object): else: [self.bindings.pop(binding, None) for binding in bindings] - def queue_bindings(self, queue): + def queue_bindings(self, queue: str) -> Iterable[binding_key_t]: return ( queue_binding_t(key.exchange, key.routing_key, self.bindings[key]) for key in self.queue_index[queue] @@ -160,22 +188,22 @@ class QoS(object): """ #: current prefetch count value - prefetch_count = 0 + prefetch_count: int = 0 #: :class:`~collections.OrderedDict` of active messages. #: *NOTE*: Can only be modified by the consuming thread. - _delivered = None + _delivered: OrderedDict = None #: acks can be done by other threads than the consuming thread. #: Instead of a mutex, which doesn't perform well here, we mark #: the delivery tags as dirty, so subsequent calls to append() can remove #: them. - _dirty = None + _dirty: Set[MessageT] = None #: If disabled, unacked messages won't be restored at shutdown. - restore_at_shutdown = True + restore_at_shutdown: bool = True - def __init__(self, channel, prefetch_count=0): + def __init__(self, channel: ChannelT, prefetch_count: int = 0) -> None: self.channel = channel self.prefetch_count = prefetch_count or 0 @@ -188,7 +216,7 @@ class QoS(object): self, self.restore_unacked_once, exitpriority=1, ) - def can_consume(self): + def can_consume(self) -> bool: """Return true if the channel can be consumed from. Used to ensure the client adhers to currently active @@ -197,7 +225,7 @@ class QoS(object): pcount = self.prefetch_count return not pcount or len(self._delivered) - len(self._dirty) < pcount - def can_consume_max_estimate(self): + def can_consume_max_estimate(self) -> int: """Return the maximum number of messages allowed to be returned. Returns an estimated number of messages that a consumer may be allowed @@ -212,16 +240,16 @@ class QoS(object): if pcount: return max(pcount - (len(self._delivered) - len(self._dirty)), 0) - def append(self, message, delivery_tag): + def append(self, message: MessageT, delivery_tag: str) -> None: """Append message to transactional state.""" if self._dirty: self._flush() self._quick_append(delivery_tag, message) - def get(self, delivery_tag): + def get(self, delivery_tag: str) -> MessageT: return self._delivered[delivery_tag] - def _flush(self): + def _flush(self) -> None: """Flush dirty (acked/rejected) tags from.""" dirty = self._dirty delivered = self._delivered @@ -232,17 +260,17 @@ class QoS(object): break delivered.pop(dirty_tag, None) - def ack(self, delivery_tag): + def ack(self, delivery_tag: str) -> None: """Acknowledge message and remove from transactional state.""" self._quick_ack(delivery_tag) - def reject(self, delivery_tag, requeue=False): + def reject(self, delivery_tag: str, requeue: bool = False) -> None: """Remove from transactional state and requeue message.""" if requeue: self.channel._restore_at_beginning(self._delivered[delivery_tag]) self._quick_ack(delivery_tag) - def restore_unacked(self): + async def restore_unacked(self) -> None: """Restore all unacknowledged messages.""" self._flush() delivered = self._delivered @@ -257,13 +285,13 @@ class QoS(object): break try: - restore(message) + await restore(message) except BaseException as exc: errors.append((exc, message)) delivered.clear() return errors - def restore_unacked_once(self, stderr=None): + async def restore_unacked_once(self, stderr: IO = None) -> None: """Restore all unacknowledged messages at shutdown/gc collect. Note: @@ -284,7 +312,7 @@ class QoS(object): if state: print(RESTORING_FMT.format(len(self._delivered)), file=stderr) - unrestored = self.restore_unacked() + unrestored = await self.restore_unacked() if unrestored: errors, messages = list(zip(*unrestored)) @@ -294,7 +322,7 @@ class QoS(object): finally: state.restored = True - def restore_visible(self, *args, **kwargs): + async def restore_visible(self, *args, **kwargs) -> None: """Restore any pending unackwnowledged messages. To be filled in for visibility_timeout style implementations. @@ -303,13 +331,14 @@ class QoS(object): This is implementation optional, and currently only used by the Redis transport. """ - pass + ... class Message(base.Message): """Message object.""" - def __init__(self, payload, channel=None, **kwargs): + def __init__(self, payload: Any, + channel: ChannelT = None, **kwargs) -> None: self._raw = payload properties = payload['properties'] body = payload.get('body') @@ -327,7 +356,7 @@ class Message(base.Message): postencode='utf-8', **kwargs) - def serializable(self): + def serializable(self) -> Mapping: props = self.properties body, _ = self.channel.encode_body(self.body, props.get('body_encoding')) @@ -354,23 +383,23 @@ class AbstractChannel(object): from :class:`Channel`. """ - def _get(self, queue, timeout=None): + async def _get(self, queue: str, timeout: float = None) -> MessageT: """Get next message from `queue`.""" raise NotImplementedError('Virtual channels must implement _get') - def _put(self, queue, message): + async def _put(self, queue: str, message: MessageT) -> None: """Put `message` onto `queue`.""" raise NotImplementedError('Virtual channels must implement _put') - def _purge(self, queue): + async def _purge(self, queue: str) -> int: """Remove all messages from `queue`.""" raise NotImplementedError('Virtual channels must implement _purge') - def _size(self, queue): + async def _size(self, queue: str) -> int: """Return the number of messages in `queue` as an :class:`int`.""" return 0 - def _delete(self, queue, *args, **kwargs): + async def _delete(self, queue: str, *args, **kwargs) -> None: """Delete `queue`. Note: @@ -379,16 +408,16 @@ class AbstractChannel(object): """ self._purge(queue) - def _new_queue(self, queue, **kwargs): + async def _new_queue(self, queue: str, **kwargs) -> None: """Create new queue. Note: Your transport can override this method if it needs to do something whenever a new queue is declared. """ - pass + ... - def _has_queue(self, queue, **kwargs): + async def _has_queue(self, queue: str, **kwargs) -> bool: """Verify that queue exists. Returns: @@ -397,13 +426,14 @@ class AbstractChannel(object): """ return True - def _poll(self, cycle, callback, timeout=None): + async def _poll(self, cycle: Any, callback: Callable, + timeout: float = None) -> Any: """Poll a list of queues for available messages.""" - return cycle.get(callback) + return await cycle.get(callback) - def _get_and_deliver(self, queue, callback): - message = self._get(queue) - callback(message, queue) + async def _get_and_deliver(self, queue: str, callback: Callable) -> None: + message = await self._get(queue) + await callback(message, queue) class Channel(AbstractChannel, base.StdChannel): @@ -415,44 +445,56 @@ class Channel(AbstractChannel, base.StdChannel): """ #: message class used. - Message = Message + Message: type = Message #: QoS class used. - QoS = QoS + QoS: type = QoS #: flag to restore unacked messages when channel #: goes out of scope. - do_restore = True + do_restore: bool = True #: mapping of exchange types and corresponding classes. - exchange_types = dict(STANDARD_EXCHANGE_TYPES) + exchange_type_classes: Mapping[str, type] = dict( + STANDARD_EXCHANGE_TYPES) + + exchange_types: Mapping[str, ExchangeType] = None #: flag set if the channel supports fanout exchanges. - supports_fanout = False + supports_fanout: bool = False #: Binary <-> ASCII codecs. - codecs = {'base64': Base64()} + codecs: Mapping[str, Codec] = {'base64': Base64()} #: Default body encoding. #: NOTE: ``transport_options['body_encoding']`` will override this value. - body_encoding = 'base64' + body_encoding: str = 'base64' #: counter used to generate delivery tags for this channel. _delivery_tags = count(1) #: Optional queue where messages with no route is delivered. #: Set by ``transport_options['deadletter_queue']``. - deadletter_queue = None + deadletter_queue: str = None # List of options to transfer from :attr:`transport_options`. - from_transport_options = ('body_encoding', 'deadletter_queue') + from_transport_options: Tuple[str, ...] = ( + 'body_encoding', + 'deadletter_queue', + ) # Priority defaults default_priority = 0 min_priority = 0 max_priority = 9 - def __init__(self, connection, **kwargs): + _consumers: Set[str] + _tag_to_queue: Mapping[str, str] + _active_queues: List[str] + closed: bool = False + _qos: QoS = None + + def __init__(self, connection: ConnectionT, **kwargs) -> None: self.connection = connection self._consumers = set() self._cycle = None @@ -462,9 +504,9 @@ class Channel(AbstractChannel, base.StdChannel): self.closed = False # instantiate exchange types - self.exchange_types = dict( - (typ, cls(self)) for typ, cls in items(self.exchange_types) - ) + self.exchange_types = { + typ: cls(self) for typ, cls in self.exchange_type_classes.items() + } try: self.channel_id = self.connection._avail_channel_ids.pop() @@ -482,9 +524,15 @@ class Channel(AbstractChannel, base.StdChannel): except KeyError: pass - def exchange_declare(self, exchange=None, type='direct', durable=False, - auto_delete=False, arguments=None, - nowait=False, passive=False): + async def exchange_declare( + self, + exchange: str = None, + type: str = 'direct', + durable: bool = False, + auto_delete: bool = False, + arguments: Mapping = None, + nowait: bool = False, + passive: bool = False) -> None: """Declare exchange.""" type = type or 'direct' exchange = exchange or 'amq.%s' % type @@ -512,49 +560,74 @@ class Channel(AbstractChannel, base.StdChannel): 'table': [], } - def exchange_delete(self, exchange, if_unused=False, nowait=False): + async def exchange_delete( + self, + exchange: str, + if_unused: bool = False, + nowait: bool = False) -> None: """Delete `exchange` and all its bindings.""" for rkey, _, queue in self.get_table(exchange): - self.queue_delete(queue, if_unused=True, if_empty=True) + await self.queue_delete(queue, if_unused=True, if_empty=True) self.state.exchanges.pop(exchange, None) - def queue_declare(self, queue=None, passive=False, **kwargs): + async def queue_declare( + self, + queue: str = None, + passive: bool = False, + **kwargs) -> queue_declare_ok_t: """Declare queue.""" queue = queue or 'amq.gen-%s' % uuid() - if passive and not self._has_queue(queue, **kwargs): + if passive and not await self._has_queue(queue, **kwargs): raise ChannelError( 'NOT_FOUND - no queue {0!r} in vhost {1!r}'.format( queue, self.connection.client.virtual_host or '/'), (50, 10), 'Channel.queue_declare', '404', ) else: - self._new_queue(queue, **kwargs) - return queue_declare_ok_t(queue, self._size(queue), 0) - - def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs): + await self._new_queue(queue, **kwargs) + return queue_declare_ok_t(queue, await self._size(queue), 0) + + async def queue_delete( + self, + queue: str, + if_unused: bool = False, + if_empty: bool = False, + **kwargs) -> None: """Delete queue.""" - if if_empty and self._size(queue): + if if_empty and await self._size(queue): return for exchange, routing_key, args in self.state.queue_bindings(queue): meta = self.typeof(exchange).prepare_bind( queue, exchange, routing_key, args, ) - self._delete(queue, exchange, *meta, **kwargs) + await self._delete(queue, exchange, *meta, **kwargs) self.state.queue_bindings_delete(queue) - def after_reply_message_received(self, queue): - self.queue_delete(queue) + async def after_reply_message_received(self, queue: str) -> None: + await self.queue_delete(queue) - def exchange_bind(self, destination, source='', routing_key='', - nowait=False, arguments=None): + async def exchange_bind( + self, destination: str, + source: str = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None) -> None: raise NotImplementedError('transport does not support exchange_bind') - def exchange_unbind(self, destination, source='', routing_key='', - nowait=False, arguments=None): + async def exchange_unbind( + self, destination: str, + source: str = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None) -> None: raise NotImplementedError('transport does not support exchange_unbind') - def queue_bind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): + async def queue_bind( + self, queue: str, + exchange: str = None, + routing_key: str = '', + arguments: Mapping = None, + **kwargs) -> None: """Bind `queue` to `exchange` with `routing key`.""" exchange = exchange or 'amq.direct' if self.state.has_binding(queue, exchange, routing_key): @@ -568,10 +641,14 @@ class Channel(AbstractChannel, base.StdChannel): ) table.append(meta) if self.supports_fanout: - self._queue_bind(exchange, *meta) - - def queue_unbind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): + await self._queue_bind(exchange, *meta) + + async def queue_unbind( + self, queue: str, + exchange: str = None, + routing_key: str = '', + arguments: Mapping = None, + **kwargs) -> None: # Remove queue binding: self.state.binding_delete(queue, exchange, routing_key) try: @@ -585,19 +662,21 @@ class Channel(AbstractChannel, base.StdChannel): # Should be optimized. Modifying table in place. table[:] = [meta for meta in table if meta != binding_meta] - def list_bindings(self): - return ((queue, exchange, rkey) + async def list_bindings(self) -> Iterable[queue_binding_t]: + return (queue_binding_t(queue, exchange, rkey) for exchange in self.state.exchanges for rkey, pattern, queue in self.get_table(exchange)) - def queue_purge(self, queue, **kwargs): + async def queue_purge(self, queue: str, **kwargs) -> int: """Remove all ready messages from queue.""" - return self._purge(queue) + return await self._purge(queue) - def _next_delivery_tag(self): + def _next_delivery_tag(self) -> str: return uuid() - def basic_publish(self, message, exchange, routing_key, **kwargs): + async def basic_publish( + self, message: MessageT, exchange: str, routing_key: str, + **kwargs) -> None: """Publish message.""" self._inplace_augment_message(message, exchange, routing_key) if exchange: @@ -605,9 +684,10 @@ class Channel(AbstractChannel, base.StdChannel): message, exchange, routing_key, **kwargs ) # anon exchange: routing_key is the destination queue - return self._put(routing_key, message, **kwargs) + return await self._put(routing_key, message, **kwargs) - def _inplace_augment_message(self, message, exchange, routing_key): + def _inplace_augment_message( + self, message: MessageT, exchange: str, routing_key: str) -> None: message['body'], body_encoding = self.encode_body( message['body'], self.body_encoding, ) @@ -621,23 +701,28 @@ class Channel(AbstractChannel, base.StdChannel): routing_key=routing_key, ) - def basic_consume(self, queue, no_ack, callback, consumer_tag, **kwargs): + async def basic_consume( + self, queue: str, + no_ack: bool = False, + callback: Callable = None, + consumer_tag: str = None, + **kwargs) -> None: """Consume from `queue`.""" self._tag_to_queue[consumer_tag] = queue self._active_queues.append(queue) - def _callback(raw_message): + async def _callback(raw_message): message = self.Message(raw_message, channel=self) if not no_ack: self.qos.append(message, message.delivery_tag) - return callback(message) + return await callback(message) self.connection._callbacks[queue] = _callback self._consumers.add(consumer_tag) self._reset_cycle() - def basic_cancel(self, consumer_tag): + await def basic_cancel(self, consumer_tag: str) -> None: """Cancel consumer by consumer tag.""" if consumer_tag in self._consumers: self._consumers.remove(consumer_tag) @@ -649,32 +734,38 @@ class Channel(AbstractChannel, base.StdChannel): pass self.connection._callbacks.pop(queue, None) - def basic_get(self, queue, no_ack=False, **kwargs): + async def basic_get( + self, queue: str, + no_ack: bool = False, + **kwargs) -> Optional[MessageT]: """Get message by direct access (synchronous).""" try: - message = self.Message(self._get(queue), channel=self) + message = self.Message(await self._get(queue), channel=self) if not no_ack: self.qos.append(message, message.delivery_tag) return message except Empty: pass - def basic_ack(self, delivery_tag, multiple=False): + async def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None: """Acknowledge message.""" self.qos.ack(delivery_tag) - def basic_recover(self, requeue=False): + async def basic_recover(self, requeue: bool = False) -> None: """Recover unacked messages.""" if requeue: - return self.qos.restore_unacked() + return await self.qos.restore_unacked() raise NotImplementedError('Does not support recover(requeue=False)') - def basic_reject(self, delivery_tag, requeue=False): + async def basic_reject(self, delivery_tag: str, requeue: bool = False) -> None: """Reject message.""" self.qos.reject(delivery_tag, requeue=requeue) - def basic_qos(self, prefetch_size=0, prefetch_count=0, - apply_global=False): + async def basic_qos( + self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: """Change QoS settings for this channel. Note: @@ -682,14 +773,14 @@ class Channel(AbstractChannel, base.StdChannel): """ self.qos.prefetch_count = prefetch_count - def get_exchanges(self): + def get_exchanges(self) -> Sequence[str]: return list(self.state.exchanges) - def get_table(self, exchange): + def get_table(self, exchange: str) -> Mapping: """Get table of bindings for `exchange`.""" return self.state.exchanges[exchange]['table'] - def typeof(self, exchange, default='direct'): + def typeof(self, exchange: str, default: str = 'direct') -> ExchangeType: """Get the exchange type instance for `exchange`.""" try: type = self.state.exchanges[exchange]['type'] @@ -697,7 +788,9 @@ class Channel(AbstractChannel, base.StdChannel): type = default return self.exchange_types[type] - def _lookup(self, exchange, routing_key, default=None): + def _lookup( + self, exchange: str, routing_key: str, + default: str = None) -> Sequence[str]: """Find all queues matching `routing_key` for the given `exchange`. Returns: @@ -725,34 +818,42 @@ class Channel(AbstractChannel, base.StdChannel): R = [default] return R - def _restore(self, message): + async def _restore(self, message: MessageT) -> None: """Redeliver message to its original destination.""" delivery_info = message.delivery_info message = message.serializable() message['redelivered'] = True for queue in self._lookup( delivery_info['exchange'], delivery_info['routing_key']): - self._put(queue, message) + await self._put(queue, message) - def _restore_at_beginning(self, message): - return self._restore(message) + async def _restore_at_beginning(self, message: MessageT) -> None: + await self._restore(message) - def drain_events(self, timeout=None, callback=None): + async def drain_events( + self, + timeout: float = None, + callback: Callable = None) -> None: callback = callback or self.connection._deliver if self._consumers and self.qos.can_consume(): if hasattr(self, '_get_many'): - return self._get_many(self._active_queues, timeout=timeout) - return self._poll(self.cycle, callback, timeout=timeout) + await self._get_many(self._active_queues, timeout=timeout) + else: + await self._poll(self.cycle, callback, timeout=timeout) raise Empty() - def message_to_python(self, raw_message): + def message_to_python(self, raw_message: Any) -> MessageT: """Convert raw message to :class:`Message` instance.""" if not isinstance(raw_message, self.Message): return self.Message(payload=raw_message, channel=self) return raw_message - def prepare_message(self, body, priority=None, content_type=None, - content_encoding=None, headers=None, properties=None): + def prepare_message(self, body: Any, + priority: int = None, + content_type: str = None, + content_encoding: str = None, + headers: Mapping = None, + properties: Mapping = None) -> Mapping: """Prepare message data.""" properties = properties or {} properties.setdefault('delivery_info', {}) @@ -764,7 +865,7 @@ class Channel(AbstractChannel, base.StdChannel): 'headers': headers or {}, 'properties': properties or {}} - def flow(self, active=True): + async def flow(self, active: bool = True) -> None: """Enable/disable message flow. Raises: @@ -773,7 +874,7 @@ class Channel(AbstractChannel, base.StdChannel): """ raise NotImplementedError('virtual channels do not support flow.') - def close(self): + async def close(self) -> None: """Close channel. Cancel all consumers, and requeue unacked messages. @@ -781,55 +882,62 @@ class Channel(AbstractChannel, base.StdChannel): if not self.closed: self.closed = True for consumer in list(self._consumers): - self.basic_cancel(consumer) + await self.basic_cancel(consumer) if self._qos: - self._qos.restore_unacked_once() + await self._qos.restore_unacked_once() if self._cycle is not None: self._cycle.close() self._cycle = None if self.connection is not None: - self.connection.close_channel(self) + await self.connection.close_channel(self) self.exchange_types = None - def encode_body(self, body, encoding=None): + def encode_body(self, body: Any, encoding: str = None) -> Tuple[Any, str]: if encoding: return self.codecs.get(encoding).encode(body), encoding return body, encoding - def decode_body(self, body, encoding=None): + def decode_body(self, body: Any, encoding: str = None) -> Any: if encoding: return self.codecs.get(encoding).decode(body) return body - def _reset_cycle(self): + def _reset_cycle(self) -> None: self._cycle = FairCycle( self._get_and_deliver, self._active_queues, Empty) - def __enter__(self): + def __enter__(self) -> 'Channel': return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info) -> None: self.close() + async def __aenter__(self) -> 'Channel': + return self + + async def __aexit__(self, *exc_info) -> None: + ... + @property - def state(self): + def state(self) -> BrokerState: """Broker state containing exchanges and bindings.""" return self.connection.state @property - def qos(self): + def qos(self) -> QoS: """:class:`QoS` manager for this channel.""" if self._qos is None: self._qos = self.QoS(self) return self._qos @property - def cycle(self): + def cycle(self) -> FairCycle: if self._cycle is None: self._reset_cycle() return self._cycle - def _get_message_priority(self, message, reverse=False): + def _get_message_priority(self, message: MessageT, + *, reverse: bool = False) -> int: """Get priority from message. The value is limited to within a boundary of 0 to 9. @@ -852,15 +960,15 @@ class Channel(AbstractChannel, base.StdChannel): class Management(base.Management): """Base class for the AMQP management API.""" - def __init__(self, transport): + def __init__(self, transport: TransportT) -> None: super(Management, self).__init__(transport) self.channel = transport.client.channel() - def get_bindings(self): + def get_bindings(self) -> Sequence[Mapping]: return [dict(destination=q, source=e, routing_key=r) for q, e, r in self.channel.list_bindings()] - def close(self): + def close(self) -> None: self.channel.close() @@ -871,31 +979,31 @@ class Transport(base.Transport): client (kombu.Connection): The client this is a transport for. """ - Channel = Channel - Cycle = FairCycle - Management = Management + Channel: type = Channel + Cycle: type = FairCycle + Management: type = Management #: Global :class:`BrokerState` containing declared exchanges and bindings. - state = BrokerState() + state: BrokerState = BrokerState() #: :class:`~kombu.utils.scheduling.FairCycle` instance #: used to fairly drain events from channels (set by constructor). - cycle = None + cycle: FairCycle = None #: port number used when no port is specified. - default_port = None + default_port: int = None #: active channels. - channels = None + channels: Sequence[ChannelT] = None #: queue/callback map. - _callbacks = None + _callbacks: MutableMapping[str, Callable] = None #: Time to sleep between unsuccessful polls. - polling_interval = 1.0 + polling_interval: float = 1.0 #: Max number of channels - channel_max = 65535 + channel_max: int = 65535 implements = base.Transport.implements.extend( async=False, @@ -903,7 +1011,7 @@ class Transport(base.Transport): heartbeats=False, ) - def __init__(self, client, **kwargs): + def __init__(self, client: ClientT, **kwargs): self.client = client self.channels = [] self._avail_channels = [] @@ -916,7 +1024,7 @@ class Transport(base.Transport): ARRAY_TYPE_H, range(self.channel_max, 0, -1), ) - def create_channel(self, connection): + def create_channel(self, connection: ConnectionT) -> ChannelT: try: return self._avail_channels.pop() except IndexError: @@ -924,7 +1032,7 @@ class Transport(base.Transport): self.channels.append(channel) return channel - def close_channel(self, channel): + async def close_channel(self, channel: ChannelT) -> None: try: self._avail_channel_ids.append(channel.channel_id) try: @@ -934,14 +1042,14 @@ class Transport(base.Transport): finally: channel.connection = None - def establish_connection(self): + async def establish_connection(self) -> 'Transport': # creates channel to verify connection. # this channel is then used as the next requested channel. # (returned by ``create_channel``). self._avail_channels.append(self.create_channel(self)) return self # for drain events - def close_connection(self, connection): + async def close_connection(self, connection: ConnectionT) -> None: self.cycle.close() for l in self._avail_channels, self.channels: while l: @@ -950,24 +1058,25 @@ class Transport(base.Transport): except LookupError: # pragma: no cover pass else: - channel.close() + await channel.close() - def drain_events(self, connection, timeout=None): + async def drain_events(self, connection: ConnectionT, + timeout: float = None) -> None: time_start = monotonic() get = self.cycle.get polling_interval = self.polling_interval while 1: try: - get(self._deliver, timeout=timeout) + await get(self._deliver, timeout=timeout) except Empty: if timeout is not None and monotonic() - time_start >= timeout: raise socket.timeout() if polling_interval is not None: - sleep(polling_interval) + await sleep(polling_interval) else: break - def _deliver(self, message, queue): + async def _deliver(self, message: MessageT, queue: str) -> None: if not queue: raise KeyError( 'Received message without destination queue: {0}'.format( @@ -980,7 +1089,7 @@ class Transport(base.Transport): else: callback(message) - def _reject_inbound_message(self, raw_message): + def _reject_inbound_message(self, raw_message: Any) -> None: for channel in self.channels: if channel: message = channel.Message(raw_message, channel=channel) @@ -988,16 +1097,18 @@ class Transport(base.Transport): channel.basic_reject(message.delivery_tag, requeue=True) break - def on_message_ready(self, channel, message, queue): + def on_message_ready( + self, channel: ChannelT, message: MessageT, queue: str) -> None: if not queue or queue not in self._callbacks: raise KeyError( 'Message for queue {0!r} without consumers: {1}'.format( queue, message)) self._callbacks[queue](message) - def _drain_channel(self, channel, callback, timeout=None): + def _drain_channel(self, channel: ChannelT, callback: Callable, + timeout: float = None) -> None: return channel.drain_events(callback=callback, timeout=timeout) @property - def default_connection_params(self): - return {'port': self.default_port, 'hostname': 'localhost'} + def default_connection_params(self) -> Mapping: + return {'port': self.default_port, 'hostnaime': 'localhost'} diff --git a/kombu/transport/virtual/exchange.py b/kombu/transport/virtual/exchange.py index 33d63066..98c0e6db 100644 --- a/kombu/transport/virtual/exchange.py +++ b/kombu/transport/virtual/exchange.py @@ -3,10 +3,11 @@ Implementations of the standard exchanges defined by the AMQ protocol (excluding the `headers` exchange). """ - import re - +from typing import Mapping, Match, Pattern, Set, Tuple +from amqp.types import ChannelT from kombu.utils.text import escape_regex +from kombu.types import MessageT class ExchangeType: @@ -18,9 +19,10 @@ class ExchangeType: channel (ChannelT): AMQ Channel. """ - type = None + type: str + channel: ChannelT - def __init__(self, channel): + def __init__(self, channel: ChannelT) -> None: self.channel = channel def lookup(self, table, exchange, routing_key, default): @@ -57,13 +59,18 @@ class DirectExchange(ExchangeType): type = 'direct' - def lookup(self, table, exchange, routing_key, default): + def lookup(self, + table: Mapping, + exchange: str, + routing_key: str, + default: str) -> Set[str]: return { queue for rkey, _, queue in table if rkey == routing_key } - def deliver(self, message, exchange, routing_key, **kwargs): + def deliver(self, message: MessageT, + exchange: str, routing_key: str, **kwargs) -> None: _lookup = self.channel._lookup _put = self.channel._put for queue in _lookup(exchange, routing_key): @@ -81,19 +88,26 @@ class TopicExchange(ExchangeType): type = 'topic' #: map of wildcard to regex conversions - wildcards = {'*': r'.*?[^\.]', - '#': r'.*?'} + wildcards: Mapping[str, str] = { + '*': r'.*?[^\.]', + '#': r'.*?', + } #: compiled regex cache - _compiled = {} + _compiled: Mapping[str, Pattern] = {} - def lookup(self, table, exchange, routing_key, default): + def lookup(self, + table: Mapping, + exchange: str, + routing_key: str, + default: str) -> Set[str]: return { queue for rkey, pattern, queue in table if self._match(pattern, routing_key) } - def deliver(self, message, exchange, routing_key, **kwargs): + def deliver(self, message: MessageT, + exchange: str, routing_key: str, **kwargs) -> None: _lookup = self.channel._lookup _put = self.channel._put deadletter = self.channel.deadletter_queue @@ -101,17 +115,21 @@ class TopicExchange(ExchangeType): if q and q != deadletter]: _put(queue, message, **kwargs) - def prepare_bind(self, queue, exchange, routing_key, arguments): + def prepare_bind(self, + queue: str, + exchange: str, + routing_key: str, + arguments: Mapping) -> Tuple[str, Pattern, str]: return routing_key, self.key_to_pattern(routing_key), queue - def key_to_pattern(self, rkey): + def key_to_pattern(self, rkey: str) -> str: """Get the corresponding regex for any routing key.""" return '^%s$' % ('\.'.join( self.wildcards.get(word, word) for word in escape_regex(rkey, '.#*').split('.') )) - def _match(self, pattern, string): + def _match(self, pattern: str, string: str) -> Match: """Match regular expression (cached). Same as :func:`re.match`, except the regex is compiled and cached, @@ -141,17 +159,19 @@ class FanoutExchange(ExchangeType): type = 'fanout' - def lookup(self, table, exchange, routing_key, default): + def lookup(self, table: Mapping, + exchange: str, routing_key: str, default: str) -> Set[str]: return {queue for _, _, queue in table} - def deliver(self, message, exchange, routing_key, **kwargs): + def deliver(self, message: MessageT, + exchange: str, routing_key: str, **kwargs) -> None: if self.channel.supports_fanout: self.channel._put_fanout( exchange, message, routing_key, **kwargs) #: Map of standard exchange types and corresponding classes. -STANDARD_EXCHANGE_TYPES = { +STANDARD_EXCHANGE_TYPES: Mapping[str, type] = { 'direct': DirectExchange, 'topic': TopicExchange, 'fanout': FanoutExchange, diff --git a/kombu/transport/zookeeper.py b/kombu/transport/zookeeper.py deleted file mode 100644 index cf9f0dc8..00000000 --- a/kombu/transport/zookeeper.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Zookeeper transport. - -:copyright: (c) 2010 - 2013 by Mahendra M. -:license: BSD, see LICENSE for more details. - -**Synopsis** - -Connects to a zookeeper node as <server>:<port>/<vhost> -The <vhost> becomes the base for all the other znodes. So we can use -it like a vhost. - -This uses the built-in kazoo recipe for queues - -**References** - -- https://zookeeper.apache.org/doc/trunk/recipes.html#sc_recipes_Queues -- https://kazoo.readthedocs.io/en/latest/api/recipe/queue.html - -**Limitations** -This queue does not offer reliable consumption. An entry is removed from -the queue prior to being processed. So if an error occurs, the consumer -has to re-queue the item or it will be lost. -""" - -import os -import socket - -from queue import Empty - -from kombu.utils.encoding import bytes_to_str -from kombu.utils.json import loads, dumps - -from . import virtual - -try: - import kazoo - from kazoo.client import KazooClient - from kazoo.recipe.queue import Queue - - KZ_CONNECTION_ERRORS = ( - kazoo.exceptions.SystemErrorException, - kazoo.exceptions.ConnectionLossException, - kazoo.exceptions.MarshallingErrorException, - kazoo.exceptions.UnimplementedException, - kazoo.exceptions.OperationTimeoutException, - kazoo.exceptions.NoAuthException, - kazoo.exceptions.InvalidACLException, - kazoo.exceptions.AuthFailedException, - kazoo.exceptions.SessionExpiredException, - ) - - KZ_CHANNEL_ERRORS = ( - kazoo.exceptions.RuntimeInconsistencyException, - kazoo.exceptions.DataInconsistencyException, - kazoo.exceptions.BadArgumentsException, - kazoo.exceptions.MarshallingErrorException, - kazoo.exceptions.UnimplementedException, - kazoo.exceptions.OperationTimeoutException, - kazoo.exceptions.ApiErrorException, - kazoo.exceptions.NoNodeException, - kazoo.exceptions.NoAuthException, - kazoo.exceptions.NodeExistsException, - kazoo.exceptions.NoChildrenForEphemeralsException, - kazoo.exceptions.NotEmptyException, - kazoo.exceptions.SessionExpiredException, - kazoo.exceptions.InvalidCallbackException, - socket.error, - ) -except ImportError: - kazoo = None # noqa - KZ_CONNECTION_ERRORS = KZ_CHANNEL_ERRORS = () # noqa - -DEFAULT_PORT = 2181 - -__author__ = 'Mahendra M <mahendra.m@gmail.com>' - - -class Channel(virtual.Channel): - """Zookeeper Channel.""" - - _client = None - _queues = {} - - def _get_path(self, queue_name): - return os.path.join(self.vhost, queue_name) - - def _get_queue(self, queue_name): - queue = self._queues.get(queue_name, None) - - if queue is None: - queue = Queue(self.client, self._get_path(queue_name)) - self._queues[queue_name] = queue - - # Ensure that the queue is created - len(queue) - - return queue - - def _put(self, queue, message, **kwargs): - return self._get_queue(queue).put( - dumps(message), - priority=self._get_message_priority(message, reverse=True), - ) - - def _get(self, queue): - queue = self._get_queue(queue) - msg = queue.get() - - if msg is None: - raise Empty() - - return loads(bytes_to_str(msg)) - - def _purge(self, queue): - count = 0 - queue = self._get_queue(queue) - - while True: - msg = queue.get() - if msg is None: - break - count += 1 - - return count - - def _delete(self, queue, *args, **kwargs): - if self._has_queue(queue): - self._purge(queue) - self.client.delete(self._get_path(queue)) - - def _size(self, queue): - queue = self._get_queue(queue) - return len(queue) - - def _new_queue(self, queue, **kwargs): - if not self._has_queue(queue): - queue = self._get_queue(queue) - - def _has_queue(self, queue): - return self.client.exists(self._get_path(queue)) is not None - - def _open(self): - conninfo = self.connection.client - self.vhost = os.path.join('/', conninfo.virtual_host[0:-1]) - hosts = [] - if conninfo.alt: - for host_port in conninfo.alt: - if host_port.startswith('zookeeper://'): - host_port = host_port[len('zookeeper://'):] - if not host_port: - continue - try: - host, port = host_port.split(':', 1) - host_port = (host, int(port)) - except ValueError: - if host_port == conninfo.hostname: - host_port = (host_port, conninfo.port or DEFAULT_PORT) - else: - host_port = (host_port, DEFAULT_PORT) - hosts.append(host_port) - host_port = (conninfo.hostname, conninfo.port or DEFAULT_PORT) - if host_port not in hosts: - hosts.insert(0, host_port) - conn_str = ','.join(['%s:%s' % (h, p) for h, p in hosts]) - conn = KazooClient(conn_str) - conn.start() - return conn - - @property - def client(self): - if self._client is None: - self._client = self._open() - return self._client - - -class Transport(virtual.Transport): - """Zookeeper Transport.""" - - Channel = Channel - polling_interval = 1 - default_port = DEFAULT_PORT - connection_errors = ( - virtual.Transport.connection_errors + KZ_CONNECTION_ERRORS - ) - channel_errors = ( - virtual.Transport.channel_errors + KZ_CHANNEL_ERRORS - ) - driver_type = 'zookeeper' - driver_name = 'kazoo' - - def __init__(self, *args, **kwargs): - if kazoo is None: - raise ImportError('The kazoo library is not installed') - - super().__init__(*args, **kwargs) - - def driver_version(self): - return kazoo.__version__ diff --git a/kombu/types.py b/kombu/types.py new file mode 100644 index 00000000..5825fb1f --- /dev/null +++ b/kombu/types.py @@ -0,0 +1,893 @@ +import abc +import logging +from collections import Hashable, deque +from numbers import Number +from typing import ( + Any, Callable, Dict, Iterable, List, Mapping, Optional, + Sequence, TypeVar, Tuple, Union, +) +from typing import Set # noqa +from amqp.types import ChannelT, MessageT as _MessageT, TransportT +from amqp.spec import queue_declare_ok_t + + +def _hasattr(C: Any, attr: str) -> bool: + return any(attr in B.__dict__ for B in C.__mro__) + + +class _AbstractClass(abc.ABCMeta): + __required_attributes__ = frozenset() # type: frozenset + + @classmethod + def _subclasshook_using(cls, parent: Any, C: Any): + return ( + cls is parent and + all(_hasattr(C, attr) for attr in cls.__required_attributes__) + ) or NotImplemented + + +class Revivable(metaclass=_AbstractClass): + + async def revive(self, channel: ChannelT) -> None: + ... + + +class ClientT(Revivable, metaclass=_AbstractClass): + + hostname: str + userid: str + password: str + ssl: Any + login_method: str + port: int = None + virtual_host: str = '/' + connect_timeout: float = 5.0 + alt: Sequence[str] + + uri_prefix: str = None + declared_entities: Set['EntityT'] = None + cycle: Iterable + transport_options: Mapping + failover_strategy: str + heartbeat: float = None + resolve_aliases: Mapping[str, str] = resolve_aliases + failover_strategies: Mapping[str, Callable] = failover_strategies + + @property + @abc.abstractmethod + def default_channel(self) -> ChannelT: + ... + + def __init__( + self, + hostname: str = 'localhost', + userid: str = None, + password: str = None, + virtual_host: str = None, + port: int = None, + insist: bool = False, + ssl: Any = None, + transport: Union[type, str] = None, + connect_timeout: float = 5.0, + transport_options: Mapping = None, + login_method: str = None, + uri_prefix: str = None, + heartbeat: float = None, + failover_strategy: str = 'round-robin', + alternates: Sequence[str] = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + def switch(self, url: str) -> None: + ... + + @abc.abstractmethod + def maybe_switch_next(self) -> None: + ... + + @abc.abstractmethod + async def connect(self) -> None: + ... + + @abc.abstractmethod + def channel(self) -> ChannelT: + ... + + @abc.abstractmethod + async def heartbeat_check(self, rate: int = 2) -> None: + ... + + @abc.abstractmethod + async def drain_events(self, timeout: float = None, **kwargs) -> None: + ... + + @abc.abstractmethod + async def maybe_close_channel(self, channel: ChannelT) -> None: + ... + + @abc.abstractmethod + async def collect(self, socket_timeout: float = None) -> None: + ... + + @abc.abstractmethod + async def release(self) -> None: + ... + + @abc.abstractmethod + async def close(self) -> None: + ... + + @abc.abstractmethod + async def ensure_connection( + self, + errback: Callable = None, + max_retries: int = None, + interval_start: float = 2.0, + interval_step: float = 2.0, + interval_max: float = 30.0, + callback: Callable = None, + reraise_as_library_errors: bool = True) -> 'ClientT': + ... + + @abc.abstractmethod + def completes_cycle(self, retries: int) -> bool: + ... + + @abc.abstractmethod + async def ensure( + self, obj: Revivable, fun: Callable, + errback: Callable = None, + max_retries: int = None, + interval_start: float = 1.0, + interval_step: float = 1.0, + interval_max: float = 1.0, + on_revive: Callable = None) -> Any: + ... + + @abc.abstractmethod + def autoretry(self, fun: Callable, + channel: ChannelT = None, **ensure_options) -> Any: + ... + + @abc.abstractmethod + def create_transport(self) -> TransportT: + ... + + @abc.abstractmethod + def get_transport_cls(self) -> type: + ... + + @abc.abstractmethod + def clone(self, **kwargs) -> 'ClientT': + ... + + @abc.abstractmethod + def get_heartbeat_interval(self) -> float: + ... + + @abc.abstractmethod + def info(self) -> Mapping[str, Any]: + ... + + @abc.abstractmethod + def as_uri( + self, + *, + include_password: bool = False, + mask: str = '**', + getfields: Callable = None) -> str: + ... + + @abc.abstractmethod + def Pool(self, limit: int = None, **kwargs) -> 'ResourceT': + ... + + @abc.abstractmethod + def ChannelPool(self, limit: int = None, **kwargs) -> 'ResourceT': + ... + + @abc.abstractmethod + def Producer(self, channel: ChannelT = None, + *args, **kwargs) -> 'ProducerT': + ... + + @abc.abstractmethod + def Consumer( + self, + queues: Sequence['QueueT'] = None, + channel: ChannelT = None, + *args, **kwargs) -> 'ConsumerT': + ... + + @abc.abstractmethod + def SimpleQueue( + self, name: str, + no_ack: bool = None, + queue_opts: Mapping = None, + exchange_opts: Mapping = None, + channel: ChannelT = None, + **kwargs) -> 'SimpleQueueT': + ... + + +class EntityT(Hashable, Revivable, metaclass=_AbstractClass): + + @property + @abc.abstractmethod + def can_cache_declaration(self) -> bool: + ... + + @property + @abc.abstractmethod + def is_bound(self) -> bool: + ... + + @property + @abc.abstractmethod + def channel(self) -> ChannelT: + ... + + @abc.abstractmethod + def __call__(self, channel: ChannelT) -> 'EntityT': + ... + + @abc.abstractmethod + def bind(self, channel: ChannelT) -> 'EntityT': + ... + + @abc.abstractmethod + def maybe_bind(self, channel: Optional[ChannelT]) -> 'EntityT': + ... + + @abc.abstractmethod + def when_bound(self) -> None: + ... + + @abc.abstractmethod + def as_dict(self, recurse: bool = False) -> Dict: + ... + + +ChannelArgT = TypeVar('ChannelArgT', ChannelT, ClientT) + + +class ExchangeT(EntityT, metaclass=_AbstractClass): + + name: str + type: str + channel: ChannelT + arguments: Mapping + durable: bool = True + auto_delete: bool = False + delivery_mode: int + no_declare: bool = False + passive: bool = False + + def __init__( + self, + name: str = '', + type: str = '', + channel: ChannelArgT = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def declare( + self, + nowait: bool = False, + passive: bool = None, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def bind_to( + self, + exchange: Union[str, 'ExchangeT'] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def unbind_from( + self, + source: Union[str, 'ExchangeT'] = '', + routing_key: str = '', + nowait: bool = False, + arguments: Mapping = None, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def publish( + self, message: Union[MessageT, str, bytes], + routing_key: str = None, + mandatory: bool = False, + immediate: bool = False, + exchange: str = None) -> None: + ... + + @abc.abstractmethod + async def delete( + self, + if_unused: bool = False, + nowait: bool = False) -> None: + ... + + @abc.abstractmethod + def binding( + self, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> BindingT: + ... + + @abc.abstractmethod + def Message( + self, body: Any, + delivery_mode: Union[str, int] = None, + properties: Mapping = None, + **kwargs) -> Any: + ... + + +class QueueT(EntityT, metaclass=_AbstractClass): + + name: str = '' + exchange: ExchangeT + routing_key: str = '' + alias: str + + bindings: Sequence['BindingT'] = None + + durable: bool = True + exclusive: bool = False + auto_delete: bool = False + no_ack: bool = False + no_declare: bool = False + expires: float + message_ttl: float + max_length: int + max_length_bytes: int + max_priority: int + + queue_arguments: Mapping + binding_arguments: Mapping + consumer_arguments: Mapping + + def __init__( + self, + name: str = '', + exchange: ExchangeT = None, + routing_key: str = '', + channel: ChannelT = None, + bindings: Sequence['BindingT'] = None, + on_declared: Callable = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def declare(self, + nowait: bool = False, + channel: ChannelT = None) -> str: + ... + + @abc.abstractmethod + async def queue_declare( + self, + nowait: bool = False, + passive: bool = False, + channel: ChannelT = None) -> queue_declare_ok_t: + ... + + @abc.abstractmethod + async def queue_bind( + self, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def bind_to( + self, + exchange: Union[str, ExchangeT] = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def get( + self, + no_ack: bool = None, + accept: Set[str] = None) -> MessageT: + ... + + @abc.abstractmethod + async def purge(self, nowait: bool = False) -> int: + ... + + @abc.abstractmethod + async def consume( + self, + consumer_tag: str = '', + callback: Callable = None, + no_ack: bool = None, + nowait: bool = False) -> None: + ... + + @abc.abstractmethod + async def cancel(self, consumer_tag: str) -> None: + ... + + @abc.abstractmethod + async def delete( + self, + if_unused: bool = False, + if_empty: bool = False, + nowait: bool = False) -> None: + ... + + @abc.abstractmethod + async def queue_unbind( + self, + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @abc.abstractmethod + async def unbind_from( + self, + exchange: str = '', + routing_key: str = '', + arguments: Mapping = None, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + @classmethod + @abc.abstractmethod + def from_dict( + self, queue: str, + exchange: str = None, + exchange_type: str = None, + binding_key: str = None, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + bindings: Sequence = None, + durable: bool = None, + queue_durable: bool = None, + exchange_durable: bool = None, + auto_delete: bool = None, + queue_auto_delete: bool = None, + exchange_auto_delete: bool = None, + exchange_arguments: Mapping = None, + queue_arguments: Mapping = None, + binding_arguments: Mapping = None, + consumer_arguments: Mapping = None, + exclusive: bool = None, + no_ack: bool = None, + **options) -> 'QueueT': + ... + + +class BindingT(metaclass=_AbstractClass): + + exchange: ExchangeT + routing_key: str + arguments: Mapping + unbind_arguments: Mapping + + def __init__( + self, + exchange: ExchangeT = None, + routing_key: str = '', + arguments: Mapping = None, + unbind_arguments: Mapping = None) -> None: + ... + + def declare(self, channel: ChannelT, nowait: bool = False) -> None: + ... + + def bind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + def unbind( + self, entity: EntityT, + nowait: bool = False, + channel: ChannelT = None) -> None: + ... + + +class ConsumerT(Revivable, metaclass=_AbstractClass): + + ContentDisallowed: type + + channel: ChannelT + queues: Sequence[QueueT] + no_ack: bool + auto_declare: bool = True + callbacks: Sequence[Callable] + on_message: Callable + on_decode_error: Callable + accept: Set[str] + prefetch_count: int = None + tag_prefix: str + + @property + @abc.abstractmethod + def connection(self) -> ClientT: + ... + + def __init__( + self, + channel: ChannelT, + queues: Sequence[QueueT] = None, + no_ack: bool = None, + auto_declare: bool = None, + callbacks: Sequence[Callable] = None, + on_decode_error: Callable = None, + on_message: Callable = None, + accept: Sequence[str] = None, + prefetch_count: int = None, + tag_prefix: str = None) -> None: + ... + + @abc.abstractmethod + async def declare(self) -> None: + ... + + @abc.abstractmethod + def register_callback(self, callback: Callable) -> None: + ... + + @abc.abstractmethod + def __enter__(self) -> 'ConsumerT': + ... + + @abc.abstractmethod + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + ... + + @abc.abstractmethod + async def __aenter__(self) -> 'ConsumerT': + ... + + @abc.abstractmethod + async def __aexit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + async def add_queue(self, queue: QueueT) -> QueueT: + ... + + @abc.abstractmethod + async def consume(self, no_ack: bool = None) -> None: + ... + + @abc.abstractmethod + async def cancel(self) -> None: + ... + + @abc.abstractmethod + async def cancel_by_queue(self, queue: Union[QueueT, str]) -> None: + ... + + @abc.abstractmethod + def consuming_from(self, queue: Union[QueueT, str]) -> bool: + ... + + @abc.abstractmethod + async def purge(self) -> int: + ... + + @abc.abstractmethod + async def flow(self, active: bool) -> None: + ... + + @abc.abstractmethod + async def qos(self, + prefetch_size: int = 0, + prefetch_count: int = 0, + apply_global: bool = False) -> None: + ... + + @abc.abstractmethod + async def recover(self, requeue: bool = False) -> None: + ... + + @abc.abstractmethod + async def receive(self, body: Any, message: MessageT) -> None: + ... + + +class ProducerT(Revivable, metaclass=_AbstractClass): + + exchange: ExchangeT + routing_key: str = '' + serializer: str + compression: str + auto_declare: bool = True + on_return: Callable = None + + __connection__: ClientT + + channel: ChannelT + + @property + @abc.abstractmethod + def connection(self) -> ClientT: + ... + + def __init__(self, + channel: ChannelArgT, + exchange: ExchangeT = None, + routing_key: str = None, + serializer: str = None, + auto_declare: bool = None, + compression: str = None, + on_return: Callable = None) -> None: + ... + + @abc.abstractmethod + async def declare(self) -> None: + ... + + @abc.abstractmethod + async def maybe_declare(self, entity: EntityT, + retry: bool = False, **retry_policy) -> None: + ... + + @abc.abstractmethod + async def publish( + self, body: Any, + routing_key: str = None, + delivery_mode: Union[int, str] = None, + mandatory: bool = False, + immediate: bool = False, + priority: int = 0, + content_type: str = None, + content_encoding: str = None, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + exchange: Union[ExchangeT, str] = None, + retry: bool = False, + retry_policy: Mapping = None, + declare: Sequence[EntityT] = None, + expiration: Number = None, + **properties) -> None: + ... + + @abc.abstractmethod + def __enter__(self) -> 'ProducerT': + ... + + @abc.abstractmethod + async def __aenter__(self) -> 'ProducerT': + ... + + @abc.abstractmethod + def __exit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + async def __aexit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + def release(self) -> None: + ... + + @abc.abstractmethod + def close(self) -> None: + ... + + +class MessageT(_MessageT): + + MessageStateError: type + + _state: str + channel: ChannelT = None + delivery_tag: str + content_type: str + content_encoding: str + delivery_info: Mapping + headers: Mapping + properties: Mapping + accept: Set[str] + body: Any + errors: List[Any] + + @property + @abc.abstractmethod + def acknowledged(self) -> bool: + ... + + @property + @abc.abstractmethod + def payload(self) -> Any: + ... + + def __init__( + self, + body: Any = None, + delivery_tag: str = None, + content_type: str = None, + content_encoding: str = None, + delivery_info: Mapping = None, + properties: Mapping = None, + headers: Mapping = None, + postencode: str = None, + accept: Set[str] = None, + channel: ChannelT = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + async def ack(self, multiple: bool = False) -> None: + ... + + @abc.abstractmethod + async def ack_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + multiple: bool = False) -> None: + ... + + @abc.abstractmethod + async def reject_log_error( + self, logger: logging.Logger, errors: Tuple[type, ...], + requeue: bool = False) -> None: + ... + + @abc.abstractmethod + async def reject(self, requeue: bool = False) -> None: + ... + + @abc.abstractmethod + async def requeue(self) -> None: + ... + + @abc.abstractmethod + def decode(self) -> Any: + ... + + +class ResourceT(metaclass=_AbstractClass): + + LimitedExceeeded: type + + close_after_fork: bool = False + + @property + @abc.abstractmethod + def limit(self) -> int: + ... + + def __init__(self, + limit: int = None, + preload: int = None, + close_after_fork: bool = None) -> None: + ... + + @abc.abstractmethod + def setup(self) -> None: + ... + + @abc.abstractmethod + def acquire(self, block: bool = False, timeout: int = None): + ... + + @abc.abstractmethod + def prepare(self, resource: Any) -> Any: + ... + + @abc.abstractmethod + def close_resource(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def release_resource(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def replace(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def release(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def collect_resource(self, resource: Any) -> None: + ... + + @abc.abstractmethod + def force_close_all(self) -> None: + ... + + @abc.abstractmethod + def resize( + self, limit: int, + force: bool = False, + ignore_errors: bool = False, + reset: bool = False) -> None: + ... + + +class SimpleQueueT(metaclass=_AbstractClass): + + Empty: type + + channel: ChannelT + producer: ProducerT + consumer: ConsumerT + no_ack: bool = False + queue: QueueT + buffer: deque + + def __init__( + self, + channel: ChannelArgT, + name: str, + no_ack: bool = None, + queue_opts: Mapping = None, + exchange_opts: Mapping = None, + serializer: str = None, + compression: str = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + def __enter__(self) -> 'SimpleQueueT': + ... + + @abc.abstractmethod + def __exit__(self, *exc_info) -> None: + ... + + @abc.abstractmethod + def get(self, block: bool = True, timeout: float = None) -> MessageT: + ... + + @abc.abstractmethod + def get_nowait(self) -> MessageT: + ... + + @abc.abstractmethod + def put(self, message: Any, + serializer: str = None, + headers: Mapping = None, + compression: str = None, + routing_key: str = None, + **kwargs) -> None: + ... + + @abc.abstractmethod + def clear(self) -> int: + ... + + @abc.abstractmethod + def qsize(self) -> int: + ... + + @abc.abstractmethod + def close(self) -> None: + ... + + @abc.abstractmethod + def __len__(self) -> int: + ... + + @abc.abstractmethod + def __bool__(self) -> bool: + ... diff --git a/kombu/utils/abstract.py b/kombu/utils/abstract.py deleted file mode 100644 index 27a27d45..00000000 --- a/kombu/utils/abstract.py +++ /dev/null @@ -1,49 +0,0 @@ -import abc - -from typing import Any -from typing import Set # noqa - - -def _hasattr(C: Any, attr: str) -> bool: - return any(attr in B.__dict__ for B in C.__mro__) - - -class _AbstractClass(object, metaclass=abc.ABCMeta): - __required_attributes__ = frozenset() # type: frozenset - - @classmethod - def _subclasshook_using(cls, parent: Any, C: Any): - return ( - cls is parent and - all(_hasattr(C, attr) for attr in cls.__required_attributes__) - ) or NotImplemented - - @classmethod - def register(cls, other: Any) -> Any: - # we override `register` to return other for use as a decorator. - type(cls).register(cls, other) - return other - - -class Connection(_AbstractClass): - ... - - -class Entity(_AbstractClass): - ... - - -class Consumer(_AbstractClass): - ... - - -class Producer(_AbstractClass): - ... - - -class Messsage(_AbstractClass): - ... - - -class Resource(_AbstractClass): - ... diff --git a/kombu/utils/amq_manager.py b/kombu/utils/amq_manager.py index fcf1c6d5..c8cb8508 100644 --- a/kombu/utils/amq_manager.py +++ b/kombu/utils/amq_manager.py @@ -1,9 +1,9 @@ """AMQP Management API utilities.""" from typing import Any, Union -from . import abstract +from kombu.types import ClientT -def get_manager(client: abstract.Connection, +def get_manager(client: ClientT, hostname: str = None, port: Union[int, str] = None, userid: str = None, diff --git a/kombu/utils/collections.py b/kombu/utils/collections.py index 0ee27399..b202470b 100644 --- a/kombu/utils/collections.py +++ b/kombu/utils/collections.py @@ -1,4 +1,5 @@ """Custom maps, sequences, etc.""" +from typing import Any, Union class HashedSeq(list): @@ -10,15 +11,15 @@ class HashedSeq(list): __slots__ = 'hashvalue' - def __init__(self, *seq): + def __init__(self, *seq) -> None: self[:] = seq self.hashvalue = hash(seq) - def __hash__(self): + def __hash__(self) -> int: return self.hashvalue -def eqhash(o): +def eqhash(o: Any) -> Union[int, HashedSeq]: """Call ``obj.__eqhash__``.""" try: return o.__eqhash__() @@ -29,14 +30,17 @@ def eqhash(o): class EqualityDict(dict): """Dict using the eq operator for keying.""" - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: h = eqhash(key) if h not in self: return self.__missing__(key) return dict.__getitem__(self, h) - def __setitem__(self, key, value): - return dict.__setitem__(self, eqhash(key), value) + def __setitem__(self, key: Any, value: Any) -> None: + dict.__setitem__(self, eqhash(key), value) - def __delitem__(self, key): - return dict.__delitem__(self, eqhash(key)) + def __delitem__(self, key: Any) -> None: + dict.__delitem__(self, eqhash(key)) + + def __missing__(self, key: Any) -> Any: + raise KeyError(key) diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py index c9390ec3..b12dd41e 100644 --- a/kombu/utils/compat.py +++ b/kombu/utils/compat.py @@ -3,6 +3,8 @@ import numbers import sys from functools import wraps from contextlib import contextmanager +from typing import Any, Callable, Iterator, Optional, Tuple, Union +from .typing import SupportsFileno try: from io import UnsupportedOperation @@ -21,17 +23,17 @@ except ImportError: # pragma: no cover _environment = None -def coro(gen): +def coro(gen: Callable) -> Callable: """Decorator to mark generator as co-routine.""" @wraps(gen) - def wind_up(*args, **kwargs): + def wind_up(*args, **kwargs) -> Any: it = gen(*args, **kwargs) next(it) return it return wind_up -def _detect_environment(): +def _detect_environment() -> str: # ## -eventlet- if 'eventlet' in sys.modules: try: @@ -57,14 +59,15 @@ def _detect_environment(): return 'default' -def detect_environment(): +def detect_environment() -> str: """Detect the current environment: default, eventlet, or gevent.""" global _environment if _environment is None: _environment = _detect_environment() return _environment -def entrypoints(namespace): + +def entrypoints(namespace: str) -> Iterator[Tuple[str, Any]]: """Return setuptools entrypoints for namespace.""" try: from pkg_resources import iter_entry_points @@ -73,14 +76,14 @@ def entrypoints(namespace): return ((ep, ep.load()) for ep in iter_entry_points(namespace)) -def fileno(f): +def fileno(f: Union[SupportsFileno, numbers.Integral]) -> numbers.Integral: """Get fileno from file-like object.""" if isinstance(f, numbers.Integral): return f return f.fileno() -def maybe_fileno(f): +def maybe_fileno(f: Any) -> Optional[numbers.Integral]: """Get object fileno, or :const:`None` if not defined.""" try: return fileno(f) diff --git a/kombu/utils/debug.py b/kombu/utils/debug.py index 6bc59020..df98a04d 100644 --- a/kombu/utils/debug.py +++ b/kombu/utils/debug.py @@ -1,10 +1,7 @@ """Debugging support.""" import logging - -from typing import Any, Optional, Sequence, Union - +from typing import Any, Sequence, Union from vine.utils import wraps - from kombu.log import get_logger __all__ = ['setup_logging', 'Logwrapped'] @@ -29,8 +26,8 @@ class Logwrapped: __ignore = ('__enter__', '__exit__') def __init__(self, instance: Any, - logger: Optional[LoggerArg]=None, - ident: Optional[str]=None) -> None: + logger: LoggerArg = None, + ident: str = None) -> None: self.instance = instance self.logger = get_logger(logger) # type: logging.Logger self.ident = ident diff --git a/kombu/utils/div.py b/kombu/utils/div.py index 21b4c59a..0535c30f 100644 --- a/kombu/utils/div.py +++ b/kombu/utils/div.py @@ -1,9 +1,13 @@ """Div. Utilities.""" -from .encoding import default_encode import sys +from typing import Any, Callable, IO +from .encoding import default_encode -def emergency_dump_state(state, open_file=open, dump=None, stderr=None): +def emergency_dump_state(state: Any, + open_file: Callable = open, + dump: Callable = None, + stderr: IO = None): """Dump message state to stdout or file.""" from pprint import pformat from tempfile import mktemp diff --git a/kombu/utils/encoding.py b/kombu/utils/encoding.py index a68fef25..8faf32a1 100644 --- a/kombu/utils/encoding.py +++ b/kombu/utils/encoding.py @@ -8,7 +8,6 @@ applications without crashing from the infamous import sys import traceback - from typing import Any, AnyStr, IO, Optional str_t = str @@ -67,7 +66,7 @@ def default_encode(obj: Any) -> Any: return obj -def safe_str(s: Any, errors: str='replace') -> str: +def safe_str(s: Any, errors: str = 'replace') -> str: """Safe form of str(), void of unicode errors.""" s = bytes_to_str(s) if not isinstance(s, (str, bytes)): @@ -75,7 +74,7 @@ def safe_str(s: Any, errors: str='replace') -> str: return _safe_str(s, errors) -def _safe_str(s: Any, errors: str='replace', file: IO=None) -> str: +def _safe_str(s: Any, errors: str = 'replace', file: IO = None) -> str: if isinstance(s, str): return s try: @@ -85,7 +84,7 @@ def _safe_str(s: Any, errors: str='replace', file: IO=None) -> str: type(s), exc, '\n'.join(traceback.format_stack())) -def safe_repr(o: Any, errors='replace') -> str: +def safe_repr(o: Any, errors: str = 'replace') -> str: """Safe form of repr, void of Unicode errors.""" try: return repr(o) diff --git a/kombu/utils/eventio.py b/kombu/utils/eventio.py index 8be03494..d10f5868 100644 --- a/kombu/utils/eventio.py +++ b/kombu/utils/eventio.py @@ -8,7 +8,7 @@ from numbers import Integral from typing import Any, Callable, Optional, Sequence, IO, cast from typing import Set, Tuple # noqa from . import fileno -from .typing import Fd, Timeout +from .typing import Fd from .compat import detect_environment __all__ = ['poll'] @@ -81,7 +81,7 @@ class _epoll(BasePoller): except (PermissionError, FileNotFoundError): pass - def poll(self, timeout: Timeout) -> Optional[Sequence]: + def poll(self, timeout: float) -> Optional[Sequence]: try: return self._epoll.poll(timeout if timeout is not None else -1) except Exception as exc: @@ -148,7 +148,7 @@ class _kqueue(BasePoller): except ValueError: pass - def poll(self, timeout: Timeout) -> Sequence: + def poll(self, timeout: float) -> Sequence: try: kevents = self._kcontrol(None, 1000, timeout) except Exception as exc: @@ -212,7 +212,7 @@ class _poll(BasePoller): self._quick_unregister(fd) return fd - def poll(self, timeout: Timeout, + def poll(self, timeout: float, round: Callable=math.ceil, POLLIN: int=POLLIN, POLLOUT: int=POLLOUT, POLLERR: int=POLLERR, READ: int=READ, WRITE: int=WRITE, ERR: int=ERR, @@ -284,7 +284,7 @@ class _select(BasePoller): self._wfd.discard(fd) self._efd.discard(fd) - def poll(self, timeout: Timeout) -> Sequence: + def poll(self, timeout: float) -> Sequence: try: read, write, error = _selectf( cast(Sequence, self._rfd), diff --git a/kombu/utils/functional.py b/kombu/utils/functional.py index 85feffa5..8d631809 100644 --- a/kombu/utils/functional.py +++ b/kombu/utils/functional.py @@ -6,9 +6,11 @@ from itertools import count, repeat from time import sleep from typing import ( Any, Callable, Dict, Iterable, Iterator, - KeysView, Mapping, Optional, Sequence, Tuple, + KeysView, ItemsView, Mapping, Optional, Sequence, Tuple, ValuesView, ) +from amqp.types import ChannelT from vine.utils import wraps +from ..types import ClientT from .encoding import safe_repr as _safe_repr __all__ = [ @@ -17,27 +19,36 @@ __all__ = [ ] KEYWORD_MARK = object() - MemoizeKeyFun = Callable[[Sequence, Mapping], Any] class ChannelPromise(object): + __value__: ChannelT + + def __init__(self, connection: ClientT) -> None: + self.__connection__ = connection - def __init__(self, contract): - self.__contract__ = contract + def __call__(self) -> Any: + try: + return self.__value__ + except AttributeError: + value = self.__value__ = self.__connection__.default_channel + return value - def __call__(self): + async def resolve(self): try: return self.__value__ except AttributeError: - value = self.__value__ = self.__contract__() + await self.__connection__.connect() + value = self.__value__ = self.__connection__.default_channel + await value.open() return value - def __repr__(self): + def __repr__(self) -> str: try: return repr(self.__value__) except AttributeError: - return '<promise: 0x{0:x}>'.format(id(self.__contract__)) + return '<promise: 0x{0:x}>'.format(id(self.__connection__)) class LRUCache(UserDict): @@ -53,7 +64,7 @@ class LRUCache(UserDict): def __init__(self, limit: int = None) -> None: self.limit = limit self.mutex = threading.RLock() - self.data = OrderedDict() # type: OrderedDict + self.data: OrderedDict = OrderedDict() def __getitem__(self, key: Any) -> Any: with self.mutex: @@ -69,7 +80,7 @@ class LRUCache(UserDict): for _ in range(len(data) - limit): data.popitem(last=False) - def popitem(self, last: bool=True) -> Any: + def popitem(self, last: bool = True) -> Any: with self.mutex: return self.data.popitem(last) @@ -83,7 +94,7 @@ class LRUCache(UserDict): def __iter__(self) -> Iterator: return iter(self.data) - def items(self) -> Iterator[Tuple[Any, Any]]: + def items(self) -> ItemsView[Any, Any]: with self.mutex: for k in self: try: @@ -91,7 +102,7 @@ class LRUCache(UserDict): except KeyError: # pragma: no cover pass - def values(self) -> Iterator[Any]: + def values(self) -> ValuesView[Any]: with self.mutex: for k in self: try: @@ -104,7 +115,7 @@ class LRUCache(UserDict): with self.mutex: return self.data.keys() - def incr(self, key: Any, delta: int=1) -> Any: + def incr(self, key: Any, delta: int = 1) -> Any: with self.mutex: # this acts as memcached does- store as a string, but return a # integer as long as it exists and we can cast it @@ -122,9 +133,9 @@ class LRUCache(UserDict): self.mutex = threading.RLock() -def memoize(maxsize: Optional[int]=None, - keyfun: Optional[MemoizeKeyFun]=None, - Cache: Any=LRUCache) -> Callable: +def memoize(maxsize: int = None, + keyfun: MemoizeKeyFun = None, + Cache: Any = LRUCache) -> Callable: """Decorator to cache function return value.""" def _memoize(fun: Callable) -> Callable: @@ -189,10 +200,10 @@ class lazy: def __repr__(self) -> str: return repr(self()) - def __eq__(self, rhs) -> bool: + def __eq__(self, rhs: Any) -> bool: return self() == rhs - def __ne__(self, rhs) -> bool: + def __ne__(self, rhs: Any) -> bool: return self() != rhs def __deepcopy__(self, memo: Dict) -> Any: @@ -200,8 +211,10 @@ class lazy: return self def __reduce__(self) -> Any: - return (self.__class__, (self._fun,), {'_args': self._args, - '_kwargs': self._kwargs}) + return (self.__class__, (self._fun,), { + '_args': self._args, + '_kwargs': self._kwargs, + }) def maybe_evaluate(value: Any) -> Any: @@ -212,8 +225,9 @@ def maybe_evaluate(value: Any) -> Any: def is_list(l: Any, - scalars: tuple=(Mapping, str), - iters: tuple=(Iterable,)) -> bool: + *, + scalars: tuple = (Mapping, str), + iters: tuple = (Iterable,)) -> bool: """Return true if the object is iterable. Note: @@ -222,18 +236,19 @@ def is_list(l: Any, return isinstance(l, iters) and not isinstance(l, scalars or ()) -def maybe_list(l: Any, scalars: tuple=(Mapping, str)) -> Optional[Sequence]: +def maybe_list(l: Any, *, + scalars: tuple = (Mapping, str)) -> Optional[Sequence]: """Return list of one element if ``l`` is a scalar.""" return l if l is None or is_list(l, scalars) else [l] -def dictfilter(d: Optional[Mapping]=None, **kw) -> Mapping: +def dictfilter(d: Mapping = None, **kw) -> Mapping: """Remove all keys from dict ``d`` whose value is :const:`None`.""" d = kw if d is None else (dict(d, **kw) if kw else d) return {k: v for k, v in d.items() if v is not None} -def shufflecycle(it): +def shufflecycle(it: Sequence) -> Iterator: it = list(it) # don't modify callers list shuffle = random.shuffle for _ in repeat(None): @@ -241,7 +256,10 @@ def shufflecycle(it): yield it[0] -def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False): +def fxrange(start: float = 1.0, + stop: float = None, + step: float = 1.0, + repeatlast: bool = False) -> Iterator[float]: cur = start * 1.0 while 1: if not stop or cur <= stop: @@ -253,7 +271,10 @@ def fxrange(start=1.0, stop=None, step=1.0, repeatlast=False): yield cur - step -def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0): +def fxrangemax(start: float = 1.0, + stop: float = None, + step: float = 1.0, + max: float = 100.0) -> Iterator[float]: sum_, cur = 0, start * 1.0 while 1: if sum_ >= max: @@ -266,9 +287,18 @@ def fxrangemax(start=1.0, stop=None, step=1.0, max=100.0): sum_ += cur -def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, - max_retries=None, interval_start=2, interval_step=2, - interval_max=30, callback=None): +async def retry_over_time( + fun: Callable, + catch: Tuple[Any, ...], + args: Sequence = [], + kwargs: Mapping[str, Any] = {}, + *, + errback: Callable = None, + max_retries: int = None, + interval_start: float = 2.0, + interval_step: float = 2.0, + interval_max: float = 30.0, + callback: Callable = None) -> Any: """Retry the function over and over until max retries is exceeded. For each retry we sleep a for a while before we try again, this interval @@ -303,7 +333,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, interval_step, repeatlast=True) for retries in count(): try: - return fun(*args, **kwargs) + return await fun(*args, **kwargs) except catch as exc: if max_retries and retries >= max_retries: raise @@ -320,11 +350,17 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None, sleep(abs(int(tts) - tts)) -def reprkwargs(kwargs, sep=', ', fmt='{0}={1}'): +def reprkwargs(kwargs: Mapping, + *, + sep: str = ', ', fmt: str = '{0}={1}') -> str: return sep.join(fmt.format(k, _safe_repr(v)) for k, v in kwargs.items()) -def reprcall(name, args=(), kwargs={}, sep=', '): +def reprcall(name: str, + args: Sequence = (), + kwargs: Mapping = {}, + *, + sep: str = ', ') -> str: return '{0}({1}{2}{3})'.format( name, sep.join(map(_safe_repr, args or ())), (args and kwargs) and sep or '', diff --git a/kombu/utils/imports.py b/kombu/utils/imports.py index 5262b41d..345ec4bb 100644 --- a/kombu/utils/imports.py +++ b/kombu/utils/imports.py @@ -1,10 +1,18 @@ """Import related utilities.""" import importlib import sys +from typing import Any, Callable, Mapping -def symbol_by_name(name, aliases={}, imp=None, package=None, - sep='.', default=None, **kwargs): +def symbol_by_name( + name: Any, + aliases: Mapping[str, str] = {}, + *, + imp: Callable = None, + package: str = None, + sep: str = '.', + default: Any = None, + **kwargs) -> Any: """Get symbol by qualified name. The name should be the full dot-separated path to the class:: diff --git a/kombu/utils/json.py b/kombu/utils/json.py index 6668a73e..b3c6dfb5 100644 --- a/kombu/utils/json.py +++ b/kombu/utils/json.py @@ -10,7 +10,7 @@ from .typing import AnyBuffer try: from django.utils.functional import Promise as DjangoPromise except ImportError: # pragma: no cover - class DjangoPromise(object): # noqa + class DjangoPromise(object): # noqa, type: ignore """Dummy object.""" try: @@ -24,20 +24,21 @@ except ImportError: # pragma: no cover else: from simplejson.decoder import JSONDecodeError as _DecodeError -_encoder_cls = type(json._default_encoder) -_default_encoder = None # ... set to JSONEncoder below. +_encoder_cls: type = type(json._default_encoder) # type: ignore +_default_encoder: type = None # ... set to JSONEncoder below. class JSONEncoder(_encoder_cls): """Kombu custom json encoder.""" - def default(self, o, + def default(self, o: Any, + *, dates=(datetime.datetime, datetime.date), times=(datetime.time,), textual=(decimal.Decimal, uuid.UUID, DjangoPromise), isinstance=isinstance, datetime=datetime.datetime, - str=str): + str=str) -> Any: reducer = getattr(o, '__json__', None) if reducer is not None: o = reducer() @@ -59,6 +60,7 @@ _default_encoder = JSONEncoder def dumps(s: Any, + *, _dumps: Callable = json.dumps, cls: Any = None, default_kwargs: Dict = _json_extra_kwargs, @@ -68,7 +70,7 @@ def dumps(s: Any, **dict(default_kwargs, **kwargs)) -def loads(s: AnyBuffer, _loads: Callable = json.loads) -> Any: +def loads(s: AnyBuffer, *, _loads: Callable = json.loads) -> Any: """Deserialize json from string.""" # None of the json implementations supports decoding from # a buffer/memoryview, or even reading from a stream diff --git a/kombu/utils/limits.py b/kombu/utils/limits.py index ee9c8003..2c34f0de 100644 --- a/kombu/utils/limits.py +++ b/kombu/utils/limits.py @@ -1,8 +1,7 @@ """Token bucket implementation for rate limiting.""" from collections import deque from time import monotonic -from typing import Any, Optional - +from typing import Any from .typing import Float __all__ = ['TokenBucket'] @@ -24,21 +23,21 @@ class TokenBucket: """ #: The rate in tokens/second that the bucket will be refilled. - fill_rate = None + fill_rate: float = None #: Maximum number of tokens in the bucket. - capacity = 1.0 # type: float + capacity: float = 1.0 #: Timestamp of the last time a token was taken out of the bucket. - timestamp = None # type: Optional[float] + timestamp: float = None - def __init__(self, fill_rate: Optional[Float], - capacity: Float=1.0) -> None: + def __init__(self, fill_rate: Float = None, + capacity: Float = 1.0) -> None: self.capacity = float(capacity) self._tokens = capacity self.fill_rate = float(fill_rate) self.timestamp = monotonic() - self.contents = deque() # type: deque + self.contents = deque() def add(self, item: Any) -> None: self.contents.append(item) @@ -49,7 +48,7 @@ class TokenBucket: def clear_pending(self) -> None: self.contents.clear() - def can_consume(self, tokens: int=1) -> bool: + def can_consume(self, tokens: int = 1) -> bool: """Check if one or more tokens can be consumed. Returns: diff --git a/kombu/utils/scheduling.py b/kombu/utils/scheduling.py index 74d4f25a..624a06d0 100644 --- a/kombu/utils/scheduling.py +++ b/kombu/utils/scheduling.py @@ -1,8 +1,7 @@ """Scheduling Utilities.""" from itertools import count -from typing import Any, Callable, Iterable, Optional, Sequence, Union +from typing import Any, Callable, Iterable, Sequence, Union from typing import List # noqa - from .imports import symbol_by_name __all__ = [ @@ -24,12 +23,12 @@ class FairCycle: Arguments: fun (Callable): Callback to call. - resources (Sequence[Any]): List of resources. + resources (Sequence): List of resources. predicate (type): Exception predicate. """ def __init__(self, fun: Callable, resources: Sequence, - predicate: Any=Exception) -> None: + predicate: Any = Exception) -> None: self.fun = fun self.resources = resources self.predicate = predicate @@ -72,8 +71,8 @@ class BaseCycle: class round_robin_cycle(BaseCycle): """Iterator that cycles between items in round-robin.""" - def __init__(self, it: Optional[Iterable]=None) -> None: - self.items = list(it if it is not None else []) # type: List + def __init__(self, it: Iterable = None) -> None: + self.items = list(it if it is not None else []) def update(self, it: Sequence) -> None: """Update items from iterable.""" diff --git a/kombu/utils/text.py b/kombu/utils/text.py index 6f7977fc..f2e8280d 100644 --- a/kombu/utils/text.py +++ b/kombu/utils/text.py @@ -1,11 +1,18 @@ # -*- coding: utf-8 -*- """Text Utilities.""" from difflib import SequenceMatcher -from typing import Iterator, Sequence, NamedTuple, Tuple - +from numbers import Number +from typing import ( + Iterator, Sequence, NamedTuple, Optional, SupportsInt, Tuple, Union, +) from kombu import version_info_t -fmatch_t = NamedTuple('fmatch_t', [('ratio', float), ('key', str)]) + +class fmatch_t(NamedTuple): + """Return value of :func:`fmatch_iter`.""" + + ratio: float + key: str def escape_regex(p: str, white: str = '') -> str: @@ -30,14 +37,14 @@ def fmatch_iter(needle: str, haystack: Sequence[str], def fmatch_best(needle: str, haystack: Sequence[str], - min_ratio: float = 0.6) -> str: + min_ratio: float = 0.6) -> Optional[str]: """Fuzzy match - Find best match (scalar).""" try: return sorted( fmatch_iter(needle, haystack, min_ratio), reverse=True, )[0][1] except IndexError: - pass + return None def version_string_as_tuple(s: str) -> version_info_t: @@ -52,7 +59,9 @@ def version_string_as_tuple(s: str) -> version_info_t: return v -def _unpack_version(major: int, minor: int = 0, micro: int = 0, +def _unpack_version(major: Union[SupportsInt, str, bytes], + minor: Union[SupportsInt, str, bytes] = 0, + micro: Union[SupportsInt, str, bytes] = 0, releaselevel: str = '', serial: str = '') -> version_info_t: return version_info_t(int(major), int(minor), micro, releaselevel, serial) diff --git a/kombu/utils/time.py b/kombu/utils/time.py index 64d31d3b..72448c3f 100644 --- a/kombu/utils/time.py +++ b/kombu/utils/time.py @@ -1,10 +1,8 @@ """Time Utilities.""" -from __future__ import absolute_import, unicode_literals - +from typing import Optional, Union __all__ = ['maybe_s_to_ms'] -def maybe_s_to_ms(v): - # type: (Optional[Union[int, float]]) -> int +def maybe_s_to_ms(v: Optional[Union[int, float]]) -> Optional[int]: """Convert seconds to milliseconds, but return None for None.""" return int(float(v) * 1000.0) if v is not None else v diff --git a/kombu/utils/typing.py b/kombu/utils/typing.py deleted file mode 100644 index 1f4843f6..00000000 --- a/kombu/utils/typing.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import abstractmethod -from typing import ( - _Protocol, AnyStr, SupportsFloat, SupportsInt, Union, -) - -Float = Union[SupportsInt, SupportsFloat, AnyStr] -Int = Union[SupportsInt, AnyStr] - -Port = Union[SupportsInt, str] - -AnyBuffer = Union[AnyStr, memoryview] - - -class SupportsFileno(_Protocol): - __slots__ = () - - @abstractmethod - def __fileno__(self) -> int: - ... - -Fd = Union[int, SupportsFileno] diff --git a/kombu/utils/url.py b/kombu/utils/url.py index 75db4d89..7dae3464 100644 --- a/kombu/utils/url.py +++ b/kombu/utils/url.py @@ -1,30 +1,25 @@ """URL Utilities.""" from functools import partial from typing import Any, Dict, Mapping, NamedTuple +from urllib.parse import parse_qsl, quote, unquote, urlparse from .typing import Port -try: - from urllib.parse import parse_qsl, quote, unquote, urlparse -except ImportError: - from urllib import quote, unquote # noqa - from urlparse import urlparse, parse_qsl # noqa safequote = partial(quote, safe='') -urlparts = NamedTuple('urlparts', [ - ('scheme', str), - ('hostname', str), - ('port', int), - ('username', str), - ('password', str), - ('path', str), - ('query', Dict), -]) +class urlparts(NamedTuple): + scheme: str + hostname: str + port: int + username: str + password: str + path: str + query: Dict -def parse_url(url: str) -> Dict: +def parse_url(url: str) -> Mapping: """Parse URL into mapping of components.""" - scheme, host, port, user, password, path, query = _parse_url(url) + scheme, host, port, user, password, path, query = url_to_parts(url) return dict(transport=scheme, hostname=host, port=port, userid=user, password=password, virtual_host=path, **query) @@ -47,7 +42,6 @@ def url_to_parts(url: str) -> urlparts: unquote(path or '') or None, dict(parse_qsl(parts.query)), ) -_parse_url = url_to_parts # noqa def as_url(scheme: str, @@ -79,7 +73,7 @@ def as_url(scheme: str, def sanitize_url(url: str, mask: str = '**') -> str: """Return copy of URL with password removed.""" - return as_url(*_parse_url(url), sanitize=True, mask=mask) + return as_url(*url_to_parts(url), sanitize=True, mask=mask) def maybe_sanitize_url(url: Any, mask: str = '**') -> Any: diff --git a/kombu/utils/uuid.py b/kombu/utils/uuid.py index 7de82b29..04130ae2 100644 --- a/kombu/utils/uuid.py +++ b/kombu/utils/uuid.py @@ -1,8 +1,9 @@ """UUID utilities.""" from uuid import uuid4 +from typing import Callable -def uuid(_uuid=uuid4): +def uuid(*, _uuid: Callable = uuid4) -> str: """Generate unique id in UUID4 format. See Also: |