diff options
author | Ismael Jiménez Sánchez <ismaeljs11@gmail.com> | 2023-04-11 18:33:26 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-11 22:33:26 +0600 |
commit | ff031f73f26b8cd75b1489727f08ed085186a873 (patch) | |
tree | be49563523016b32cfcaccf72e812f1ee77781bd | |
parent | 9659f11ae1b1633cd9897ac3f49d029b17a29010 (diff) | |
download | kombu-ff031f73f26b8cd75b1489727f08ed085186a873.tar.gz |
fix: handle keyerror in azureservicebus transport when message is not found in qos and perform basic_ack (#1691)
* fix: handle keyerror when message is not found in qos and perform basic_ack
* fix: added tests for basic_ack
* fix: limit line length
-rw-r--r-- | kombu/transport/azureservicebus.py | 29 | ||||
-rw-r--r-- | t/unit/transport/test_azureservicebus.py | 86 |
2 files changed, 101 insertions, 14 deletions
diff --git a/kombu/transport/azureservicebus.py b/kombu/transport/azureservicebus.py index e7e2c0cc..01b12965 100644 --- a/kombu/transport/azureservicebus.py +++ b/kombu/transport/azureservicebus.py @@ -273,23 +273,24 @@ class Channel(virtual.Channel): return msg def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None: - delivery_info = self.qos.get(delivery_tag).delivery_info - - if delivery_info['exchange'] in self._noack_queues: - return super().basic_ack(delivery_tag) - - queue = delivery_info['azure_queue_name'] - # recv_mode is PEEK_LOCK when ack'ing messages - queue_obj = self._get_asb_receiver(queue) - try: - queue_obj.receiver.complete_message(delivery_info['azure_message']) - except azure.servicebus.exceptions.MessageAlreadySettled: + delivery_info = self.qos.get(delivery_tag).delivery_info + except KeyError: super().basic_ack(delivery_tag) - except Exception: - super().basic_reject(delivery_tag) else: - super().basic_ack(delivery_tag) + queue = delivery_info['azure_queue_name'] + # recv_mode is PEEK_LOCK when ack'ing messages + queue_obj = self._get_asb_receiver(queue) + + try: + queue_obj.receiver.complete_message( + delivery_info['azure_message']) + except azure.servicebus.exceptions.MessageAlreadySettled: + super().basic_ack(delivery_tag) + except Exception: + super().basic_reject(delivery_tag) + else: + super().basic_ack(delivery_tag) def _size(self, queue: str) -> int: """Return the number of messages in a queue.""" diff --git a/t/unit/transport/test_azureservicebus.py b/t/unit/transport/test_azureservicebus.py index 5de93c2f..72510e09 100644 --- a/t/unit/transport/test_azureservicebus.py +++ b/t/unit/transport/test_azureservicebus.py @@ -333,3 +333,89 @@ def test_custom_entity_name(): assert channel.entity_name('test_celery') == 'test_celery' assert channel.entity_name('test:celery') == 'test_celery' assert channel.entity_name('test+celery') == 'test_celery' + + +def test_basic_ack_complete_message(mock_queue: MockQueue): + mock_queue.producer.publish("test message") + message = mock_queue.channel._get(mock_queue.queue_name) + mock_queue.channel.qos.get = MagicMock( + return_value=mock_queue.channel.Message( + message, mock_queue.channel + ) + ) + receiver_mock = MagicMock() + receiver_mock.complete_message = MagicMock(return_value=None) + queue_object_mock = MagicMock() + queue_object_mock.receiver = receiver_mock + mock_queue.channel._get_asb_receiver = MagicMock( + return_value=queue_object_mock) + with patch( + 'kombu.transport.virtual.base.Channel.basic_ack' + ) as super_basic_ack: + mock_queue.channel.basic_ack("test_delivery_tag") + assert mock_queue.channel.qos.get.call_count == 1 + assert mock_queue.channel._get_asb_receiver.call_count == 1 + assert queue_object_mock.receiver.complete_message.call_count == 1 + assert super_basic_ack.call_count == 1 + + +def test_basic_ack_when_already_settled(mock_queue: MockQueue): + mock_queue.producer.publish("test message") + message = mock_queue.channel._get(mock_queue.queue_name) + mock_queue.channel.qos.get = MagicMock( + return_value=mock_queue.channel.Message( + message, mock_queue.channel + ) + ) + receiver_mock = MagicMock() + receiver_mock.complete_message = MagicMock( + side_effect=azure.servicebus.exceptions.MessageAlreadySettled()) + queue_object_mock = MagicMock() + queue_object_mock.receiver = receiver_mock + mock_queue.channel._get_asb_receiver = MagicMock( + return_value=queue_object_mock) + with patch( + 'kombu.transport.virtual.base.Channel.basic_ack' + ) as super_basic_ack: + mock_queue.channel.basic_ack("test_delivery_tag") + assert mock_queue.channel.qos.get.call_count == 1 + assert mock_queue.channel._get_asb_receiver.call_count == 1 + assert queue_object_mock.receiver.complete_message.call_count == 1 + assert super_basic_ack.call_count == 1 + + +def test_basic_ack_when_qos_raises_keyerror(mock_queue: MockQueue): + """Test that basic_ack calls super method when keyerror""" + mock_queue.channel.qos.get = MagicMock(side_effect=KeyError()) + with patch( + 'kombu.transport.virtual.base.Channel.basic_ack' + ) as super_basic_ack: + mock_queue.channel.basic_ack("invented_delivery_tag") + assert super_basic_ack.call_count == 1 + assert mock_queue.channel.qos.get.call_count == 1 + + +def test_basic_ack_reject_message_when_raises_exception( + mock_queue: MockQueue +): + mock_queue.producer.publish("test message") + message = mock_queue.channel._get(mock_queue.queue_name) + mock_queue.channel.qos.get = MagicMock( + return_value=mock_queue.channel.Message( + message, mock_queue.channel + ) + ) + receiver_mock = MagicMock() + receiver_mock.complete_message = MagicMock(side_effect=Exception()) + queue_object_mock = MagicMock() + queue_object_mock.receiver = receiver_mock + mock_queue.channel._get_asb_receiver = MagicMock( + return_value=queue_object_mock) + with patch( + 'kombu.transport.virtual.base.Channel.basic_reject' + ) as super_basic_reject: + mock_queue.channel.basic_ack("test_delivery_tag") + assert mock_queue.channel.qos.get.call_count == 1 + assert mock_queue.channel._get_asb_receiver.call_count == 1 + assert queue_object_mock.receiver.complete_message.call_count == 1 + assert super_basic_reject.call_count == 1 |