summaryrefslogtreecommitdiff
path: root/kombu/abstract.py
blob: 48a917c9cec8bac100c5721e7198911f1ac9fad1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Object utilities."""

from __future__ import annotations

from copy import copy
from typing import TYPE_CHECKING, Any, Callable, TypeVar

from .connection import maybe_channel
from .exceptions import NotBoundError
from .utils.functional import ChannelPromise

if TYPE_CHECKING:
    from kombu.connection import Connection
    from kombu.transport.virtual import Channel


__all__ = ('Object', 'MaybeChannelBound')

_T = TypeVar("_T")
_ObjectType = TypeVar("_ObjectType", bound="Object")
_MaybeChannelBoundType = TypeVar(
    "_MaybeChannelBoundType", bound="MaybeChannelBound"
)


def unpickle_dict(
    cls: type[_ObjectType], kwargs: dict[str, Any]
) -> _ObjectType:
    return cls(**kwargs)


def _any(v: _T) -> _T:
    return v


class Object:
    """Common base class.

    Supports automatic kwargs->attributes handling, and cloning.
    """

    attrs: tuple[tuple[str, Any], ...] = ()

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        for name, type_ in self.attrs:
            value = kwargs.get(name)
            if value is not None:
                setattr(self, name, (type_ or _any)(value))
            else:
                try:
                    getattr(self, name)
                except AttributeError:
                    setattr(self, name, None)

    def as_dict(self, recurse: bool = False) -> dict[str, Any]:
        def f(obj: Any, type: Callable[[Any], Any]) -> Any:
            if recurse and isinstance(obj, Object):
                return obj.as_dict(recurse=True)
            return type(obj) if type and obj is not None else obj
        return {
            attr: f(getattr(self, attr), type) for attr, type in self.attrs
        }

    def __reduce__(self: _ObjectType) -> tuple[
        Callable[[type[_ObjectType], dict[str, Any]], _ObjectType],
        tuple[type[_ObjectType], dict[str, Any]]
    ]:
        return unpickle_dict, (self.__class__, self.as_dict())

    def __copy__(self: _ObjectType) -> _ObjectType:
        return self.__class__(**self.as_dict())


class MaybeChannelBound(Object):
    """Mixin for classes that can be bound to an AMQP channel."""

    _channel: Channel | None = None
    _is_bound = False

    #: Defines whether maybe_declare can skip declaring this entity twice.
    can_cache_declaration = False

    def __call__(
        self: _MaybeChannelBoundType, channel: (Channel | Connection)
    ) -> _MaybeChannelBoundType:
        """`self(channel) -> self.bind(channel)`."""
        return self.bind(channel)

    def bind(
        self: _MaybeChannelBoundType, channel: (Channel | Connection)
    ) -> _MaybeChannelBoundType:
        """Create copy of the instance that is bound to a channel."""
        return copy(self).maybe_bind(channel)

    def maybe_bind(
        self: _MaybeChannelBoundType, channel: (Channel | Connection)
    ) -> _MaybeChannelBoundType:
        """Bind instance to channel if not already bound."""
        if not self.is_bound and channel:
            self._channel = maybe_channel(channel)
            self.when_bound()
            self._is_bound = True
        return self

    def revive(self, channel: Channel) -> None:
        """Revive channel after the connection has been re-established.

        Used by :meth:`~kombu.Connection.ensure`.

        """
        if self.is_bound:
            self._channel = channel
            self.when_bound()

    def when_bound(self) -> None:
        """Callback called when the class is bound."""

    def __repr__(self) -> str:
        return self._repr_entity(type(self).__name__)

    def _repr_entity(self, item: str = '') -> str:
        item = item or type(self).__name__
        if self.is_bound:
            return '<{} bound to chan:{}>'.format(
                item or type(self).__name__, self.channel.channel_id)
        return f'<unbound {item}>'

    @property
    def is_bound(self) -> bool:
        """Flag set if the channel is bound."""
        return self._is_bound and self._channel is not None

    @property
    def channel(self) -> Channel:
        """Current channel if the object is bound."""
        channel = self._channel
        if channel is None:
            raise NotBoundError(
                "Can't call method on {} not bound to a channel".format(
                    type(self).__name__))
        if isinstance(channel, ChannelPromise):
            channel = self._channel = channel()
        return channel