diff options
Diffstat (limited to 'kombu/transport/SQS.py')
-rw-r--r-- | kombu/transport/SQS.py | 73 |
1 files changed, 35 insertions, 38 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index ac199aa1..8a935cdd 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -238,7 +238,7 @@ class Channel(virtual.Channel): _predefined_queue_async_clients = {} # A client for each predefined queue _sqs = None _predefined_queue_clients = {} # A client for each predefined queue - _queue_cache = {} + _queue_cache = {} # SQS queue name => SQS queue URL _noack_queues = set() QoS = QoS @@ -341,34 +341,38 @@ class Channel(virtual.Channel): return self.entity_name(self.queue_name_prefix + queue_name) def _new_queue(self, queue, **kwargs): - """Ensure a queue with given name exists in SQS.""" - if not isinstance(queue, str): - return queue + """ + Ensure a queue with given name exists in SQS. + Arguments: + queue (str): the AMQP queue name + Returns: + str: the SQS queue URL + """ # Translate to SQS name for consistency with initial # _queue_cache population. - queue = self.canonical_queue_name(queue) + sqs_qname = self.canonical_queue_name(queue) # The SQS ListQueues method only returns 1000 queues. When you have # so many queues, it's possible that the queue you are looking for is # not cached. In this case, we could update the cache with the exact # queue name first. - if queue not in self._queue_cache: - self._update_queue_cache(queue) + if sqs_qname not in self._queue_cache: + self._update_queue_cache(sqs_qname) try: - return self._queue_cache[queue] + return self._queue_cache[sqs_qname] except KeyError: if self.predefined_queues: raise UndefinedQueueException(( "Queue with name '{}' must be " "defined in 'predefined_queues'." - ).format(queue)) + ).format(sqs_qname)) attributes = {'VisibilityTimeout': str(self.visibility_timeout)} - if queue.endswith('.fifo'): + if sqs_qname.endswith('.fifo'): attributes['FifoQueue'] = 'true' - resp = self._create_queue(queue, attributes) - self._queue_cache[queue] = resp['QueueUrl'] + resp = self._create_queue(sqs_qname, attributes) + self._queue_cache[sqs_qname] = resp['QueueUrl'] return resp['QueueUrl'] def _create_queue(self, queue_name, attributes): @@ -441,13 +445,13 @@ class Channel(virtual.Channel): pass return byte_string - def _message_to_python(self, message, queue_name, queue): + def _message_to_python(self, message, queue_name, q_url): body = self._optional_b64_decode(message['Body'].encode()) payload = loads(bytes_to_str(body)) if queue_name in self._noack_queues: - queue = self._new_queue(queue_name) + q_url = self._new_queue(queue_name) self.asynsqs(queue=queue_name).delete_message( - queue, + q_url, message['ReceiptHandle'], ) else: @@ -464,7 +468,7 @@ class Channel(virtual.Channel): }) # set delivery tag to SQS receipt handle delivery_info.update({ - 'sqs_message': message, 'sqs_queue': queue, + 'sqs_message': message, 'sqs_queue': q_url, }) properties['delivery_tag'] = message['ReceiptHandle'] return payload @@ -483,8 +487,8 @@ class Channel(virtual.Channel): Returns: List: A list of Payload objects """ - q = self._new_queue(queue) - return [self._message_to_python(m, queue, q) for m in messages] + q_url = self._new_queue(queue) + return [self._message_to_python(m, queue, q_url) for m in messages] def _get_bulk(self, queue, max_if_unlimited=SQS_MAX_MESSAGES, callback=None): @@ -569,11 +573,14 @@ class Channel(virtual.Channel): return callback def _get_async(self, queue, count=1, callback=None): - q = self._new_queue(queue) + q_url = self._new_queue(queue) qname = self.canonical_queue_name(queue) return self._get_from_sqs( - qname, count=count, connection=self.asynsqs(queue=qname), - callback=transform(self._on_messages_ready, callback, q, queue), + queue_name=qname, queue_url=q_url, count=count, + connection=self.asynsqs(queue=qname), + callback=transform( + self._on_messages_ready, callback, q_url, queue + ), ) def _on_messages_ready(self, queue, qname, messages): @@ -583,24 +590,14 @@ class Channel(virtual.Channel): msg_parsed = self._message_to_python(msg, qname, queue) callbacks[qname](msg_parsed) - def _get_from_sqs(self, queue, - count=1, connection=None, callback=None): + def _get_from_sqs(self, queue_name, queue_url, + connection, count=1, callback=None): """Retrieve and handle messages from SQS. Uses long polling and returns :class:`~vine.promises.promise`. """ - connection = connection if connection is not None else queue.connection - if self.predefined_queues: - if queue not in self._queue_cache: - raise UndefinedQueueException(( - "Queue with name '{}' must be defined in " - "'predefined_queues'." - ).format(queue)) - queue_url = self._queue_cache[queue] - else: - queue_url = connection.get_queue_url(queue) return connection.receive_message( - queue, queue_url, number_messages=count, + queue_name, queue_url, number_messages=count, wait_time_seconds=self.wait_time_seconds, callback=callback, ) @@ -635,16 +632,16 @@ class Channel(virtual.Channel): def _size(self, queue): """Return the number of messages in a queue.""" - url = self._new_queue(queue) + q_url = self._new_queue(queue) c = self.sqs(queue=self.canonical_queue_name(queue)) resp = c.get_queue_attributes( - QueueUrl=url, + QueueUrl=q_url, AttributeNames=['ApproximateNumberOfMessages']) return int(resp['Attributes']['ApproximateNumberOfMessages']) def _purge(self, queue): """Delete all current messages in a queue.""" - q = self._new_queue(queue) + q_url = self._new_queue(queue) # SQS is slow at registering messages, so run for a few # iterations to ensure messages are detected and deleted. size = 0 @@ -652,7 +649,7 @@ class Channel(virtual.Channel): size += int(self._size(queue)) if not size: break - self.sqs(queue=queue).purge_queue(QueueUrl=q) + self.sqs(queue=queue).purge_queue(QueueUrl=q_url) return size def close(self): |