summaryrefslogtreecommitdiff
path: root/kombu/simple.py
blob: a33e5f9e1860d25f45eeb078128864d04149d8ca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Simple messaging interface."""

from __future__ import annotations

import socket
from collections import deque
from queue import Empty
from time import monotonic
from typing import TYPE_CHECKING

from . import entity, messaging
from .connection import maybe_channel

if TYPE_CHECKING:
    from types import TracebackType

__all__ = ('SimpleQueue', 'SimpleBuffer')


class SimpleBase:
    Empty = Empty
    _consuming = False

    def __enter__(self):
        return self

    def __exit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None
    ) -> None:
        self.close()

    def __init__(self, channel, producer, consumer, no_ack=False):
        self.channel = maybe_channel(channel)
        self.producer = producer
        self.consumer = consumer
        self.no_ack = no_ack
        self.queue = self.consumer.queues[0]
        self.buffer = deque()
        self.consumer.register_callback(self._receive)

    def get(self, block=True, timeout=None):
        if not block:
            return self.get_nowait()

        self._consume()

        time_start = monotonic()
        remaining = timeout
        while True:
            if self.buffer:
                return self.buffer.popleft()

            if remaining is not None and remaining <= 0.0:
                raise self.Empty()

            try:
                # The `drain_events` method will
                # block on the socket connection to rabbitmq. if any
                # application-level messages are received, it will put them
                # into `self.buffer`.
                # * The method will block for UP TO `timeout` milliseconds.
                # * The method may raise a socket.timeout exception; or...
                # * The method may return without having put anything on
                #    `self.buffer`.  This is because internal heartbeat
                #    messages are sent over the same socket; also POSIX makes
                #    no guarantees against socket calls returning early.
                self.channel.connection.client.drain_events(timeout=remaining)
            except socket.timeout:
                raise self.Empty()

            if remaining is not None:
                elapsed = monotonic() - time_start
                remaining = timeout - elapsed

    def get_nowait(self):
        m = self.queue.get(no_ack=self.no_ack, accept=self.consumer.accept)
        if not m:
            raise self.Empty()
        return m

    def put(self, message, serializer=None, headers=None, compression=None,
            routing_key=None, **kwargs):
        self.producer.publish(message,
                              serializer=serializer,
                              routing_key=routing_key,
                              headers=headers,
                              compression=compression,
                              **kwargs)

    def clear(self):
        return self.consumer.purge()

    def qsize(self):
        _, size, _ = self.queue.queue_declare(passive=True)
        return size

    def close(self):
        self.consumer.cancel()

    def _receive(self, message_data, message):
        self.buffer.append(message)

    def _consume(self):
        if not self._consuming:
            self.consumer.consume(no_ack=self.no_ack)
            self._consuming = True

    def __len__(self):
        """`len(self) -> self.qsize()`."""
        return self.qsize()

    def __bool__(self):
        return True
    __nonzero__ = __bool__


class SimpleQueue(SimpleBase):
    """Simple API for persistent queues."""

    no_ack = False
    queue_opts = {}
    queue_args = {}
    exchange_opts = {'type': 'direct'}

    def __init__(self, channel, name, no_ack=None, queue_opts=None,
                 queue_args=None, exchange_opts=None, serializer=None,
                 compression=None, accept=None):
        queue = name
        queue_opts = dict(self.queue_opts, **queue_opts or {})
        queue_args = dict(self.queue_args, **queue_args or {})
        exchange_opts = dict(self.exchange_opts, **exchange_opts or {})
        if no_ack is None:
            no_ack = self.no_ack
        if not isinstance(queue, entity.Queue):
            exchange = entity.Exchange(name, **exchange_opts)
            queue = entity.Queue(name, exchange, name,
                                 queue_arguments=queue_args,
                                 **queue_opts)
            routing_key = name
        else:
            exchange = queue.exchange
            routing_key = queue.routing_key
        consumer = messaging.Consumer(channel, queue, accept=accept)
        producer = messaging.Producer(channel, exchange,
                                      serializer=serializer,
                                      routing_key=routing_key,
                                      compression=compression)
        super().__init__(channel, producer,
                         consumer, no_ack)


class SimpleBuffer(SimpleQueue):
    """Simple API for ephemeral queues."""

    no_ack = True
    queue_opts = {'durable': False,
                  'auto_delete': True}
    exchange_opts = {'durable': False,
                     'delivery_mode': 'transient',
                     'auto_delete': True}