summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--kombu/transport/SQS.py73
-rw-r--r--kombu/transport/virtual/base.py4
-rw-r--r--t/unit/transport/test_SQS.py32
3 files changed, 69 insertions, 40 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index ac199aa1..8a935cdd 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -238,7 +238,7 @@ class Channel(virtual.Channel):
_predefined_queue_async_clients = {} # A client for each predefined queue
_sqs = None
_predefined_queue_clients = {} # A client for each predefined queue
- _queue_cache = {}
+ _queue_cache = {} # SQS queue name => SQS queue URL
_noack_queues = set()
QoS = QoS
@@ -341,34 +341,38 @@ class Channel(virtual.Channel):
return self.entity_name(self.queue_name_prefix + queue_name)
def _new_queue(self, queue, **kwargs):
- """Ensure a queue with given name exists in SQS."""
- if not isinstance(queue, str):
- return queue
+ """
+ Ensure a queue with given name exists in SQS.
+ Arguments:
+ queue (str): the AMQP queue name
+ Returns:
+ str: the SQS queue URL
+ """
# Translate to SQS name for consistency with initial
# _queue_cache population.
- queue = self.canonical_queue_name(queue)
+ sqs_qname = self.canonical_queue_name(queue)
# The SQS ListQueues method only returns 1000 queues. When you have
# so many queues, it's possible that the queue you are looking for is
# not cached. In this case, we could update the cache with the exact
# queue name first.
- if queue not in self._queue_cache:
- self._update_queue_cache(queue)
+ if sqs_qname not in self._queue_cache:
+ self._update_queue_cache(sqs_qname)
try:
- return self._queue_cache[queue]
+ return self._queue_cache[sqs_qname]
except KeyError:
if self.predefined_queues:
raise UndefinedQueueException((
"Queue with name '{}' must be "
"defined in 'predefined_queues'."
- ).format(queue))
+ ).format(sqs_qname))
attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
- if queue.endswith('.fifo'):
+ if sqs_qname.endswith('.fifo'):
attributes['FifoQueue'] = 'true'
- resp = self._create_queue(queue, attributes)
- self._queue_cache[queue] = resp['QueueUrl']
+ resp = self._create_queue(sqs_qname, attributes)
+ self._queue_cache[sqs_qname] = resp['QueueUrl']
return resp['QueueUrl']
def _create_queue(self, queue_name, attributes):
@@ -441,13 +445,13 @@ class Channel(virtual.Channel):
pass
return byte_string
- def _message_to_python(self, message, queue_name, queue):
+ def _message_to_python(self, message, queue_name, q_url):
body = self._optional_b64_decode(message['Body'].encode())
payload = loads(bytes_to_str(body))
if queue_name in self._noack_queues:
- queue = self._new_queue(queue_name)
+ q_url = self._new_queue(queue_name)
self.asynsqs(queue=queue_name).delete_message(
- queue,
+ q_url,
message['ReceiptHandle'],
)
else:
@@ -464,7 +468,7 @@ class Channel(virtual.Channel):
})
# set delivery tag to SQS receipt handle
delivery_info.update({
- 'sqs_message': message, 'sqs_queue': queue,
+ 'sqs_message': message, 'sqs_queue': q_url,
})
properties['delivery_tag'] = message['ReceiptHandle']
return payload
@@ -483,8 +487,8 @@ class Channel(virtual.Channel):
Returns:
List: A list of Payload objects
"""
- q = self._new_queue(queue)
- return [self._message_to_python(m, queue, q) for m in messages]
+ q_url = self._new_queue(queue)
+ return [self._message_to_python(m, queue, q_url) for m in messages]
def _get_bulk(self, queue,
max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
@@ -569,11 +573,14 @@ class Channel(virtual.Channel):
return callback
def _get_async(self, queue, count=1, callback=None):
- q = self._new_queue(queue)
+ q_url = self._new_queue(queue)
qname = self.canonical_queue_name(queue)
return self._get_from_sqs(
- qname, count=count, connection=self.asynsqs(queue=qname),
- callback=transform(self._on_messages_ready, callback, q, queue),
+ queue_name=qname, queue_url=q_url, count=count,
+ connection=self.asynsqs(queue=qname),
+ callback=transform(
+ self._on_messages_ready, callback, q_url, queue
+ ),
)
def _on_messages_ready(self, queue, qname, messages):
@@ -583,24 +590,14 @@ class Channel(virtual.Channel):
msg_parsed = self._message_to_python(msg, qname, queue)
callbacks[qname](msg_parsed)
- def _get_from_sqs(self, queue,
- count=1, connection=None, callback=None):
+ def _get_from_sqs(self, queue_name, queue_url,
+ connection, count=1, callback=None):
"""Retrieve and handle messages from SQS.
Uses long polling and returns :class:`~vine.promises.promise`.
"""
- connection = connection if connection is not None else queue.connection
- if self.predefined_queues:
- if queue not in self._queue_cache:
- raise UndefinedQueueException((
- "Queue with name '{}' must be defined in "
- "'predefined_queues'."
- ).format(queue))
- queue_url = self._queue_cache[queue]
- else:
- queue_url = connection.get_queue_url(queue)
return connection.receive_message(
- queue, queue_url, number_messages=count,
+ queue_name, queue_url, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
callback=callback,
)
@@ -635,16 +632,16 @@ class Channel(virtual.Channel):
def _size(self, queue):
"""Return the number of messages in a queue."""
- url = self._new_queue(queue)
+ q_url = self._new_queue(queue)
c = self.sqs(queue=self.canonical_queue_name(queue))
resp = c.get_queue_attributes(
- QueueUrl=url,
+ QueueUrl=q_url,
AttributeNames=['ApproximateNumberOfMessages'])
return int(resp['Attributes']['ApproximateNumberOfMessages'])
def _purge(self, queue):
"""Delete all current messages in a queue."""
- q = self._new_queue(queue)
+ q_url = self._new_queue(queue)
# SQS is slow at registering messages, so run for a few
# iterations to ensure messages are detected and deleted.
size = 0
@@ -652,7 +649,7 @@ class Channel(virtual.Channel):
size += int(self._size(queue))
if not size:
break
- self.sqs(queue=queue).purge_queue(QueueUrl=q)
+ self.sqs(queue=queue).purge_queue(QueueUrl=q_url)
return size
def close(self):
diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py
index 552ebec7..ec47ecb8 100644
--- a/kombu/transport/virtual/base.py
+++ b/kombu/transport/virtual/base.py
@@ -698,8 +698,8 @@ class Channel(AbstractChannel, base.StdChannel):
"""Find all queues matching `routing_key` for the given `exchange`.
Returns:
- str: queue name -- must return the string `default`
- if no queues matched.
+ list[str]: queue names -- must return `[default]`
+ if default is set and no queues matched.
"""
if default is None:
default = self.deadletter_queue
diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py
index 2b1219fc..303823be 100644
--- a/t/unit/transport/test_SQS.py
+++ b/t/unit/transport/test_SQS.py
@@ -443,6 +443,38 @@ class test_Channel:
self.channel._get_bulk(self.queue_name)
self.channel.connection._deliver.assert_called_once()
+ # hub required for successful instantiation of AsyncSQSConnection
+ @pytest.mark.usefixtures('hub')
+ def test_get_async(self):
+ """Basic coverage of async code typically used via:
+ basic_consume > _loop1 > _schedule_queue > _get_bulk_async"""
+ # Prepare
+ for i in range(3):
+ message = 'message: %s' % i
+ self.producer.publish(message)
+
+ # SQS.Channel.asynsqs constructs AsyncSQSConnection using self.sqs
+ # which is already a mock thanks to `setup` above, we just need to
+ # mock the async-specific methods (as test_AsyncSQSConnection does)
+ async_sqs_conn = self.channel.asynsqs(self.queue_name)
+ async_sqs_conn.get_list = Mock(name='X.get_list')
+
+ # Call key method
+ self.channel._get_bulk_async(self.queue_name)
+
+ assert async_sqs_conn.get_list.call_count == 1
+ get_list_args = async_sqs_conn.get_list.call_args[0]
+ get_list_kwargs = async_sqs_conn.get_list.call_args[1]
+ assert get_list_args[0] == 'ReceiveMessage'
+ assert get_list_args[1] == {
+ 'MaxNumberOfMessages': SQS.SQS_MAX_MESSAGES,
+ 'AttributeName.1': 'ApproximateReceiveCount',
+ 'WaitTimeSeconds': self.channel.wait_time_seconds,
+ }
+ assert get_list_args[3] == \
+ self.channel.sqs().get_queue_url(self.queue_name).url
+ assert get_list_kwargs['parent'] == self.queue_name
+
def test_drain_events_with_empty_list(self):
def mock_can_consume():
return False