summaryrefslogtreecommitdiff
path: root/kombu/transport/SQS.py
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2011-06-05 19:02:51 +0100
committerAsk Solem <ask@celeryproject.org>2011-06-05 19:02:51 +0100
commit3e44c3851593460ad4e6ff8d394b6d509d3278d3 (patch)
treecd153ce0804a242c18db71f39afb9d16919fe417 /kombu/transport/SQS.py
parent04f104e312e68a1218715d32adbc9c5c94a824c7 (diff)
downloadkombu-3e44c3851593460ad4e6ff8d394b6d509d3278d3.tar.gz
SQS: SDB related fixes and cleanup
Diffstat (limited to 'kombu/transport/SQS.py')
-rw-r--r--kombu/transport/SQS.py214
1 files changed, 132 insertions, 82 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index 80fb1d61..b97f1241 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -18,6 +18,7 @@ from anyjson import serialize, deserialize
import boto
from boto import exception
+from boto.sdb.domain import Domain
from boto.sqs.message import Message
from kombu.transport import virtual
@@ -33,86 +34,136 @@ CHARS_REPLACE_TABLE = string.maketrans(CHARS_REPLACE + '.',
"_" * len(CHARS_REPLACE) + '-')
+class Table(Domain):
+ """Amazon SimpleDB domain describing the message routing table."""
+
+ def routes_for(self, exchange):
+ """Iterator giving all routes for an exchange."""
+ for id in self._exchange_members(exchange):
+ yield self.get_item(id)
+
+ def get_queue(self, queue):
+ """Get binding for queue."""
+ qid = self._get_queue_id(queue)
+ if qid:
+ return self.get_item(qid)
+
+ def create_binding(self, queue):
+ """Get binding item for queue.
+
+ Creates the item if it doesn't exist.
+
+ """
+ item = self.get_queue(queue)
+ if item:
+ return item
+ return self.create_item(gen_unique_id())
+
+ def queue_delete(self, queue):
+ """delete queue by name."""
+ qid = self._get_queue_id(queue)
+ if qid:
+ self.delete_item(qid)
+
+ def exchange_delete(self, exchange):
+ """Delete all routes for `exchange`."""
+ for id in self._exchange_members(exchange):
+ domain.delete_item(id)
+
+ def get_item(self, item_name, consistent_read=True):
+ """Uses `consistent_read` by default."""
+ # Domain is an old-style class, can't use super().
+ return Domain.get_item(self, item_name, consistent_read)
+
+ def select(self, query='', next_token=None, consistent_read=True,
+ max_items=None):
+ """Uses `consistent_read` by default."""
+ return Domain.select(query, next_token, consistent_read, max_items)
+
+ def _exchange_members(self, exchange):
+ return self.select("""exchange = '%s'""" % exchange)
+
+ def _get_queue_id(self, queue):
+ for id in self.select("""queue = '%s' limit 1""" % queue, max_items=1):
+ return id
+
+
+
class Channel(virtual.Channel):
- keyprefix_queue = "_kombu.binding.%(exchange)s"
- keyprefix_domain = '_kombu.%(vhost)s"
- sep = '\x06\x16'
+ Table = Table
- _client = None
- _fanout_queues = {} # can be global
+ default_region = "us-east-1"
+ domain_format = "kombu%(vhost)s"
+ _sdb = None
+ _sqs = None
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
+ """Format AMQP queue name into a legal SQS queue name."""
return name.translate(table)
def _new_queue(self, queue, **kwargs):
- return self.client.create_queue(self.entity_name(queue),
- self.visibility_timeout)
+ """Ensures a queue exists in SQS."""
+ return self.sqs.create_queue(self.entity_name(queue),
+ self.visibility_timeout)
- def _get_or_create_item(self, name):
- item = self.sdb_domain.get_attributes(name, consistent_read=True)
- if item is None:
- return self.sdb_domain.new_item(name), False
- return item, True
+ def _queue_bind(self, exchange, routing_key, pattern, queue):
+ """Bind ``queue`` to ``exchange`` with routing key.
+ Route will be stored in SDB if so enabled.
- def _queue_bind(self, exchange, routing_key, pattern, queue):
+ """
if not self.supports_fanout:
return
- if self.typeof(exchange).type == "fanout":
- # Mark exchange as fanout locally
- self._fanout_queues[queue] = exchange
- binding = self._create_binding(queue)
+ binding = self.table.create_binding(queue)
binding.update(exchange=exchange,
routing_key=routing_key or "",
pattern=pattern or "",
queue=queue or "")
binding.save()
- def _find_queue(self, queue):
- domain = self.sdb_domain
- for id in domain.select("""queue = '%s' limit 1""" % queue,
- max_items=1):
- return domain.get_item(id, consistent_read=True)
+ def get_table(self, exchange):
+ """Get routing table.
- def _create_binding(self, queue):
- item = self._find_queue(queue)
- if item:
- return item
- return self.sdb_domain.create_item(gen_unique_id())
-
- def _get_table(self, exchange):
- table = []
- domain = self.sdb_domain
- for id in domain.select("""exchange = '%s'""" % exchange):
- ex = domain.get_item(id, consistent_read=True)
- table.append((ex["routing_key",
- ex["pattern"],
- ex["queue"]))
- return table
+ Retrieved from SDB if :attr:`supports_fanout`.
+
+ """
+ if self.supports_fanout:
+ return [(r["routing_key"], r["pattern"], r["queue"])
+ for r in self.table.routes_for(exchange)]
+ return super(Channel, self).get_table(exchange)
def _delete(self, queue):
"""delete queue by name."""
- for id in domain.select("""queue = '%s' limit 1""" % queue):
- domain.delete_item(id)
+ self.table.queue_delete(queue)
super(Channel, self)._delete(queue)
+ def exchange_delete(self, exchange, **kwargs):
+ """Delete exchange by name."""
+ if self.supports_fanout:
+ self.table.exchange_delete(exchange)
+ super(Channel, self).exchange_delete(exchange, **kwargs)
+
def _has_queue(self, queue, **kwargs):
- return bool(self._find_queue(queue))
+ """Returns True if ``queue`` has been previously declared."""
+ if self.supports_fanout:
+ return bool(self.table.get_queue(queue))
+ return super(Channel, self)._has_queue(queue)
- def _put_fanout(self, exchange, message, **kwargs):
- domain = self.sdb_domain
- for id in domain.select("""exchange = '%s'""" % exchange):
- item = domain.get_item(id, consistent_read=True)
- self._put(item["queue"], message, **kwargs)
+ def _put(self, queue, message, **kwargs):
+ """Put message onto queue."""
+ q = self._new_queue(queue)
+ m = Message()
+ m.set_body(serialize(message))
+ q.write(m)
- def basic_consume(self, queue, *args, **kwargs):
- if queue in self._fanout_queues:
- exchange = self._fanout_queues[queue]
- self.active_fanout_queues.add(queue)
- return super(Channel, self).basic_consume(queue, *args, **kwargs)
+ def _put_fanout(self, exchange, message, **kwargs):
+ """Deliver fanout message to all queues in ``exchange``."""
+ for route in self.table.routes_for(exchange):
+ self._put(route["queue"], message, **kwargs)
def _get(self, queue):
+ """Try to retrieve a single message off ``queue``."""
q = self._new_queue(queue)
rs = q.get_messages(1)
if rs:
@@ -120,15 +171,11 @@ class Channel(virtual.Channel):
raise Empty()
def _size(self, queue):
+ """Returns the number of messages in a queue."""
return self._new_queue(queue).count()
- def _put(self, queue, message, **kwargs):
- q = self._new_queue(queue)
- m = Message()
- m.set_body(serialize(message))
- q.write(m)
-
def _purge(self, queue):
+ """Deletes all current messages in a queue."""
q = self._new_queue(queue)
size = q.count()
q.clear()
@@ -136,34 +183,39 @@ class Channel(virtual.Channel):
def close(self):
super(Channel, self).close()
- if self._client:
- try:
- self._client.close()
- except AttributeError, exc: # FIXME ???
- if "can't set attribute" not in str(exc):
- raise
-
- def _open(self):
- return boto.connect_sqs(self.conninfo.userid, self.conninfo.password)
-
- def _open_sdb(self):
- return boto.connect_sdb(self.conninfo.userid, self.conninfo.password)
+ for conn in (self._sqs, self._sdb):
+ if conn:
+ try:
+ conn.close()
+ except AttributeError, exc: # FIXME ???
+ if "can't set attribute" not in str(exc):
+ raise
+
+ def _aws_connect_to(self, fun):
+ conninfo = self.conninfo
+ return fun(self.region, aws_access_key_id=conninfo.userid,
+ aws_secret_access_key=conninfo.password,
+ port=conninfo.port)
@property
- def client(self):
- if self._client is None:
- self._client = self._open()
- return self._client
+ def sqs(self):
+ if self._sqs is None:
+ self._sqs = self._aws_connect_to(boto.sqs.connect_to_region)
+ return self._sqs
@property
def sdb(self):
if self._sdb is None:
- self._sdb = self._open_sdb()
+ self._sdb = self._aws_connect_to(boto.sdb.connect_to_region)
+ return self._sdb
@property
- def sdb_domain(self):
- return self._sdb.create_domain(self.keyprefix_domain % {
- "vhost": self.connection.client.vhost})
+ def table(self):
+ name = self.domain_format % {"vhost": self.conninfo.vhost}
+ d = self.sdb.get_object("CreateDomain", {"DomainName": name},
+ self.Table)
+ d.name = name
+ return d
@property
def conninfo(self):
@@ -181,10 +233,9 @@ class Channel(virtual.Channel):
def supports_fanout(self):
return self.transport_options.get("sdb_persistence", True)
- @cached_property
- def sdb_domain(self):
- return self.sdb.new_domain(self.keyprefix_domain % {
- "vhost": self.conninfo.vhost})
+ @property
+ def region(self):
+ return self.transport_options.get("region") or self.default_region
class Transport(virtual.Transport):
@@ -192,6 +243,5 @@ class Transport(virtual.Transport):
interval = 1
default_port = None
- connection_errors = (exception.SQSError,
- socket.error)
+ connection_errors = (exception.SQSError, socket.error)
channel_errors = (exception.SQSDecodeError, )