diff options
author | Matt Wise <matt@nextdoor.com> | 2013-12-02 10:17:38 -0800 |
---|---|---|
committer | Matt Wise <matt@nextdoor.com> | 2013-12-02 10:17:38 -0800 |
commit | b9e8a3240489ac8d5b408a70ddac196944a7d0a6 (patch) | |
tree | 0904e49e6f37e9ef2188e2f77bf6befe674bf612 | |
parent | c8a919cd8341be47f10499d3ed51f6aa12137b80 (diff) | |
download | kombu-b9e8a3240489ac8d5b408a70ddac196944a7d0a6.tar.gz |
Make _get() simpler, move most logic into private methods.
In preparation for building a _get_bulk() method to
handle getting many messages from SQS at once, this commit
moves most of the logic from _get() into separate methods
that can be more easily unit tested and re-used.
-rw-r--r-- | kombu/tests/transport/test_SQS.py | 38 | ||||
-rw-r--r-- | kombu/transport/SQS.py | 72 |
2 files changed, 94 insertions, 16 deletions
diff --git a/kombu/tests/transport/test_SQS.py b/kombu/tests/transport/test_SQS.py index 55194617..12dd81b8 100644 --- a/kombu/tests/transport/test_SQS.py +++ b/kombu/tests/transport/test_SQS.py @@ -241,16 +241,52 @@ class test_Channel(Case): self.removeMockedQueueFile(queue_name) self.assertNotIn(queue_name, self.channel._queue_cache) + def test_get_from_sqs(self): + # Test getting a single message + message = "my test message" + self.producer.publish(message) + results = self.channel._get_from_sqs(self.queue_name) + self.assertEquals(len(results), 1) + + # Now test getting many messages + for i in xrange(3): + message = "message: %s" % i + self.producer.publish(message) + + results = self.channel._get_from_sqs(self.queue_name, count=3) + self.assertEquals(len(results), 3) + def test_get_with_empty_list(self): self.assertRaises(five.Empty, self.channel._get, self.queue_name) + def test_messages_to_payloads(self): + message_count = 3 + # Create several test messages and publish them + for i in xrange(message_count): + message = "message: %s" % i + self.producer.publish(message) + + # Get the messages now + messages = self.channel._get_from_sqs(self.queue_name, + count=message_count) + + # Now convert them to payloads + payloads = self.channel._messages_to_payloads(messages, + self.queue_name) + + # We got the same number of payloads back, right? + self.assertEquals(len(payloads), message_count) + + # Make sure they're payload-style objects + for p in payloads: + self.assertTrue('properties' in p) + def test_put_and_get(self): message = "my test message" self.producer.publish(message) results = self.queue(self.channel).get().payload self.assertEquals(message, results) - def test_puts_and_gets(self): for i in xrange(3): message = "message: %s" % i diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 33e0dcdc..619c4a99 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -2,11 +2,26 @@ kombu.transport.SQS =================== -Amazon SQS transport. +Amazon SQS transport module for Kombu. This package implements an AMQP-like +interface on top of Amazons SQS service, with the goal of being optimized for +high performance and reliability. +The default settings for this module are focused now on high performance in +task queue situations where tasks are small, idempotent and run very fast. + +SQS Features supported by this transport: + Long Polling: + http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ + sqs-long-polling.html + + Long polling is enabled by setting the `wait_time_seconds` transport + option to a number > 1. Amazon supports up to 20 seconds. This is + disabled for now, but will be enabled by default in the near future. """ + from __future__ import absolute_import +import logging import socket import string @@ -27,6 +42,8 @@ from kombu.utils.encoding import bytes_to_str, safe_str from . import virtual +log = logging.getLogger(__name__) + # dots are replaced by dash, all other punctuation # replaced by underscore. CHARS_REPLACE_TABLE = dict((ord(c), 0x5f) @@ -131,7 +148,6 @@ class Channel(virtual.Channel): default_region = 'us-east-1' default_visibility_timeout = 1800 # 30 minutes. default_wait_time_seconds = 0 # disabled see #198 - default_messages_to_fetch = 1 domain_format = 'kombu%(vhost)s' _sdb = None _sqs = None @@ -242,22 +258,53 @@ class Channel(virtual.Channel): for route in self.table.routes_for(exchange): self._put(route['queue'], message, **kwargs) - def _get(self, queue): - """Try to retrieve a single message off ``queue``.""" + def _get_from_sqs(self, queue, count=1): + """Retrieves messages from SQS and returns the raw SQS message objects. + + returns: + List of SQS message objects + """ q = self._new_queue(queue) if W_LONG_POLLING and queue not in self._fanout_queues: - rs = q.get_messages(1, wait_time_seconds=self.wait_time_seconds) + return q.get_messages(count, + wait_time_seconds=self.wait_time_seconds) else: # boto < 2.8 - rs = q.get_messages(1) - if rs: - m = rs[0] - payload = loads(bytes_to_str(rs[0].get_body())) + return q.get_messages(count) + + def _messages_to_payloads(self, messages, queue): + """Converts a list of SQS Message objects into Payloads. + + This method handles converting SQS Message objects into + Payloads, and appropriately updating the queue depending on + the 'ack' settings for that queue. + + args: + messages: A list of SQS Message Objects + queue: String name representing the queue they came from + + returns: + payloads: A list of Payload objects + """ + q = self._new_queue(queue) + payloads = [] + for m in messages: + payload = loads(bytes_to_str(m.get_body())) if queue in self._noack_queues: q.delete_message(m) else: payload['properties']['delivery_info'].update({ 'sqs_message': m, 'sqs_queue': q, }) - return payload + payloads.append(payload) + return payloads + + def _get(self, queue): + """Try to retrieve a single message off ``queue``.""" + messages = self._get_from_sqs(queue) + payloads = self._messages_to_payloads(messages, queue) + + if len(payloads) > 0: + return payloads[0] + raise Empty() def _restore(self, message, @@ -372,11 +419,6 @@ class Channel(virtual.Channel): return self.transport_options.get('wait_time_seconds', self.default_wait_time_seconds) - @cached_property - def messages_to_fetch(self): - return self.transport_options.get('messages_to_fetch', - self.default_messages_to_fetch) - class Transport(virtual.Transport): Channel = Channel |