diff options
-rw-r--r-- | kombu/transport/SQS.py | 73 | ||||
-rw-r--r-- | kombu/transport/virtual/base.py | 4 | ||||
-rw-r--r-- | t/unit/transport/test_SQS.py | 32 |
3 files changed, 69 insertions, 40 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index ac199aa1..8a935cdd 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -238,7 +238,7 @@ class Channel(virtual.Channel): _predefined_queue_async_clients = {} # A client for each predefined queue _sqs = None _predefined_queue_clients = {} # A client for each predefined queue - _queue_cache = {} + _queue_cache = {} # SQS queue name => SQS queue URL _noack_queues = set() QoS = QoS @@ -341,34 +341,38 @@ class Channel(virtual.Channel): return self.entity_name(self.queue_name_prefix + queue_name) def _new_queue(self, queue, **kwargs): - """Ensure a queue with given name exists in SQS.""" - if not isinstance(queue, str): - return queue + """ + Ensure a queue with given name exists in SQS. + Arguments: + queue (str): the AMQP queue name + Returns: + str: the SQS queue URL + """ # Translate to SQS name for consistency with initial # _queue_cache population. - queue = self.canonical_queue_name(queue) + sqs_qname = self.canonical_queue_name(queue) # The SQS ListQueues method only returns 1000 queues. When you have # so many queues, it's possible that the queue you are looking for is # not cached. In this case, we could update the cache with the exact # queue name first. - if queue not in self._queue_cache: - self._update_queue_cache(queue) + if sqs_qname not in self._queue_cache: + self._update_queue_cache(sqs_qname) try: - return self._queue_cache[queue] + return self._queue_cache[sqs_qname] except KeyError: if self.predefined_queues: raise UndefinedQueueException(( "Queue with name '{}' must be " "defined in 'predefined_queues'." - ).format(queue)) + ).format(sqs_qname)) attributes = {'VisibilityTimeout': str(self.visibility_timeout)} - if queue.endswith('.fifo'): + if sqs_qname.endswith('.fifo'): attributes['FifoQueue'] = 'true' - resp = self._create_queue(queue, attributes) - self._queue_cache[queue] = resp['QueueUrl'] + resp = self._create_queue(sqs_qname, attributes) + self._queue_cache[sqs_qname] = resp['QueueUrl'] return resp['QueueUrl'] def _create_queue(self, queue_name, attributes): @@ -441,13 +445,13 @@ class Channel(virtual.Channel): pass return byte_string - def _message_to_python(self, message, queue_name, queue): + def _message_to_python(self, message, queue_name, q_url): body = self._optional_b64_decode(message['Body'].encode()) payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: - queue = self._new_queue(queue_name) + q_url = self._new_queue(queue_name) self.asynsqs(queue=queue_name).delete_message( - queue, + q_url, message['ReceiptHandle'], ) else: @@ -464,7 +468,7 @@ class Channel(virtual.Channel): }) # set delivery tag to SQS receipt handle delivery_info.update({ - 'sqs_message': message, 'sqs_queue': queue, + 'sqs_message': message, 'sqs_queue': q_url, }) properties['delivery_tag'] = message['ReceiptHandle'] return payload @@ -483,8 +487,8 @@ class Channel(virtual.Channel): Returns: List: A list of Payload objects """ - q = self._new_queue(queue) - return [self._message_to_python(m, queue, q) for m in messages] + q_url = self._new_queue(queue) + return [self._message_to_python(m, queue, q_url) for m in messages] def _get_bulk(self, queue, max_if_unlimited=SQS_MAX_MESSAGES, callback=None): @@ -569,11 +573,14 @@ class Channel(virtual.Channel): return callback def _get_async(self, queue, count=1, callback=None): - q = self._new_queue(queue) + q_url = self._new_queue(queue) qname = self.canonical_queue_name(queue) return self._get_from_sqs( - qname, count=count, connection=self.asynsqs(queue=qname), - callback=transform(self._on_messages_ready, callback, q, queue), + queue_name=qname, queue_url=q_url, count=count, + connection=self.asynsqs(queue=qname), + callback=transform( + self._on_messages_ready, callback, q_url, queue + ), ) def _on_messages_ready(self, queue, qname, messages): @@ -583,24 +590,14 @@ class Channel(virtual.Channel): msg_parsed = self._message_to_python(msg, qname, queue) callbacks[qname](msg_parsed) - def _get_from_sqs(self, queue, - count=1, connection=None, callback=None): + def _get_from_sqs(self, queue_name, queue_url, + connection, count=1, callback=None): """Retrieve and handle messages from SQS. Uses long polling and returns :class:`~vine.promises.promise`. """ - connection = connection if connection is not None else queue.connection - if self.predefined_queues: - if queue not in self._queue_cache: - raise UndefinedQueueException(( - "Queue with name '{}' must be defined in " - "'predefined_queues'." - ).format(queue)) - queue_url = self._queue_cache[queue] - else: - queue_url = connection.get_queue_url(queue) return connection.receive_message( - queue, queue_url, number_messages=count, + queue_name, queue_url, number_messages=count, wait_time_seconds=self.wait_time_seconds, callback=callback, ) @@ -635,16 +632,16 @@ class Channel(virtual.Channel): def _size(self, queue): """Return the number of messages in a queue.""" - url = self._new_queue(queue) + q_url = self._new_queue(queue) c = self.sqs(queue=self.canonical_queue_name(queue)) resp = c.get_queue_attributes( - QueueUrl=url, + QueueUrl=q_url, AttributeNames=['ApproximateNumberOfMessages']) return int(resp['Attributes']['ApproximateNumberOfMessages']) def _purge(self, queue): """Delete all current messages in a queue.""" - q = self._new_queue(queue) + q_url = self._new_queue(queue) # SQS is slow at registering messages, so run for a few # iterations to ensure messages are detected and deleted. size = 0 @@ -652,7 +649,7 @@ class Channel(virtual.Channel): size += int(self._size(queue)) if not size: break - self.sqs(queue=queue).purge_queue(QueueUrl=q) + self.sqs(queue=queue).purge_queue(QueueUrl=q_url) return size def close(self): diff --git a/kombu/transport/virtual/base.py b/kombu/transport/virtual/base.py index 552ebec7..ec47ecb8 100644 --- a/kombu/transport/virtual/base.py +++ b/kombu/transport/virtual/base.py @@ -698,8 +698,8 @@ class Channel(AbstractChannel, base.StdChannel): """Find all queues matching `routing_key` for the given `exchange`. Returns: - str: queue name -- must return the string `default` - if no queues matched. + list[str]: queue names -- must return `[default]` + if default is set and no queues matched. """ if default is None: default = self.deadletter_queue diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index 2b1219fc..303823be 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -443,6 +443,38 @@ class test_Channel: self.channel._get_bulk(self.queue_name) self.channel.connection._deliver.assert_called_once() + # hub required for successful instantiation of AsyncSQSConnection + @pytest.mark.usefixtures('hub') + def test_get_async(self): + """Basic coverage of async code typically used via: + basic_consume > _loop1 > _schedule_queue > _get_bulk_async""" + # Prepare + for i in range(3): + message = 'message: %s' % i + self.producer.publish(message) + + # SQS.Channel.asynsqs constructs AsyncSQSConnection using self.sqs + # which is already a mock thanks to `setup` above, we just need to + # mock the async-specific methods (as test_AsyncSQSConnection does) + async_sqs_conn = self.channel.asynsqs(self.queue_name) + async_sqs_conn.get_list = Mock(name='X.get_list') + + # Call key method + self.channel._get_bulk_async(self.queue_name) + + assert async_sqs_conn.get_list.call_count == 1 + get_list_args = async_sqs_conn.get_list.call_args[0] + get_list_kwargs = async_sqs_conn.get_list.call_args[1] + assert get_list_args[0] == 'ReceiveMessage' + assert get_list_args[1] == { + 'MaxNumberOfMessages': SQS.SQS_MAX_MESSAGES, + 'AttributeName.1': 'ApproximateReceiveCount', + 'WaitTimeSeconds': self.channel.wait_time_seconds, + } + assert get_list_args[3] == \ + self.channel.sqs().get_queue_url(self.queue_name).url + assert get_list_kwargs['parent'] == self.queue_name + def test_drain_events_with_empty_list(self): def mock_can_consume(): return False |