diff options
Diffstat (limited to 't')
-rw-r--r-- | t/mocks.py | 11 | ||||
-rw-r--r-- | t/unit/asynchronous/test_semaphore.py | 6 | ||||
-rw-r--r-- | t/unit/test_common.py | 11 | ||||
-rw-r--r-- | t/unit/test_compat.py | 27 | ||||
-rw-r--r-- | t/unit/test_connection.py | 14 | ||||
-rw-r--r-- | t/unit/test_messaging.py | 5 | ||||
-rw-r--r-- | t/unit/test_simple.py | 5 | ||||
-rw-r--r-- | t/unit/transport/test_redis.py | 11 | ||||
-rw-r--r-- | t/unit/transport/virtual/test_base.py | 5 |
9 files changed, 62 insertions, 33 deletions
@@ -1,9 +1,13 @@ from itertools import count +from typing import TYPE_CHECKING, Optional, Type from unittest.mock import Mock from kombu.transport import base from kombu.utils import json +if TYPE_CHECKING: + from types import TracebackType + class _ContextMock(Mock): """Dummy class implementing __enter__ and __exit__ @@ -13,7 +17,12 @@ class _ContextMock(Mock): def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional['TracebackType'] + ) -> None: pass diff --git a/t/unit/asynchronous/test_semaphore.py b/t/unit/asynchronous/test_semaphore.py index 8767ca91..485d0800 100644 --- a/t/unit/asynchronous/test_semaphore.py +++ b/t/unit/asynchronous/test_semaphore.py @@ -1,11 +1,13 @@ +from typing import List + from kombu.asynchronous.semaphore import LaxBoundedSemaphore class test_LaxBoundedSemaphore: - def test_over_release(self): + def test_over_release(self) -> None: x = LaxBoundedSemaphore(2) - calls = [] + calls: List[int] = [] for i in range(1, 21): x.acquire(calls.append, i) x.release() diff --git a/t/unit/test_common.py b/t/unit/test_common.py index 0f669b7d..4780c0a4 100644 --- a/t/unit/test_common.py +++ b/t/unit/test_common.py @@ -1,4 +1,5 @@ import socket +from typing import TYPE_CHECKING, Optional, Type from unittest.mock import Mock, patch import pytest @@ -10,6 +11,9 @@ from kombu.common import (PREFETCH_COUNT_MAX, Broadcast, QoS, collect_replies, maybe_declare, send_reply) from t.mocks import ContextMock, MockPool +if TYPE_CHECKING: + from types import TracebackType + def test_generate_oid(): from uuid import NAMESPACE_OID @@ -338,7 +342,12 @@ class MockConsumer: self.consumers.add(self) return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional['TracebackType'] + ) -> None: self.consumers.discard(self) diff --git a/t/unit/test_compat.py b/t/unit/test_compat.py index d75ce5df..31eb97ea 100644 --- a/t/unit/test_compat.py +++ b/t/unit/test_compat.py @@ -115,12 +115,14 @@ class test_Publisher: pub.close() def test__enter__exit__(self): - pub = compat.Publisher(self.connection, - exchange='test_Publisher_send', - routing_key='rkey') - x = pub.__enter__() - assert x is pub - x.__exit__() + pub = compat.Publisher( + self.connection, + exchange='test_Publisher_send', + routing_key='rkey' + ) + with pub as x: + assert x is pub + assert pub._closed @@ -158,11 +160,14 @@ class test_Consumer: assert q2.exchange.auto_delete def test__enter__exit__(self, n='test__enter__exit__'): - c = compat.Consumer(self.connection, queue=n, exchange=n, - routing_key='rkey') - x = c.__enter__() - assert x is c - x.__exit__() + c = compat.Consumer( + self.connection, + queue=n, + exchange=n, + routing_key='rkey' + ) + with c as x: + assert x is c assert c._closed def test_revive(self, n='test_revive'): diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 17ea7b34..703e237d 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -398,14 +398,12 @@ class test_Connection: qsms.assert_called_with(self.conn.connection) def test__enter____exit__(self): - conn = self.conn - context = conn.__enter__() - assert context is conn - conn.connect() - assert conn.connection.connected - conn.__exit__() - assert conn.connection is None - conn.close() # again + with self.conn as context: + assert context is self.conn + self.conn.connect() + assert self.conn.connection.connected + assert self.conn.connection is None + self.conn.close() # again def test_close_survives_connerror(self): diff --git a/t/unit/test_messaging.py b/t/unit/test_messaging.py index f8ed437c..68c85b5f 100644 --- a/t/unit/test_messaging.py +++ b/t/unit/test_messaging.py @@ -188,9 +188,8 @@ class test_Producer: def test_enter_exit(self): p = self.connection.Producer() p.release = Mock() - - assert p.__enter__() is p - p.__exit__() + with p as x: + assert x is p p.release.assert_called_with() def test_connection_property_handles_AttributeError(self): diff --git a/t/unit/test_simple.py b/t/unit/test_simple.py index a5cd899a..6a9a9b09 100644 --- a/t/unit/test_simple.py +++ b/t/unit/test_simple.py @@ -91,9 +91,8 @@ class SimpleBase: def test_enter_exit(self): q = self.Queue('test_enter_exit') q.close = Mock() - - assert q.__enter__() is q - q.__exit__() + with q as x: + assert x is q q.close.assert_called_with() def test_qsize(self): diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index 029f7901..9905f6ee 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -6,6 +6,7 @@ from collections import defaultdict from itertools import count from queue import Empty from queue import Queue as _Queue +from typing import TYPE_CHECKING, Optional, Type from unittest.mock import ANY, Mock, call, patch import pytest @@ -16,6 +17,9 @@ from kombu.transport import virtual from kombu.utils import eventio # patch poll from kombu.utils.json import dumps +if TYPE_CHECKING: + from types import TracebackType + def _redis_modules(): @@ -231,7 +235,12 @@ class Pipeline: def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional['TracebackType'] + ) -> None: pass def __getattr__(self, key): diff --git a/t/unit/transport/virtual/test_base.py b/t/unit/transport/virtual/test_base.py index a5685ab9..97df370a 100644 --- a/t/unit/transport/virtual/test_base.py +++ b/t/unit/transport/virtual/test_base.py @@ -462,9 +462,8 @@ class test_Channel: assert 'could not be delivered' in log[0].message.args[0] def test_context(self): - x = self.channel.__enter__() - assert x is self.channel - x.__exit__() + with self.channel as x: + assert x is self.channel assert x.closed def test_cycle_property(self): |