from __future__ import annotations import time from itertools import count from typing import TYPE_CHECKING from unittest.mock import Mock from kombu.transport import base from kombu.utils import json if TYPE_CHECKING: from types import TracebackType class _ContextMock(Mock): """Dummy class implementing __enter__ and __exit__ as the :keyword:`with` statement requires these to be implemented in the class, not just the instance.""" def __enter__(self): return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: pass def ContextMock(*args, **kwargs): """Mock that mocks :keyword:`with` statement contexts.""" obj = _ContextMock(*args, **kwargs) obj.attach_mock(_ContextMock(), '__enter__') obj.attach_mock(_ContextMock(), '__exit__') obj.__enter__.return_value = obj # if __exit__ return a value the exception is ignored, # so it must return None here. obj.__exit__.return_value = None return obj def PromiseMock(*args, **kwargs): m = Mock(*args, **kwargs) def on_throw(exc=None, *args, **kwargs): if exc: raise exc raise m.throw.side_effect = on_throw m.set_error_state.side_effect = on_throw m.throw1.side_effect = on_throw return m class MockPool: def __init__(self, value=None): self.value = value or ContextMock() def acquire(self, **kwargs): return self.value class Message(base.Message): def __init__(self, *args, **kwargs): self.throw_decode_error = kwargs.get('throw_decode_error', False) super().__init__(*args, **kwargs) def decode(self): if self.throw_decode_error: raise ValueError("can't decode message") return super().decode() class Channel(base.StdChannel): open = True throw_decode_error = False _ids = count(1) def __init__(self, connection): self.connection = connection self.called = [] self.deliveries = count(1) self.to_deliver = [] self.events = {'basic_return': set()} self.channel_id = next(self._ids) def _called(self, name): self.called.append(name) def __contains__(self, key): return key in self.called def exchange_declare(self, *args, **kwargs): self._called('exchange_declare') def prepare_message(self, body, priority=0, content_type=None, content_encoding=None, headers=None, properties={}): self._called('prepare_message') return {'body': body, 'headers': headers, 'properties': properties, 'priority': priority, 'content_type': content_type, 'content_encoding': content_encoding} def basic_publish(self, message, exchange='', routing_key='', mandatory=False, immediate=False, **kwargs): self._called('basic_publish') return message, exchange, routing_key def exchange_delete(self, *args, **kwargs): self._called('exchange_delete') def queue_declare(self, *args, **kwargs): self._called('queue_declare') def queue_bind(self, *args, **kwargs): self._called('queue_bind') def queue_unbind(self, *args, **kwargs): self._called('queue_unbind') def queue_delete(self, queue, if_unused=False, if_empty=False, **kwargs): self._called('queue_delete') def basic_get(self, *args, **kwargs): self._called('basic_get') try: return self.to_deliver.pop() except IndexError: pass def queue_purge(self, *args, **kwargs): self._called('queue_purge') def basic_consume(self, *args, **kwargs): self._called('basic_consume') def basic_cancel(self, *args, **kwargs): self._called('basic_cancel') def basic_ack(self, *args, **kwargs): self._called('basic_ack') def basic_recover(self, requeue=False): self._called('basic_recover') def exchange_bind(self, *args, **kwargs): self._called('exchange_bind') def exchange_unbind(self, *args, **kwargs): self._called('exchange_unbind') def close(self): self._called('close') def message_to_python(self, message, *args, **kwargs): self._called('message_to_python') return Message(body=json.dumps(message), channel=self, delivery_tag=next(self.deliveries), throw_decode_error=self.throw_decode_error, content_type='application/json', content_encoding='utf-8') def flow(self, active): self._called('flow') def basic_reject(self, delivery_tag, requeue=False): if requeue: return self._called('basic_reject:requeue') return self._called('basic_reject') def basic_qos(self, prefetch_size=0, prefetch_count=0, apply_global=False): self._called('basic_qos') class Connection: connected = True def __init__(self, client): self.client = client def channel(self): return Channel(self) class Transport(base.Transport): def establish_connection(self): return Connection(self.client) def create_channel(self, connection): return connection.channel() def drain_events(self, connection, **kwargs): return 'event' def close_connection(self, connection): connection.connected = False class TimeoutingTransport(Transport): recoverable_connection_errors = (TimeoutError,) def __init__(self, connect_timeout=1, **kwargs): self.connect_timeout = connect_timeout super().__init__(**kwargs) def establish_connection(self): time.sleep(self.connect_timeout) raise TimeoutError('timed out')