diff options
author | Ask Solem <ask@celeryproject.org> | 2013-12-13 14:41:03 +0000 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2013-12-13 14:41:03 +0000 |
commit | 49bf684f710984911d0d59c8702acc46014ae12a (patch) | |
tree | 47a763a15eb0d056192de9f6fd9c2cdc7935dd12 | |
parent | 2e7410ae0de2632676f28bc7605c1ebd4096eec8 (diff) | |
download | kombu-sqsopts.tar.gz |
Improvements and cosmetics for Issue #281sqsopts
-rw-r--r-- | kombu/tests/transport/test_SQS.py | 32 | ||||
-rw-r--r-- | kombu/transport/SQS.py | 139 |
2 files changed, 88 insertions, 83 deletions
diff --git a/kombu/tests/transport/test_SQS.py b/kombu/tests/transport/test_SQS.py index 4e59bb70..a655bca8 100644 --- a/kombu/tests/transport/test_SQS.py +++ b/kombu/tests/transport/test_SQS.py @@ -7,8 +7,6 @@ slightly. from __future__ import absolute_import -import os - from kombu import Connection from kombu import messaging from kombu import five @@ -179,20 +177,22 @@ class test_Channel(Case): def test_get_bulk_raises_empty(self): self.assertRaises(five.Empty, self.channel._get_bulk, self.queue_name) - def test_messages_to_payloads(self): + def test_messages_to_python(self): message_count = 3 # Create several test messages and publish them - for i in xrange(message_count): - message = "message: %s" % i + for i in range(message_count): + message = 'message: %s' % i self.producer.publish(message) # Get the messages now - messages = self.channel._get_from_sqs(self.queue_name, - count=message_count) + messages = self.channel._get_from_sqs( + self.queue_name, count=message_count, + ) # Now convert them to payloads - payloads = self.channel._messages_to_payloads(messages, - self.queue_name) + payloads = self.channel._messages_to_python( + messages, self.queue_name, + ) # We got the same number of payloads back, right? self.assertEquals(len(payloads), message_count) @@ -202,23 +202,23 @@ class test_Channel(Case): self.assertTrue('properties' in p) def test_put_and_get(self): - message = "my test message" + message = 'my test message' self.producer.publish(message) results = self.queue(self.channel).get().payload self.assertEquals(message, results) def test_puts_and_gets(self): for i in xrange(3): - message = "message: %s" % i + message = 'message: %s' % i self.producer.publish(message) for i in xrange(3): - self.assertEquals("message: %s" % i, + self.assertEquals('message: %s' % i, self.queue(self.channel).get().payload) def test_put_and_get_bulk(self): # With QoS.prefetch_count = 0 - message = "my test message" + message = 'my test message' self.producer.publish(message) results = self.channel._get_bulk(self.queue_name) self.assertEquals(1, len(results)) @@ -232,7 +232,7 @@ class test_Channel(Case): # Now, generate all the messages for i in xrange(message_count): - message = "message: %s" % i + message = 'message: %s' % i self.producer.publish(message) # Count how many messages are retrieved the first time. Should @@ -260,7 +260,7 @@ class test_Channel(Case): # Now, generate all the messages for i in xrange(message_count): - self.producer.publish("message: %s" % i) + self.producer.publish('message: %s' % i) # Now drain all the events for i in xrange(message_count): @@ -281,7 +281,7 @@ class test_Channel(Case): # Now, generate all the messages for i in xrange(message_count): - self.producer.publish("message: %s" % i) + self.producer.publish('message: %s' % i) # Now drain all the events for i in xrange(message_count): diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 888a7ca9..fec600cb 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -56,13 +56,14 @@ from boto.sqs.connection import SQSConnection from boto.sqs.message import Message from kombu.five import Empty, range, text_t +from kombu.log import get_logger from kombu.utils import cached_property, uuid from kombu.utils.encoding import bytes_to_str, safe_str from kombu.transport.virtual import scheduling from . import virtual -log = logging.getLogger(__name__) +logger = get_logger(__name__) # dots are replaced by dash, all other punctuation # replaced by underscore. @@ -195,8 +196,9 @@ class Channel(virtual.Channel): def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: self._noack_queues.add(queue) - return super(Channel, self).basic_consume(queue, no_ack, - *args, **kwargs) + return super(Channel, self).basic_consume( + queue, no_ack, *args, **kwargs + ) def basic_cancel(self, consumer_tag): if consumer_tag in self._consumers: @@ -205,54 +207,52 @@ class Channel(virtual.Channel): return super(Channel, self).basic_cancel(consumer_tag) def drain_events(self, timeout=None): - """Returns a single payload message from one of our queues. + """Return a single payload message from one of our queues. + + :raises Empty: if no messages available. + """ # If we're not allowed to consume or have no consumers, raise Empty if not self._consumers or not self.qos.can_consume(): - log.debug('No consumers available, or qos.can_consume() ' - 'returned false. Raising Empty.') raise Empty() + message_cache = self._queue_message_cache # Check if there are any items in our buffer. If there are any, pop # off that queue first. - available_items_in_cache = len(self._queue_message_cache) - if available_items_in_cache > 0: - log.debug('%s messages were found in local cache. Returning one.' - % available_items_in_cache) - return self._queue_message_cache.popleft() + try: + return message_cache.popleft() + except IndexError: + pass # At this point, go and get more messages from SQS - log.debug('Requesting new messages from SQS') - (res, queue) = self._poll(self.cycle, timeout=timeout) - for r in res: - self._queue_message_cache.append((r, queue)) - log.debug('Message queue cache now has %s items.' % - len(self._queue_message_cache)) + res, queue = self._poll(self.cycle, timeout=timeout) + message_cache.extend((r, queue) for r in res) # Now try to pop off the queue again. try: - return self._queue_message_cache.popleft() + return message_cache.popleft() except IndexError: raise Empty() def _reset_cycle(self): - """Returns a FairCycle object. + """Reset the consume cycle. + + :returns: a FairCycle object that points to our _get_bulk() method + rather than the standard _get() method. This allows for multiple + messages to be returned at once from SQS (based on the prefetch + limit). - Returns a FairCycle object that points to our _get_bulk() method - rather than the standard _get() method. This allows for multiple - messages to be returned at once from SQS (based on the prefetch - limit). """ - self._cycle = scheduling.FairCycle(self._get_bulk, - self._active_queues, - Empty) + self._cycle = scheduling.FairCycle( + self._get_bulk, self._active_queues, Empty, + ) def entity_name(self, name, table=CHARS_REPLACE_TABLE): """Format AMQP queue name into a legal SQS queue name.""" return text_t(safe_str(name)).translate(table) def _new_queue(self, queue, **kwargs): - """Ensures a queue exists in SQS.""" + """Ensure a queue with given name exists in SQS.""" # Translate to SQS name for consistency with initial # _queue_cache population. queue = self.entity_name(self.queue_name_prefix + queue) @@ -328,45 +328,46 @@ class Channel(virtual.Channel): self._put(route['queue'], message, **kwargs) def _get_from_sqs(self, queue, count=1): - """Retrieves messages from SQS and returns the raw SQS message objects. + """Retrieve messages from SQS and returns the raw SQS message objects. + + :returns: List of SQS message objects - returns: - List of SQS message objects """ q = self._new_queue(queue) if W_LONG_POLLING and queue not in self._fanout_queues: - return q.get_messages(count, - wait_time_seconds=self.wait_time_seconds) + return q.get_messages( + count, wait_time_seconds=self.wait_time_seconds, + ) else: # boto < 2.8 return q.get_messages(count) - def _messages_to_payloads(self, messages, queue): - """Converts a list of SQS Message objects into Payloads. + def _message_to_python(self, message, queue_name, queue): + payload = loads(bytes_to_str(message.get_body())) + if queue_name in self._noack_queues: + queue.delete_message(message) + else: + payload['properties']['delivery_info'].update({ + 'sqs_message': message, 'sqs_queue': queue, + }) + return payload + + def _messages_to_python(self, messages, queue): + """Convert a list of SQS Message objects into Payloads. This method handles converting SQS Message objects into Payloads, and appropriately updating the queue depending on the 'ack' settings for that queue. - args: - messages: A list of SQS Message Objects - queue: String name representing the queue they came from + :param messages: A list of SQS Message objects. + :param queue: String name representing the queue they came from + + :returns: A list of Payload objects - returns: - payloads: A list of Payload objects """ q = self._new_queue(queue) - payloads = [] - for m in messages: - payload = loads(bytes_to_str(m.get_body())) - if queue in self._noack_queues: - q.delete_message(m) - else: - payload['properties']['delivery_info'].update({ - 'sqs_message': m, 'sqs_queue': q, }) - payloads.append(payload) - return payloads - - def _get_bulk(self, queue): + return [self._message_to_python(m, queue, q) for m in messages] + + def _get_bulk(self, queue, max_if_unlimited=10): """Try to retrieve multiple messages off ``queue``. Where _get() returns a single Payload object, this method returns a @@ -375,33 +376,37 @@ class Channel(virtual.Channel): number of messages that the QoS object allows (based on the prefetch_count). + .. note:: + Ignores QoS limits so caller is responsible for checking + that we are allowed to consume at least one message from the + queue. get_bulk will then ask QoS for an estimate of + the number of extra messages that we can consume. + args: queue: The queue name (string) to pull from returns: payloads: A list of payload objects returned """ - messages_to_consume = self.qos.can_consume_max_estimate() - log.debug('Retrieving up to %s messages from queue %s' % - (messages_to_consume, queue)) - messages = self._get_from_sqs(queue, count=messages_to_consume) - payloads = self._messages_to_payloads(messages, queue) - - if len(payloads) > 0: - return payloads - - raise Empty() + # drain_events calls `can_consume` first, consuming + # a token, so we know that we are allowed to consume at least + # one message. + maxcount = self.qos.can_consume_max_estimate() + maxcount = max_if_unlimited if maxcount is None else max(maxcount, 1) + messages = self._get_from_sqs(queue, count=maxcount) + + if not messages: + raise Empty() + return self._messages_to_python(messages, queue) def _get(self, queue): """Try to retrieve a single message off ``queue``.""" - log.debug('Retrieving a single message from queue %s' % queue) - messages = self._get_from_sqs(queue) - payloads = self._messages_to_payloads(messages, queue) + messages = self._get_from_sqs(queue, count=1) - if len(payloads) > 0: - return payloads[0] + if not messages: + raise Empty() - raise Empty() + return self._messages_to_python(messages, queue)[0] def _restore(self, message, unwanted_delivery_info=('sqs_message', 'sqs_queue')): |