summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTom Sparrow <793763+sparrowt@users.noreply.github.com>2023-04-19 05:01:07 +0100
committerGitHub <noreply@github.com>2023-04-19 10:01:07 +0600
commit5062d53f94dd32e6b36a0e23e6b0c6dcffa79cb1 (patch)
tree5bcc7d9d9e8408f2246fb16bbfaa6c8bede6a953
parentf86f1fc6e1caab6bcd2351e3b95424cece6015e4 (diff)
downloadkombu-5062d53f94dd32e6b36a0e23e6b0c6dcffa79cb1.tar.gz
SQS: avoid excessive GetQueueURL calls by using cached queue url (#1621)
* Fix #1618: avoid re-fetching queue URL when we already have it `_get_from_sqs` was unnecessarily calling `get_queue_url` every time even though the only place which calls `_get_from_sqs` (that is `_get_async`) actually already knows the queue URL. This change avoids hundreds of `GetQueueUrl` AWS API calls per hour when using this SQS backend with celery. Also `connection` is set by the one-and-only caller (and `queue` is actually the queue name string now anyway so couldn't ever have `.connection`) so remove the None default and unused fallback code. * Clarify that `_new_queue` returns the queue URL It seems that prior to 129a9e4ed05b it returned a queue object but this is no longer the case so update comments variable names accordingly to make it clearer. Also remove the incorrect fallback which cannot be correct any more given the return value has to be the queue URL which must be a string. * Unit test coverage for SQS async codepath This key code path (which as far as I can see is the main route when using celery with SQS) was missing test coverage. This test adds coverage for: `_get_bulk_async` -> `_get_async` -> `_get_from_sqs`
-rw-r--r--kombu/transport/SQS.py73
-rw-r--r--kombu/transport/virtual/base.py4
-rw-r--r--t/unit/transport/test_SQS.py32
3 files changed, 69 insertions, 40 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index ac199aa1..8a935cdd 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -238,7 +238,7 @@ class Channel(virtual.Channel):
_predefined_queue_async_clients = {} # A client for each predefined queue
_sqs = None
_predefined_queue_clients = {} # A client for each predefined queue
- _queue_cache = {}
+ _queue_cache = {} # SQS queue name => SQS queue URL
_noack_queues = set()
QoS = QoS
@@ -341,34 +341,38 @@ class Channel(virtual.Channel):
return self.entity_name(self.queue_name_prefix + queue_name)
def _new_queue(self, queue, **kwargs):
- """Ensure a queue with given name exists in SQS."""
- if not isinstance(queue, str):
- return queue
+ """
+ Ensure a queue with given name exists in SQS.
+ Arguments:
+ queue (str): the AMQP queue name
+ Returns:
+ str: the SQS queue URL
+ """
# Translate to SQS name for consistency with initial
# _queue_cache population.
- queue = self.canonical_queue_name(queue)
+ sqs_qname = self.canonical_queue_name(queue)
# The SQS ListQueues method only returns 1000 queues. When you have
# so many queues, it's possible that the queue you are looking for is
# not cached. In this case, we could update the cache with the exact
# queue name first.
- if queue not in self._queue_cache:
- self._update_queue_cache(queue)
+ if sqs_qname not in self._queue_cache:
+ self._update_queue_cache(sqs_qname)
try:
- return self._queue_cache[queue]
+ return self._queue_cache[sqs_qname]
except KeyError:
if self.predefined_queues:
raise UndefinedQueueException((
"Queue with name '{}' must be "
"defined in 'predefined_queues'."
- ).format(queue))
+ ).format(sqs_qname))
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
- if queue.endswith('.fifo'):
+ if sqs_qname.endswith('.fifo'):
attributes['FifoQueue'] = 'true'
- resp = self._create_queue(queue, attributes)
- self._queue_cache[queue] = resp['QueueUrl']
+ resp = self._create_queue(sqs_qname, attributes)
+ self._queue_cache[sqs_qname] = resp['QueueUrl']
return resp['QueueUrl']
def _create_queue(self, queue_name, attributes):
@@ -441,13 +445,13 @@ class Channel(virtual.Channel):
pass
return byte_string
- def _message_to_python(self, message, queue_name, queue):
+ def _message_to_python(self, message, queue_name, q_url):
body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
- queue = self._new_queue(queue_name)
+ q_url = self._new_queue(queue_name)
self.asynsqs(queue=queue_name).delete_message(
- queue,
+ q_url,
message['ReceiptHandle'],
)
else:
@@ -464,7 +468,7 @@ class Channel(virtual.Channel):
})
# set delivery tag to SQS receipt handle
delivery_info.update({
- 'sqs_message': message, 'sqs_queue': queue,
+ 'sqs_message': message, 'sqs_queue': q_url,
})
properties['delivery_tag'] = message['ReceiptHandle']
return payload
@@ -483,8 +487,8 @@ class Channel(virtual.Channel):
Returns:
List: A list of Payload objects
"""
- q = self._new_queue(queue)
- return [self._message_to_python(m, queue, q) for m in messages]
+ q_url = self._new_queue(queue)
+ return [self._message_to_python(m, queue, q_url) for m in messages]
def _get_bulk(self, queue,
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
@@ -569,11 +573,14 @@ class Channel(virtual.Channel):
return callback
def _get_async(self, queue, count=1, callback=None):
- q = self._new_queue(queue)
+ q_url = self._new_queue(queue)
qname = self.canonical_queue_name(queue)
return self._get_from_sqs(
- qname, count=count, connection=self.asynsqs(queue=qname),
- callback=transform(self._on_messages_ready, callback, q, queue),
+ queue_name=qname, queue_url=q_url, count=count,
+ connection=self.asynsqs(queue=qname),
+ callback=transform(
+ self._on_messages_ready, callback, q_url, queue
+ ),
)
def _on_messages_ready(self, queue, qname, messages):
@@ -583,24 +590,14 @@ class Channel(virtual.Channel):
msg_parsed = self._message_to_python(msg, qname, queue)
callbacks[qname](msg_parsed)
- def _get_from_sqs(self, queue,
- count=1, connection=None, callback=None):
+ def _get_from_sqs(self, queue_name, queue_url,
+ connection, count=1, callback=None):
"""Retrieve and handle messages from SQS.
Uses long polling and returns :class:`~vine.promises.promise`.
"""
- connection = connection if connection is not None else queue.connection
- if self.predefined_queues:
- if queue not in self._queue_cache:
- raise UndefinedQueueException((
- "Queue with name '{}' must be defined in "
- "'predefined_queues'."
- ).format(queue))
- queue_url = self._queue_cache[queue]
- else:
- queue_url = connection.get_queue_url(queue)
return connection.receive_message(
- queue, queue_url, number_messages=count,
+ queue_name, queue_url, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
callback=callback,
)
@@ -635,16 +632,16 @@ class Channel(virtual.Channel):
def _size(self, queue):
"""Return the number of messages in a queue."""
- url = self._new_queue(queue)
+ q_url = self._new_queue(queue)
c = self.sqs(queue=self.canonical_queue_name(queue))
resp = c.get_queue_attributes(
- QueueUrl=url,
+ QueueUrl=q_url,
AttributeNames=['ApproximateNumberOfMessages'])
return int(resp['Attributes']['ApproximateNumberOfMessages'])
def _purge(self, queue):
"""Delete all current messages in a queue."""
- q = self._new_queue(queue)
+ q_url = self._new_queue(queue)
# SQS is slow at registering messages, so run for a few
# iterations to ensure messages are detected and deleted.
size = 0
@@ -652,7 +649,7 @@ class Channel(virtual.Channel):
size += int(self._size(queue))
if not size:
break
- self.sqs(queue=queue).purge_queue(QueueUrl=q)
+ self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
return size
def close(self):
diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py
index 552ebec7..ec47ecb8 100644
--- a/kombu/transport/virtual/base.py
+++ b/kombu/transport/virtual/base.py
@@ -698,8 +698,8 @@ class Channel(AbstractChannel, base.StdChannel):
"""Find all queues matching `routing_key` for the given `exchange`.
Returns:
- str: queue name -- must return the string `default`
- if no queues matched.
+ list[str]: queue names -- must return `[default]`
+ if default is set and no queues matched.
"""
if default is None:
default = self.deadletter_queue
diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py
index 2b1219fc..303823be 100644
--- a/t/unit/transport/test_SQS.py
+++ b/t/unit/transport/test_SQS.py
@@ -443,6 +443,38 @@ class test_Channel:
self.channel._get_bulk(self.queue_name)
self.channel.connection._deliver.assert_called_once()
+ # hub required for successful instantiation of AsyncSQSConnection
+ @pytest.mark.usefixtures('hub')
+ def test_get_async(self):
+ """Basic coverage of async code typically used via:
+ basic_consume > _loop1 > _schedule_queue > _get_bulk_async"""
+ # Prepare
+ for i in range(3):
+ message = 'message: %s' % i
+ self.producer.publish(message)
+
+ # SQS.Channel.asynsqs constructs AsyncSQSConnection using self.sqs
+ # which is already a mock thanks to `setup` above, we just need to
+ # mock the async-specific methods (as test_AsyncSQSConnection does)
+ async_sqs_conn = self.channel.asynsqs(self.queue_name)
+ async_sqs_conn.get_list = Mock(name='X.get_list')
+
+ # Call key method
+ self.channel._get_bulk_async(self.queue_name)
+
+ assert async_sqs_conn.get_list.call_count == 1
+ get_list_args = async_sqs_conn.get_list.call_args[0]
+ get_list_kwargs = async_sqs_conn.get_list.call_args[1]
+ assert get_list_args[0] == 'ReceiveMessage'
+ assert get_list_args[1] == {
+ 'MaxNumberOfMessages': SQS.SQS_MAX_MESSAGES,
+ 'AttributeName.1': 'ApproximateReceiveCount',
+ 'WaitTimeSeconds': self.channel.wait_time_seconds,
+ }
+ assert get_list_args[3] == \
+ self.channel.sqs().get_queue_url(self.queue_name).url
+ assert get_list_kwargs['parent'] == self.queue_name
+
def test_drain_events_with_empty_list(self):
def mock_can_consume():
return False