summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatt Wise <matt@nextdoor.com>2013-12-02 10:17:38 -0800
committerMatt Wise <matt@nextdoor.com>2013-12-02 10:17:38 -0800
commitb9e8a3240489ac8d5b408a70ddac196944a7d0a6 (patch)
tree0904e49e6f37e9ef2188e2f77bf6befe674bf612
parentc8a919cd8341be47f10499d3ed51f6aa12137b80 (diff)
downloadkombu-b9e8a3240489ac8d5b408a70ddac196944a7d0a6.tar.gz
Make _get() simpler, move most logic into private methods.
In preparation for building a _get_bulk() method to handle getting many messages from SQS at once, this commit moves most of the logic from _get() into separate methods that can be more easily unit tested and re-used.
-rw-r--r--kombu/tests/transport/test_SQS.py38
-rw-r--r--kombu/transport/SQS.py72
2 files changed, 94 insertions, 16 deletions
diff --git a/kombu/tests/transport/test_SQS.py b/kombu/tests/transport/test_SQS.py
index 55194617..12dd81b8 100644
--- a/kombu/tests/transport/test_SQS.py
+++ b/kombu/tests/transport/test_SQS.py
@@ -241,16 +241,52 @@ class test_Channel(Case):
self.removeMockedQueueFile(queue_name)
self.assertNotIn(queue_name, self.channel._queue_cache)
+ def test_get_from_sqs(self):
+ # Test getting a single message
+ message = "my test message"
+ self.producer.publish(message)
+ results = self.channel._get_from_sqs(self.queue_name)
+ self.assertEquals(len(results), 1)
+
+ # Now test getting many messages
+ for i in xrange(3):
+ message = "message: %s" % i
+ self.producer.publish(message)
+
+ results = self.channel._get_from_sqs(self.queue_name, count=3)
+ self.assertEquals(len(results), 3)
+
def test_get_with_empty_list(self):
self.assertRaises(five.Empty, self.channel._get, self.queue_name)
+ def test_messages_to_payloads(self):
+ message_count = 3
+ # Create several test messages and publish them
+ for i in xrange(message_count):
+ message = "message: %s" % i
+ self.producer.publish(message)
+
+ # Get the messages now
+ messages = self.channel._get_from_sqs(self.queue_name,
+ count=message_count)
+
+ # Now convert them to payloads
+ payloads = self.channel._messages_to_payloads(messages,
+ self.queue_name)
+
+ # We got the same number of payloads back, right?
+ self.assertEquals(len(payloads), message_count)
+
+ # Make sure they're payload-style objects
+ for p in payloads:
+ self.assertTrue('properties' in p)
+
def test_put_and_get(self):
message = "my test message"
self.producer.publish(message)
results = self.queue(self.channel).get().payload
self.assertEquals(message, results)
-
def test_puts_and_gets(self):
for i in xrange(3):
message = "message: %s" % i
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index 33e0dcdc..619c4a99 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -2,11 +2,26 @@
kombu.transport.SQS
===================
-Amazon SQS transport.
+Amazon SQS transport module for Kombu. This package implements an AMQP-like
+interface on top of Amazons SQS service, with the goal of being optimized for
+high performance and reliability.
+The default settings for this module are focused now on high performance in
+task queue situations where tasks are small, idempotent and run very fast.
+
+SQS Features supported by this transport:
+ Long Polling:
+ http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/
+ sqs-long-polling.html
+
+ Long polling is enabled by setting the `wait_time_seconds` transport
+ option to a number > 1. Amazon supports up to 20 seconds. This is
+ disabled for now, but will be enabled by default in the near future.
"""
+
from __future__ import absolute_import
+import logging
import socket
import string
@@ -27,6 +42,8 @@ from kombu.utils.encoding import bytes_to_str, safe_str
from . import virtual
+log = logging.getLogger(__name__)
+
# dots are replaced by dash, all other punctuation
# replaced by underscore.
CHARS_REPLACE_TABLE = dict((ord(c), 0x5f)
@@ -131,7 +148,6 @@ class Channel(virtual.Channel):
default_region = 'us-east-1'
default_visibility_timeout = 1800 # 30 minutes.
default_wait_time_seconds = 0 # disabled see #198
- default_messages_to_fetch = 1
domain_format = 'kombu%(vhost)s'
_sdb = None
_sqs = None
@@ -242,22 +258,53 @@ class Channel(virtual.Channel):
for route in self.table.routes_for(exchange):
self._put(route['queue'], message, **kwargs)
- def _get(self, queue):
- """Try to retrieve a single message off ``queue``."""
+ def _get_from_sqs(self, queue, count=1):
+ """Retrieves messages from SQS and returns the raw SQS message objects.
+
+ returns:
+ List of SQS message objects
+ """
q = self._new_queue(queue)
if W_LONG_POLLING and queue not in self._fanout_queues:
- rs = q.get_messages(1, wait_time_seconds=self.wait_time_seconds)
+ return q.get_messages(count,
+ wait_time_seconds=self.wait_time_seconds)
else: # boto < 2.8
- rs = q.get_messages(1)
- if rs:
- m = rs[0]
- payload = loads(bytes_to_str(rs[0].get_body()))
+ return q.get_messages(count)
+
+ def _messages_to_payloads(self, messages, queue):
+ """Converts a list of SQS Message objects into Payloads.
+
+ This method handles converting SQS Message objects into
+ Payloads, and appropriately updating the queue depending on
+ the 'ack' settings for that queue.
+
+ args:
+ messages: A list of SQS Message Objects
+ queue: String name representing the queue they came from
+
+ returns:
+ payloads: A list of Payload objects
+ """
+ q = self._new_queue(queue)
+ payloads = []
+ for m in messages:
+ payload = loads(bytes_to_str(m.get_body()))
if queue in self._noack_queues:
q.delete_message(m)
else:
payload['properties']['delivery_info'].update({
'sqs_message': m, 'sqs_queue': q, })
- return payload
+ payloads.append(payload)
+ return payloads
+
+ def _get(self, queue):
+ """Try to retrieve a single message off ``queue``."""
+ messages = self._get_from_sqs(queue)
+ payloads = self._messages_to_payloads(messages, queue)
+
+ if len(payloads) > 0:
+ return payloads[0]
+
raise Empty()
def _restore(self, message,
@@ -372,11 +419,6 @@ class Channel(virtual.Channel):
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
- @cached_property
- def messages_to_fetch(self):
- return self.transport_options.get('messages_to_fetch',
- self.default_messages_to_fetch)
-
class Transport(virtual.Transport):
Channel = Channel