diff options
-rw-r--r-- | kombu/transport/SQS.py | 59 |
1 files changed, 36 insertions, 23 deletions
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index b97f1241..77e59f76 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -18,11 +18,15 @@ from anyjson import serialize, deserialize import boto from boto import exception +from boto import sdb as _sdb +from boto import sqs as _sqs from boto.sdb.domain import Domain +from boto.sdb.connection import SDBConnection +from boto.sqs.connection import SQSConnection from boto.sqs.message import Message from kombu.transport import virtual -from kombu.utils import cached_property +from kombu.utils import cached_property, gen_unique_id # dots are replaced by dash, all other punctuation @@ -39,8 +43,7 @@ class Table(Domain): def routes_for(self, exchange): """Iterator giving all routes for an exchange.""" - for id in self._exchange_members(exchange): - yield self.get_item(id) + return self.select("""exchange = '%s'""" % exchange) def get_queue(self, queue): """Get binding for queue.""" @@ -56,8 +59,9 @@ class Table(Domain): """ item = self.get_queue(queue) if item: - return item - return self.create_item(gen_unique_id()) + return item, item["id"] + id = gen_unique_id() + return self.new_item(id), id def queue_delete(self, queue): """delete queue by name.""" @@ -67,8 +71,8 @@ class Table(Domain): def exchange_delete(self, exchange): """Delete all routes for `exchange`.""" - for id in self._exchange_members(exchange): - domain.delete_item(id) + for item in self.routes_for(exchange): + domain.delete_item(item["id"]) def get_item(self, item_name, consistent_read=True): """Uses `consistent_read` by default.""" @@ -78,14 +82,14 @@ class Table(Domain): def select(self, query='', next_token=None, consistent_read=True, max_items=None): """Uses `consistent_read` by default.""" - return Domain.select(query, next_token, consistent_read, max_items) - - def _exchange_members(self, exchange): - return self.select("""exchange = '%s'""" % exchange) + query = """SELECT * FROM `%s` WHERE %s""" % (self.name, query) + return Domain.select(self, query, next_token, + consistent_read, max_items) def _get_queue_id(self, queue): - for id in self.select("""queue = '%s' limit 1""" % queue, max_items=1): - return id + for item in self.select("""queue = '%s' limit 1""" % queue, + max_items=1): + return item["id"] @@ -99,7 +103,7 @@ class Channel(virtual.Channel): def entity_name(self, name, table=CHARS_REPLACE_TABLE): """Format AMQP queue name into a legal SQS queue name.""" - return name.translate(table) + return name.encode(errors="replace").translate(table) def _new_queue(self, queue, **kwargs): """Ensures a queue exists in SQS.""" @@ -115,11 +119,12 @@ class Channel(virtual.Channel): if not self.supports_fanout: return - binding = self.table.create_binding(queue) + binding, id = self.table.create_binding(queue) binding.update(exchange=exchange, routing_key=routing_key or "", pattern=pattern or "", - queue=queue or "") + queue=queue or "", + id=id) binding.save() def get_table(self, exchange): @@ -191,27 +196,35 @@ class Channel(virtual.Channel): if "can't set attribute" not in str(exc): raise - def _aws_connect_to(self, fun): + def _aws_connect_to(self, fun, regions): conninfo = self.conninfo - return fun(self.region, aws_access_key_id=conninfo.userid, - aws_secret_access_key=conninfo.password, - port=conninfo.port) + region = None + if self.region: + for _r in regions: + if _r.name == self.region: + region = _r + break + return fun(region=region, + aws_access_key_id=conninfo.userid, + aws_secret_access_key=conninfo.password, + port=conninfo.port) @property def sqs(self): if self._sqs is None: - self._sqs = self._aws_connect_to(boto.sqs.connect_to_region) + self._sqs = self._aws_connect_to(SQSConnection, _sqs.regions()) return self._sqs @property def sdb(self): if self._sdb is None: - self._sdb = self._aws_connect_to(boto.sdb.connect_to_region) + self._sdb = self._aws_connect_to(SDBConnection, _sdb.regions()) return self._sdb @property def table(self): - name = self.domain_format % {"vhost": self.conninfo.vhost} + name = self.entity_name(self.domain_format % { + "vhost": self.conninfo.virtual_host}) d = self.sdb.get_object("CreateDomain", {"DomainName": name}, self.Table) d.name = name |