diff options
Diffstat (limited to 'kombu')
-rw-r--r-- | kombu/transport/SQS.py | 214 |
1 files changed, 132 insertions, 82 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 80fb1d61..b97f1241 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -18,6 +18,7 @@ from anyjson import serialize, deserialize import boto from boto import exception +from boto.sdb.domain import Domain from boto.sqs.message import Message from kombu.transport import virtual @@ -33,86 +34,136 @@ CHARS_REPLACE_TABLE = string.maketrans(CHARS_REPLACE + '.', "_" * len(CHARS_REPLACE) + '-') +class Table(Domain): + """Amazon SimpleDB domain describing the message routing table.""" + + def routes_for(self, exchange): + """Iterator giving all routes for an exchange.""" + for id in self._exchange_members(exchange): + yield self.get_item(id) + + def get_queue(self, queue): + """Get binding for queue.""" + qid = self._get_queue_id(queue) + if qid: + return self.get_item(qid) + + def create_binding(self, queue): + """Get binding item for queue. + + Creates the item if it doesn't exist. + + """ + item = self.get_queue(queue) + if item: + return item + return self.create_item(gen_unique_id()) + + def queue_delete(self, queue): + """delete queue by name.""" + qid = self._get_queue_id(queue) + if qid: + self.delete_item(qid) + + def exchange_delete(self, exchange): + """Delete all routes for `exchange`.""" + for id in self._exchange_members(exchange): + domain.delete_item(id) + + def get_item(self, item_name, consistent_read=True): + """Uses `consistent_read` by default.""" + # Domain is an old-style class, can't use super(). + return Domain.get_item(self, item_name, consistent_read) + + def select(self, query='', next_token=None, consistent_read=True, + max_items=None): + """Uses `consistent_read` by default.""" + return Domain.select(query, next_token, consistent_read, max_items) + + def _exchange_members(self, exchange): + return self.select("""exchange = '%s'""" % exchange) + + def _get_queue_id(self, queue): + for id in self.select("""queue = '%s' limit 1""" % queue, max_items=1): + return id + + + class Channel(virtual.Channel): - keyprefix_queue = "_kombu.binding.%(exchange)s" - keyprefix_domain = '_kombu.%(vhost)s" - sep = '\x06\x16' + Table = Table - _client = None - _fanout_queues = {} # can be global + default_region = "us-east-1" + domain_format = "kombu%(vhost)s" + _sdb = None + _sqs = None def entity_name(self, name, table=CHARS_REPLACE_TABLE): + """Format AMQP queue name into a legal SQS queue name.""" return name.translate(table) def _new_queue(self, queue, **kwargs): - return self.client.create_queue(self.entity_name(queue), - self.visibility_timeout) + """Ensures a queue exists in SQS.""" + return self.sqs.create_queue(self.entity_name(queue), + self.visibility_timeout) - def _get_or_create_item(self, name): - item = self.sdb_domain.get_attributes(name, consistent_read=True) - if item is None: - return self.sdb_domain.new_item(name), False - return item, True + def _queue_bind(self, exchange, routing_key, pattern, queue): + """Bind ``queue`` to ``exchange`` with routing key. + Route will be stored in SDB if so enabled. - def _queue_bind(self, exchange, routing_key, pattern, queue): + """ if not self.supports_fanout: return - if self.typeof(exchange).type == "fanout": - # Mark exchange as fanout locally - self._fanout_queues[queue] = exchange - binding = self._create_binding(queue) + binding = self.table.create_binding(queue) binding.update(exchange=exchange, routing_key=routing_key or "", pattern=pattern or "", queue=queue or "") binding.save() - def _find_queue(self, queue): - domain = self.sdb_domain - for id in domain.select("""queue = '%s' limit 1""" % queue, - max_items=1): - return domain.get_item(id, consistent_read=True) + def get_table(self, exchange): + """Get routing table. - def _create_binding(self, queue): - item = self._find_queue(queue) - if item: - return item - return self.sdb_domain.create_item(gen_unique_id()) - - def _get_table(self, exchange): - table = [] - domain = self.sdb_domain - for id in domain.select("""exchange = '%s'""" % exchange): - ex = domain.get_item(id, consistent_read=True) - table.append((ex["routing_key", - ex["pattern"], - ex["queue"])) - return table + Retrieved from SDB if :attr:`supports_fanout`. + + """ + if self.supports_fanout: + return [(r["routing_key"], r["pattern"], r["queue"]) + for r in self.table.routes_for(exchange)] + return super(Channel, self).get_table(exchange) def _delete(self, queue): """delete queue by name.""" - for id in domain.select("""queue = '%s' limit 1""" % queue): - domain.delete_item(id) + self.table.queue_delete(queue) super(Channel, self)._delete(queue) + def exchange_delete(self, exchange, **kwargs): + """Delete exchange by name.""" + if self.supports_fanout: + self.table.exchange_delete(exchange) + super(Channel, self).exchange_delete(exchange, **kwargs) + def _has_queue(self, queue, **kwargs): - return bool(self._find_queue(queue)) + """Returns True if ``queue`` has been previously declared.""" + if self.supports_fanout: + return bool(self.table.get_queue(queue)) + return super(Channel, self)._has_queue(queue) - def _put_fanout(self, exchange, message, **kwargs): - domain = self.sdb_domain - for id in domain.select("""exchange = '%s'""" % exchange): - item = domain.get_item(id, consistent_read=True) - self._put(item["queue"], message, **kwargs) + def _put(self, queue, message, **kwargs): + """Put message onto queue.""" + q = self._new_queue(queue) + m = Message() + m.set_body(serialize(message)) + q.write(m) - def basic_consume(self, queue, *args, **kwargs): - if queue in self._fanout_queues: - exchange = self._fanout_queues[queue] - self.active_fanout_queues.add(queue) - return super(Channel, self).basic_consume(queue, *args, **kwargs) + def _put_fanout(self, exchange, message, **kwargs): + """Deliver fanout message to all queues in ``exchange``.""" + for route in self.table.routes_for(exchange): + self._put(route["queue"], message, **kwargs) def _get(self, queue): + """Try to retrieve a single message off ``queue``.""" q = self._new_queue(queue) rs = q.get_messages(1) if rs: @@ -120,15 +171,11 @@ class Channel(virtual.Channel): raise Empty() def _size(self, queue): + """Returns the number of messages in a queue.""" return self._new_queue(queue).count() - def _put(self, queue, message, **kwargs): - q = self._new_queue(queue) - m = Message() - m.set_body(serialize(message)) - q.write(m) - def _purge(self, queue): + """Deletes all current messages in a queue.""" q = self._new_queue(queue) size = q.count() q.clear() @@ -136,34 +183,39 @@ class Channel(virtual.Channel): def close(self): super(Channel, self).close() - if self._client: - try: - self._client.close() - except AttributeError, exc: # FIXME ??? - if "can't set attribute" not in str(exc): - raise - - def _open(self): - return boto.connect_sqs(self.conninfo.userid, self.conninfo.password) - - def _open_sdb(self): - return boto.connect_sdb(self.conninfo.userid, self.conninfo.password) + for conn in (self._sqs, self._sdb): + if conn: + try: + conn.close() + except AttributeError, exc: # FIXME ??? + if "can't set attribute" not in str(exc): + raise + + def _aws_connect_to(self, fun): + conninfo = self.conninfo + return fun(self.region, aws_access_key_id=conninfo.userid, + aws_secret_access_key=conninfo.password, + port=conninfo.port) @property - def client(self): - if self._client is None: - self._client = self._open() - return self._client + def sqs(self): + if self._sqs is None: + self._sqs = self._aws_connect_to(boto.sqs.connect_to_region) + return self._sqs @property def sdb(self): if self._sdb is None: - self._sdb = self._open_sdb() + self._sdb = self._aws_connect_to(boto.sdb.connect_to_region) + return self._sdb @property - def sdb_domain(self): - return self._sdb.create_domain(self.keyprefix_domain % { - "vhost": self.connection.client.vhost}) + def table(self): + name = self.domain_format % {"vhost": self.conninfo.vhost} + d = self.sdb.get_object("CreateDomain", {"DomainName": name}, + self.Table) + d.name = name + return d @property def conninfo(self): @@ -181,10 +233,9 @@ class Channel(virtual.Channel): def supports_fanout(self): return self.transport_options.get("sdb_persistence", True) - @cached_property - def sdb_domain(self): - return self.sdb.new_domain(self.keyprefix_domain % { - "vhost": self.conninfo.vhost}) + @property + def region(self): + return self.transport_options.get("region") or self.default_region class Transport(virtual.Transport): @@ -192,6 +243,5 @@ class Transport(virtual.Transport): interval = 1 default_port = None - connection_errors = (exception.SQSError, - socket.error) + connection_errors = (exception.SQSError, socket.error) channel_errors = (exception.SQSDecodeError, ) |