summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2014-04-23 23:00:03 +0100
committerAsk Solem <ask@celeryproject.org>2014-05-03 22:28:43 +0100
commit6a1abb7e946085befb3c7ea4a3e6d703160356e4 (patch)
tree4bbcf5ee7de3053ee2644011bdc1ed2cb518e5ef
parent01779ce7e64df48f98b051e5183b5fb38529c591 (diff)
downloadkombu-6a1abb7e946085befb3c7ea4a3e6d703160356e4.tar.gz
Initial import of kombu.async.aws
Tests: Use setup/teardown instead of setUp/tearDown Tests async.http: Auth support Tests for kombu.async.aws.connection All tests for SQS and implemented message batch operations SQS: Removes poor fanout support using SDB Async SQS working somewhat More fixes flakes SQS: Improves async queue/channel scheduling Adds Transport.implements for introspection of transport capabilities
-rw-r--r--kombu/async/aws/__init__.py9
-rw-r--r--kombu/async/aws/connection.py276
-rw-r--r--kombu/async/aws/sqs/__init__.py18
-rw-r--r--kombu/async/aws/sqs/connection.py197
-rw-r--r--kombu/async/aws/sqs/jsonmessage.py10
-rw-r--r--kombu/async/aws/sqs/message.py36
-rw-r--r--kombu/async/aws/sqs/queue.py130
-rw-r--r--kombu/async/http/__init__.py13
-rw-r--r--kombu/async/http/base.py9
-rw-r--r--kombu/async/http/curl.py36
-rw-r--r--kombu/async/hub.py9
-rw-r--r--kombu/connection.py4
-rw-r--r--kombu/tests/async/aws/__init__.py0
-rw-r--r--kombu/tests/async/aws/sqs/__init__.py0
-rw-r--r--kombu/tests/async/aws/sqs/test_connection.py310
-rw-r--r--kombu/tests/async/aws/sqs/test_message.py35
-rw-r--r--kombu/tests/async/aws/sqs/test_queue.py201
-rw-r--r--kombu/tests/async/aws/sqs/test_sqs.py25
-rw-r--r--kombu/tests/async/aws/test_aws.py14
-rw-r--r--kombu/tests/async/aws/test_connection.py427
-rw-r--r--kombu/tests/async/test_http.py65
-rw-r--r--kombu/tests/async/test_hub.py8
-rw-r--r--kombu/tests/case.py43
-rw-r--r--kombu/tests/test_compat.py6
-rw-r--r--kombu/tests/test_compression.py2
-rw-r--r--kombu/tests/test_connection.py10
-rw-r--r--kombu/tests/test_entities.py2
-rw-r--r--kombu/tests/test_log.py2
-rw-r--r--kombu/tests/test_messaging.py4
-rw-r--r--kombu/tests/test_mixins.py2
-rw-r--r--kombu/tests/test_pidbox.py2
-rw-r--r--kombu/tests/test_pools.py2
-rw-r--r--kombu/tests/test_simple.py4
-rw-r--r--kombu/tests/transport/test_SQS.py43
-rw-r--r--kombu/tests/transport/test_amqplib.py5
-rw-r--r--kombu/tests/transport/test_base.py13
-rw-r--r--kombu/tests/transport/test_filesystem.py2
-rw-r--r--kombu/tests/transport/test_librabbitmq.py6
-rw-r--r--kombu/tests/transport/test_memory.py2
-rw-r--r--kombu/tests/transport/test_pyamqp.py6
-rw-r--r--kombu/tests/transport/test_redis.py8
-rw-r--r--kombu/tests/transport/test_sqlalchemy.py2
-rw-r--r--kombu/tests/transport/virtual/test_base.py10
-rw-r--r--kombu/tests/transport/virtual/test_exchange.py6
-rw-r--r--kombu/tests/utils/test_encoding.py6
-rw-r--r--kombu/tests/utils/test_utils.py2
-rw-r--r--kombu/transport/SQS.py286
-rw-r--r--kombu/transport/amqplib.py6
-rw-r--r--kombu/transport/base.py43
-rw-r--r--kombu/transport/librabbitmq.py5
-rw-r--r--kombu/transport/mongodb.py4
-rw-r--r--kombu/transport/pyamqp.py7
-rw-r--r--kombu/transport/redis.py6
-rw-r--r--kombu/transport/virtual/__init__.py13
-rw-r--r--kombu/transport/zmq.py4
-rw-r--r--requirements/extras/sqs.txt2
56 files changed, 2041 insertions, 357 deletions
diff --git a/kombu/async/aws/__init__.py b/kombu/async/aws/__init__.py
new file mode 100644
index 00000000..43a01abd
--- /dev/null
+++ b/kombu/async/aws/__init__.py
@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+
+def connect_sqs(aws_access_key_id=None, aws_secret_access_key=None, **kwargs):
+ from .sqs.connection import AsyncSQSConnection
+ return AsyncSQSConnection(
+ aws_access_key_id, aws_secret_access_key, **kwargs
+ )
diff --git a/kombu/async/aws/connection.py b/kombu/async/aws/connection.py
new file mode 100644
index 00000000..532c8a7e
--- /dev/null
+++ b/kombu/async/aws/connection.py
@@ -0,0 +1,276 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+import mimetools
+
+from urlparse import urlunsplit
+from xml.sax import parseString as sax_parse
+
+from amqp.promise import promise, transform
+
+from kombu.async.http import Headers, Request, get_client
+from kombu.five import StringIO, items
+
+try:
+ import boto
+ from boto.connection import AWSAuthConnection, AWSQueryConnection
+ from boto.handler import XmlHandler
+ from boto.resultset import ResultSet
+except ImportError: # pragma: no cover
+ boto = ResultSet = XmlHandler = None # noqa
+
+ class _void(object):
+ pass
+ AWSAuthConnection, AWSQueryConnection = _void # noqa
+
+__all__ = ['AsyncHTTPConnection', 'AsyncHTTPSConnection',
+ 'AsyncHTTPResponse', 'AsyncConnection',
+ 'AsyncAWSAuthConnection', 'AsyncAWSQueryConnection']
+
+
+class AsyncHTTPResponse(object):
+
+ def __init__(self, response):
+ self.response = response
+ self._msg = None
+ self.version = 10
+
+ def read(self, *args, **kwargs):
+ return self.response.body
+
+ def getheader(self, name, default=None):
+ return self.response.headers.get(name, default)
+
+ def getheaders(self):
+ return list(items(self.response.headers))
+
+ @property
+ def msg(self):
+ if self._msg is None:
+ self._msg = mimetools.Message(
+ StringIO('\r\n'.join(
+ '{0}: {1}'.format(*h) for h in self.getheaders())
+ )
+ )
+ return self._msg
+
+ @property
+ def status(self):
+ return self.response.code
+
+ @property
+ def reason(self):
+ if self.response.error:
+ return self.response.error.message
+ return ''
+
+ def __repr__(self):
+ return repr(self.response)
+
+
+class AsyncHTTPConnection(object):
+ Request = Request
+ Response = AsyncHTTPResponse
+
+ method = 'GET'
+ path = '/'
+ body = None
+ scheme = 'http'
+ default_ports = {'http': 80, 'https': 443}
+
+ def __init__(self, host, port=None,
+ strict=None, timeout=20.0, http_client=None, **kwargs):
+ self.host = host
+ self.port = port
+ self.headers = []
+ self.timeout = timeout
+ self.strict = strict
+ self.http_client = http_client or get_client()
+
+ def request(self, method, path, body=None, headers=None):
+ self.path = path
+ self.method = method
+ if body is not None:
+ try:
+ read = body.read
+ except AttributeError:
+ self.body = body
+ else:
+ self.body = read()
+ if headers is not None:
+ self.headers.extend(list(items(headers)))
+
+ def getrequest(self, scheme=None):
+ scheme = scheme if scheme else self.scheme
+ host = self.host
+ if self.port and self.port != self.default_ports[scheme]:
+ host = '{0}:{1}'.format(host, self.port)
+ url = urlunsplit((scheme, host, self.path, '', ''))
+ headers = Headers(self.headers)
+ return self.Request(url, method=self.method, headers=headers,
+ body=self.body, connect_timeout=self.timeout,
+ request_timeout=self.timeout, validate_cert=False)
+
+ def getresponse(self, callback=None):
+ request = self.getrequest()
+ request.then(transform(self.Response, callback))
+ return self.http_client.add_request(request)
+
+ def set_debuglevel(self, level):
+ pass
+
+ def connect(self):
+ pass
+
+ def close(self):
+ pass
+
+ def putrequest(self, method, path, **kwargs):
+ self.method = method
+ self.path = path
+
+ def putheader(self, header, value):
+ self.headers.append((header, value))
+
+ def endheaders(self):
+ pass
+
+ def send(self, data):
+ if self.body:
+ self.body += data
+ else:
+ self.body = data
+
+ def __repr__(self):
+ return '<AsyncHTTPConnection: {0!r}>'.format(self.getrequest())
+
+
+class AsyncHTTPSConnection(AsyncHTTPConnection):
+ scheme = 'https'
+
+
+class AsyncConnection(object):
+
+ def __init__(self, http_client=None, **kwargs):
+ self._httpclient = http_client or get_client()
+
+ def get_http_connection(self, host, port, is_secure):
+ return (AsyncHTTPSConnection if is_secure else AsyncHTTPConnection)(
+ host, port, http_client=self._httpclient,
+ )
+
+ def _mexe(self, request, sender=None, callback=None):
+ callback = callback or promise()
+ boto.log.debug(
+ 'HTTP %s %s/%s headers=%s body=%s',
+ request.host, request.path,
+ request.headers, request.body,
+ )
+
+ conn = self.get_http_connection(
+ request.host, request.port, self.is_secure,
+ )
+ request.authorize(connection=self)
+
+ if callable(sender):
+ sender(conn, request.method, request.path, request.body,
+ request.headers, callback)
+ else:
+ conn.request(request.method, request.path,
+ request.body, request.headers)
+ conn.getresponse(callback=callback)
+ return callback
+
+
+class AsyncAWSAuthConnection(AsyncConnection, AWSAuthConnection):
+
+ def __init__(self, host,
+ http_client=None, http_client_params={}, **kwargs):
+ AsyncConnection.__init__(self, http_client, **http_client_params)
+ AWSAuthConnection.__init__(self, host, **kwargs)
+
+ def make_request(self, method, path, headers=None, data='', host=None,
+ auth_path=None, sender=None, callback=None, **kwargs):
+ req = self.build_base_http_request(
+ method, path, auth_path, {}, headers, data, host,
+ )
+ return self._mexe(req, sender=sender, callback=callback)
+
+
+class AsyncAWSQueryConnection(AsyncConnection, AWSQueryConnection):
+
+ def __init__(self, host,
+ http_client=None, http_client_params={}, **kwargs):
+ AsyncConnection.__init__(self, http_client, **http_client_params)
+ AWSAuthConnection.__init__(self, host, **kwargs)
+
+ def make_request(self, action, params, path, verb, callback=None):
+ request = self.build_base_http_request(
+ verb, path, None, params, {}, '', self.server_name())
+ if action:
+ request.params['Action'] = action
+ request.params['Version'] = self.APIVersion
+ return self._mexe(request, callback=callback)
+
+ def get_list(self, action, params, markers,
+ path='/', parent=None, verb='GET', callback=None):
+ return self.make_request(
+ action, params, path, verb,
+ callback=transform(
+ self._on_list_ready, callback, parent or self, markers,
+ ),
+ )
+
+ def get_object(self, action, params, cls,
+ path='/', parent=None, verb='GET', callback=None):
+ return self.make_request(
+ action, params, path, verb,
+ callback=transform(
+ self._on_obj_ready, callback, parent or self, cls,
+ ),
+ )
+
+ def get_status(self, action, params,
+ path='/', parent=None, verb='GET', callback=None):
+ return self.make_request(
+ action, params, path, verb,
+ callback=transform(
+ self._on_status_ready, callback, parent or self,
+ ),
+ )
+
+ def _on_list_ready(self, parent, markers, response):
+ body = response.read()
+ if response.status == 200 and body:
+ rs = ResultSet(markers)
+ h = XmlHandler(rs, parent)
+ sax_parse(body, h)
+ return rs
+ else:
+ raise self._for_status(response, body)
+
+ def _on_obj_ready(self, parent, cls, response):
+ body = response.read()
+ if response.status == 200 and body:
+ obj = cls(parent)
+ h = XmlHandler(obj, parent)
+ sax_parse(body, h)
+ return obj
+ else:
+ raise self._for_status(response, body)
+
+ def _on_status_ready(self, parent, response):
+ body = response.read()
+ if response.status == 200 and body:
+ rs = ResultSet()
+ h = XmlHandler(rs, parent)
+ sax_parse(body, h)
+ return rs.status
+ else:
+ raise self._for_status(response, body)
+
+ def _for_status(self, response, body):
+ context = 'Empty body' if not body else 'HTTP Error'
+ exc = self.ResponseError(response.status, response.reason, body)
+ boto.log.error('{0}: %r'.format(context), exc)
+ return exc
diff --git a/kombu/async/aws/sqs/__init__.py b/kombu/async/aws/sqs/__init__.py
new file mode 100644
index 00000000..d5fb1bbb
--- /dev/null
+++ b/kombu/async/aws/sqs/__init__.py
@@ -0,0 +1,18 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from boto.regioninfo import get_regions
+
+from .connection import AsyncSQSConnection
+
+__all__ = ['regions', 'connect_to_region']
+
+
+def regions():
+ return get_regions('sqs', connection_cls=AsyncSQSConnection)
+
+
+def connect_to_region(region_name, **kwargs):
+ for region in regions():
+ if region.name == region_name:
+ return region.connect(**kwargs)
diff --git a/kombu/async/aws/sqs/connection.py b/kombu/async/aws/sqs/connection.py
new file mode 100644
index 00000000..4df9fbd3
--- /dev/null
+++ b/kombu/async/aws/sqs/connection.py
@@ -0,0 +1,197 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from amqp.promise import transform
+
+from boto.regioninfo import RegionInfo
+from boto.sqs import connection as _connection
+from boto.sqs.attributes import Attributes
+from boto.sqs.batchresults import BatchResults
+
+from kombu.async.aws.connection import AsyncAWSQueryConnection
+
+from .message import AsyncMessage
+from .queue import AsyncQueue
+
+__all__ = ['AsyncSQSConnection']
+
+
+class AsyncSQSConnection(AsyncAWSQueryConnection, _connection.SQSConnection):
+
+ def __init__(self, aws_access_key_id=None, aws_secret_access_key=None,
+ is_secure=True, port=None, proxy=None, proxy_port=None,
+ proxy_user=None, proxy_pass=None, debug=0,
+ https_connection_factory=None, region=None, *args, **kwargs):
+ self.region = region or RegionInfo(
+ self, self.DefaultRegionName, self.DefaultRegionEndpoint,
+ connection_cls=type(self),
+ )
+ AsyncAWSQueryConnection.__init__(
+ self,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ is_secure=is_secure, port=port,
+ proxy=proxy, proxy_port=proxy_port,
+ proxy_user=proxy_user, proxy_pass=proxy_pass,
+ host=self.region.endpoint, debug=debug,
+ https_connection_factory=https_connection_factory, **kwargs
+ )
+
+ def create_queue(self, queue_name,
+ visibility_timeout=None, callback=None):
+ params = {'QueueName': queue_name}
+ if visibility_timeout:
+ params['DefaultVisibilityTimeout'] = format(
+ visibility_timeout, 'd',
+ )
+ return self.get_object('CreateQueue', params, AsyncQueue,
+ callback=callback)
+
+ def delete_queue(self, queue, force_deletion=False, callback=None):
+ return self.get_status('DeleteQueue', None, queue.id,
+ callback=callback)
+
+ def get_queue_attributes(self, queue, attribute='All', callback=None):
+ return self.get_object(
+ 'GetQueueAttributes', {'AttributeName': attribute},
+ Attributes, queue.id, callback=callback,
+ )
+
+ def set_queue_attribute(self, queue, attribute, value, callback=None):
+ return self.get_status(
+ 'SetQueueAttribute',
+ {'Attribute.Name': attribute, 'Attribute.Value': value},
+ queue.id, callback=callback,
+ )
+
+ def receive_message(self, queue,
+ number_messages=1, visibility_timeout=None,
+ attributes=None, wait_time_seconds=None,
+ callback=None):
+ params = {'MaxNumberOfMessages': number_messages}
+ if visibility_timeout:
+ params['VisibilityTimeout'] = visibility_timeout
+ if attributes:
+ self.build_list_params(params, attributes, 'AttributeName')
+ if wait_time_seconds is not None:
+ params['WaitTimeSeconds'] = wait_time_seconds
+ return self.get_list(
+ 'ReceiveMessage', params, [('Message', queue.message_class)],
+ queue.id, callback=callback,
+ )
+
+ def delete_message(self, queue, message, callback=None):
+ return self.delete_message_from_handle(
+ queue, message.receipt_handle, callback,
+ )
+
+ def delete_message_batch(self, queue, messages, callback=None):
+ params = {}
+ for i, m in enumerate(messages):
+ prefix = 'DeleteMessageBatchRequestEntry.{0}'.format(i + 1)
+ params.update({
+ '{0}.Id'.format(prefix): m.id,
+ '{0}.ReceiptHandle'.format(prefix): m.receipt_handle,
+ })
+ return self.get_object(
+ 'DeleteMessageBatch', params, BatchResults, queue.id,
+ verb='POST', callback=callback,
+ )
+
+ def delete_message_from_handle(self, queue, receipt_handle,
+ callback=None):
+ return self.get_status(
+ 'DeleteMessage', {'ReceiptHandle': receipt_handle},
+ queue.id, callback=callback,
+ )
+
+ def send_message(self, queue, message_content,
+ delay_seconds=None, callback=None):
+ params = {'MessageBody': message_content}
+ if delay_seconds:
+ params['DelaySeconds'] = int(delay_seconds)
+ return self.get_object(
+ 'SendMessage', params, AsyncMessage, queue.id,
+ verb='POST', callback=callback,
+ )
+
+ def send_message_batch(self, queue, messages, callback=None):
+ params = {}
+ for i, msg in enumerate(messages):
+ prefix = 'SendMessageBatchRequestEntry.{0}'.format(i + 1)
+ params.update({
+ '{0}.Id'.format(prefix): msg[0],
+ '{0}.MessageBody'.format(prefix): msg[1],
+ '{0}.DelaySeconds'.format(prefix): msg[2],
+ })
+ return self.get_object(
+ 'SendMessageBatch', params, BatchResults, queue.id,
+ verb='POST', callback=callback,
+ )
+
+ def change_message_visibility(self, queue, receipt_handle,
+ visibility_timeout, callback=None):
+ return self.get_status(
+ 'ChangeMessageVisibility',
+ {'ReceiptHandle': receipt_handle,
+ 'VisibilityTimeout': visibility_timeout},
+ queue.id, callback=callback,
+ )
+
+ def change_message_visibility_batch(self, queue, messages, callback=None):
+ params = {}
+ for i, t in enumerate(messages):
+ pre = 'ChangeMessageVisibilityBatchRequestEntry.{0}'.format(i + 1)
+ params.update({
+ '{0}.Id'.format(pre): t[0].id,
+ '{0}.ReceiptHandle'.format(pre): t[0].receipt_handle,
+ '{0}.VisibilityTimeout'.format(pre): t[1],
+ })
+ return self.get_object(
+ 'ChangeMessageVisibilityBatch', params, BatchResults, queue.id,
+ verb='POST', callback=callback,
+ )
+
+ def get_all_queues(self, prefix='', callback=None):
+ params = {}
+ if prefix:
+ params['QueueNamePrefix'] = prefix
+ return self.get_list(
+ 'ListQueues', params, [('QueueUrl', AsyncQueue)],
+ callback=callback,
+ )
+
+ def get_queue(self, queue_name, callback=None):
+ # TODO Does not support owner_acct_id argument
+ return self.get_all_queues(
+ queue_name,
+ transform(self._on_queue_ready, callback, queue_name),
+ )
+ lookup = get_queue
+
+ def _on_queue_ready(self, name, queues):
+ return next(
+ (q for q in queues if q.url.endswith(name)), None,
+ )
+
+ def get_dead_letter_source_queues(self, queue, callback=None):
+ return self.get_list(
+ 'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
+ [('QueueUrl', AsyncQueue)],
+ callback=callback,
+ )
+
+ def add_permission(self, queue, label, aws_account_id, action_name,
+ callback=None):
+ return self.get_status(
+ 'AddPermission',
+ {'Label': label,
+ 'AWSAccountId': aws_account_id,
+ 'ActionName': action_name},
+ queue.id, callback=callback,
+ )
+
+ def remove_permission(self, queue, label, callback=None):
+ return self.get_status(
+ 'RemovePermission', {'Label': label}, queue.id, callback=callback,
+ )
diff --git a/kombu/async/aws/sqs/jsonmessage.py b/kombu/async/aws/sqs/jsonmessage.py
new file mode 100644
index 00000000..e5ed0e7f
--- /dev/null
+++ b/kombu/async/aws/sqs/jsonmessage.py
@@ -0,0 +1,10 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from boto.sqs import jsonmessage as _jsonmessage
+
+from .message import BaseAsyncMessage
+
+
+class AsyncJSONMessage(BaseAsyncMessage, _jsonmessage.JSONMessage):
+ pass
diff --git a/kombu/async/aws/sqs/message.py b/kombu/async/aws/sqs/message.py
new file mode 100644
index 00000000..8141b6d2
--- /dev/null
+++ b/kombu/async/aws/sqs/message.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from boto.sqs import message as _message
+
+__all__ = ['BaseAsyncMessage', 'AsyncRawMessage', 'AsyncMessage',
+ 'AsyncMHMessage', 'AsyncEncodedMHMessage']
+
+
+class BaseAsyncMessage(object):
+
+ def delete(self, callback=None):
+ if self.queue:
+ return self.queue.delete_message(self, callback)
+
+ def change_visibility(self, visibility_timeout, callback=None):
+ if self.queue:
+ return self.queue.connection.change_message_visibility(
+ self.queue, self.receipt_handle, visibility_timeout, callback,
+ )
+
+
+class AsyncRawMessage(BaseAsyncMessage, _message.RawMessage):
+ pass
+
+
+class AsyncMessage(BaseAsyncMessage, _message.Message):
+ pass
+
+
+class AsyncMHMessage(BaseAsyncMessage, _message.MHMessage):
+ pass
+
+
+class AsyncEncodedMHMessage(BaseAsyncMessage, _message.EncodedMHMessage):
+ pass
diff --git a/kombu/async/aws/sqs/queue.py b/kombu/async/aws/sqs/queue.py
new file mode 100644
index 00000000..038c6a01
--- /dev/null
+++ b/kombu/async/aws/sqs/queue.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from amqp.promise import transform
+
+from boto.exception import BotoClientError
+from boto.sqs import queue as _queue
+
+from .message import AsyncMessage
+
+_all__ = ['AsyncQueue']
+
+
+def list_first(rs):
+ return rs[0] if len(rs) == 1 else None
+
+
+class AsyncQueue(_queue.Queue):
+
+ def __init__(self, connection=None, url=None, message_class=AsyncMessage):
+ self.connection = connection
+ self.url = url
+ self.message_class = message_class
+ self.visibility_timeout = None
+
+ def _NA(self, *args, **kwargs):
+ raise BotoClientError('Not implemented')
+ count_slow = dump = save_to_file = save_to_filename = save = \
+ save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \
+ load = clear = _NA
+
+ def get_attributes(self, attributes='All', callback=None):
+ return self.connection.get_queue_attributes(
+ self, attributes, callback,
+ )
+
+ def set_attribute(self, attribute, value, callback=None):
+ return self.connection.set_queue_attribute(
+ self, attribute, value, callback,
+ )
+
+ def get_timeout(self, callback=None, _attr='VisibilityTimeout'):
+ return self.get_attributes(
+ _attr, transform(
+ self._coerce_field_value, callback, _attr, int,
+ ),
+ )
+
+ def _coerce_field_value(self, key, type, response):
+ return type(response[key])
+
+ def set_timeout(self, visibility_timeout, callback=None):
+ return self.set_attribute(
+ 'VisibilityTimeout', visibility_timeout,
+ transform(
+ self._on_timeout_set, callback,
+ )
+ )
+
+ def _on_timeout_set(self, visibility_timeout):
+ if visibility_timeout:
+ self.visibility_timeout = visibility_timeout
+ return self.visibility_timeout
+
+ def add_permission(self, label, aws_account_id, action_name,
+ callback=None):
+ return self.connection.add_permission(
+ self, label, aws_account_id, action_name, callback,
+ )
+
+ def remove_permission(self, label, callback=None):
+ return self.connection.remove_permission(self, label, callback)
+
+ def read(self, visibility_timeout=None, wait_time_seconds=None,
+ callback=None):
+ return self.get_messages(
+ 1, visibility_timeout,
+ wait_time_seconds=wait_time_seconds,
+ callback=transform(list_first, callback),
+ )
+
+ def write(self, message, delay_seconds=None, callback=None):
+ return self.connection.send_message(
+ self, message.get_body_encoded(), delay_seconds,
+ callback=transform(self._on_message_sent, callback, message),
+ )
+
+ def write_batch(self, messages, callback=None):
+ return self.connection.send_message_batch(
+ self, messages, callback=callback,
+ )
+
+ def _on_message_sent(self, orig_message, new_message):
+ orig_message.id = new_message.id
+ orig_message.md5 = new_message.md5
+ return new_message
+
+ def get_messages(self, num_messages=1, visibility_timeout=None,
+ attributes=None, wait_time_seconds=None, callback=None):
+ return self.connection.receive_message(
+ self, number_messages=num_messages,
+ visibility_timeout=visibility_timeout,
+ attributes=attributes,
+ wait_time_seconds=wait_time_seconds,
+ callback=callback,
+ )
+
+ def delete_message(self, message, callback=None):
+ return self.connection.delete_message(self, message, callback)
+
+ def delete_message_batch(self, messages, callback=None):
+ return self.connection.delete_message_batch(
+ self, messages, callback=callback,
+ )
+
+ def change_message_visibility_batch(self, messages, callback=None):
+ return self.connection.change_message_visibility_batch(
+ self, messages, callback=callback,
+ )
+
+ def delete(self, callback=None):
+ return self.connection.delete_queue(self, callback=callback)
+
+ def count(self, page_size=10, vtimeout=10, callback=None,
+ _attr='ApproximateNumberOfMessages'):
+ return self.get_attributes(
+ _attr, callback=transform(
+ self._coerce_field_value, callback, _attr, int,
+ ),
+ )
diff --git a/kombu/async/http/__init__.py b/kombu/async/http/__init__.py
index 8c18dfa9..d1b3deda 100644
--- a/kombu/async/http/__init__.py
+++ b/kombu/async/http/__init__.py
@@ -1,10 +1,21 @@
from __future__ import absolute_import
+from kombu.async import get_event_loop
+
from .base import Request, Headers, Response
__all__ = ['Client', 'Headers', 'Response', 'Request']
-def Client(hub, **kwargs):
+def Client(hub=None, **kwargs):
from .curl import CurlClient
return CurlClient(hub, **kwargs)
+
+
+def get_client(hub=None, **kwargs):
+ hub = hub or get_event_loop()
+ try:
+ return hub._current_http_client
+ except AttributeError:
+ client = hub._current_http_client = Client(hub, **kwargs)
+ return client
diff --git a/kombu/async/http/base.py b/kombu/async/http/base.py
index 7a3dfd28..35fbd831 100644
--- a/kombu/async/http/base.py
+++ b/kombu/async/http/base.py
@@ -6,7 +6,7 @@ from amqp.promise import Thenable, promise, maybe_promise
from kombu.exceptions import HttpError
from kombu.five import items
-from kombu.utils import cached_property, coro
+from kombu.utils import coro
from kombu.utils.encoding import bytes_to_str
from kombu.utils.functional import maybe_list, memoize
@@ -55,6 +55,9 @@ class Request(object):
default).
:keyword validate_cert: Set to true if the server certificate should be
verified when performing ``https://`` requests (enabled by default).
+ :keyword auth_username: Username for HTTP authentication.
+ :keyword auth_password: Password for HTTP authentication.
+ :keyword auth_mode: Type of HTTP authentication (``basic`` or ``digest``).
:keyword user_agent: Custom user agent for this request.
:keyword network_interace: Network interface to use for this request.
:keyword on_ready: Callback to be called when the response has been
@@ -84,6 +87,7 @@ class Request(object):
"""
body = user_agent = network_interface = \
+ auth_username = auth_password = auth_mode = \
proxy_host = proxy_port = proxy_username = proxy_password = \
ca_certs = client_key = client_cert = None
@@ -118,6 +122,9 @@ class Request(object):
def then(self, callback, errback=None):
self.on_ready.then(callback, errback)
+
+ def __repr__(self):
+ return '<Request: {0.method} {0.url} {0.body}>'.format(self)
Thenable.register(Request)
diff --git a/kombu/async/http/curl.py b/kombu/async/http/curl.py
index b4c516a5..d76c8e75 100644
--- a/kombu/async/http/curl.py
+++ b/kombu/async/http/curl.py
@@ -7,7 +7,7 @@ from functools import partial
from io import BytesIO
from time import time
-from kombu.async.hub import READ, WRITE
+from kombu.async.hub import READ, WRITE, get_event_loop
from kombu.exceptions import HttpError
from kombu.five import items
from kombu.utils.encoding import bytes_to_str
@@ -37,9 +37,10 @@ EXTRA_METHODS = frozenset(['DELETE', 'OPTIONS', 'PATCH'])
class CurlClient(BaseClient):
Curl = Curl
- def __init__(self, hub, max_clients=10):
+ def __init__(self, hub=None, max_clients=10):
if pycurl is None:
raise ImportError('The curl client requires the pycurl library.')
+ hub = hub or get_event_loop()
super(CurlClient, self).__init__(hub)
self.max_clients = max_clients
@@ -71,6 +72,7 @@ class CurlClient(BaseClient):
self._pending.append(request)
self._process_queue()
self._set_timeout(0)
+ return request
def _handle_socket(self, event, fd, multi, data, _pycurl=pycurl):
if event == _pycurl.POLL_REMOVE:
@@ -186,9 +188,9 @@ class CurlClient(BaseClient):
request.headers.setdefault('Expect', '')
request.headers.setdefault('Pragma', '')
- curl.setopt(
+ setopt(
_pycurl.HTTPHEADER,
- ['%s %s'.format(h) for h in items(request.headers)],
+ ['{0}: {1}'.format(*h) for h in items(request.headers)],
)
setopt(
@@ -240,8 +242,8 @@ class CurlClient(BaseClient):
setopt(meth, True)
if request.method in ('POST', 'PUT'):
- assert request.body is not None
- reqbuffer = BytesIO(request.body)
+ body = request.body or ''
+ reqbuffer = BytesIO(body)
setopt(_pycurl.READFUNCTION, reqbuffer.read)
if request.method == 'POST':
@@ -249,14 +251,24 @@ class CurlClient(BaseClient):
if cmd == _pycurl.IOCMD_RESTARTREAD:
reqbuffer.seek(0)
setopt(_pycurl.IOCTLFUNCTION, ioctl)
- setopt(_pycurl.POSTFIELDSIZE, len(request.body))
+ setopt(_pycurl.POSTFIELDSIZE, len(body))
else:
- setopt(_pycurl.INFILESIZE, len(request.body))
+ setopt(_pycurl.INFILESIZE, len(body))
elif request.method == 'GET':
- assert request.body is None
-
- # TODO Does not support Basic AUTH
- curl.unsetopt(_pycurl.USERPWD)
+ assert not request.body
+
+ if request.auth_username is not None:
+ auth_mode = {
+ 'basic': _pycurl.HTTPAUTH_BASIC,
+ 'digest': _pycurl.HTTPAUTH_DIGEST
+ }[request.auth_mode or 'basic']
+ setopt(_pycurl.HTTPAUTH, auth_mode)
+ userpwd = '{0}:{1}'.format(
+ request.auth_username, request.auth_password or '',
+ )
+ setopt(_pycurl.USERPWD, userpwd)
+ else:
+ curl.unsetopt(_pycurl.USERPWD)
if request.client_cert is not None:
setopt(_pycurl.SSLCERT, request.client_cert)
diff --git a/kombu/async/hub.py b/kombu/async/hub.py
index 264f8188..25c71cdb 100644
--- a/kombu/async/hub.py
+++ b/kombu/async/hub.py
@@ -15,7 +15,7 @@ from contextlib import contextmanager
from time import sleep
from types import GeneratorType as generator
-from amqp import promise
+from amqp.promise import promise, Thenable
from kombu.five import Empty, range
from kombu.log import get_logger
@@ -184,9 +184,10 @@ class Hub(object):
self._loop = None
def call_soon(self, callback, *args):
- handle = promise(callback, args)
- self._ready.append(handle)
- return handle
+ if not isinstance(callback, Thenable):
+ callback = promise(callback, args)
+ self._ready.append(callback)
+ return callback
def call_later(self, delay, callback, *args):
return self.timer.call_after(delay, callback, args)
diff --git a/kombu/connection.py b/kombu/connection.py
index 85b8f5e9..873b422b 100644
--- a/kombu/connection.py
+++ b/kombu/connection.py
@@ -829,11 +829,11 @@ class Connection(object):
@property
def supports_heartbeats(self):
- return self.transport.supports_heartbeats
+ return self.transport.implements.heartbeats
@property
def is_evented(self):
- return self.transport.supports_ev
+ return self.transport.implements.async
BrokerConnection = Connection
diff --git a/kombu/tests/async/aws/__init__.py b/kombu/tests/async/aws/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/kombu/tests/async/aws/__init__.py
diff --git a/kombu/tests/async/aws/sqs/__init__.py b/kombu/tests/async/aws/sqs/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/kombu/tests/async/aws/sqs/__init__.py
diff --git a/kombu/tests/async/aws/sqs/test_connection.py b/kombu/tests/async/aws/sqs/test_connection.py
new file mode 100644
index 00000000..1ea6ba09
--- /dev/null
+++ b/kombu/tests/async/aws/sqs/test_connection.py
@@ -0,0 +1,310 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from kombu.async.aws.sqs.connection import (
+ AsyncSQSConnection, Attributes, BatchResults,
+)
+from kombu.async.aws.sqs.message import AsyncMessage
+from kombu.async.aws.sqs.queue import AsyncQueue
+from kombu.utils import uuid
+
+from kombu.tests.case import HubCase, PromiseMock, Mock
+
+
+class test_AsyncSQSConnection(HubCase):
+
+ def setup(self):
+ self.x = AsyncSQSConnection('ak', 'sk')
+ self.x.get_object = Mock(name='X.get_object')
+ self.x.get_status = Mock(name='X.get_status')
+ self.x.get_list = Mock(nanme='X.get_list')
+ self.callback = PromiseMock(name='callback')
+
+ def test_default_region(self):
+ self.assertTrue(self.x.region)
+ self.assertTrue(issubclass(
+ self.x.region.connection_cls, AsyncSQSConnection,
+ ))
+
+ def test_create_queue(self):
+ self.x.create_queue('foo', callback=self.callback)
+ self.x.get_object.assert_called_with(
+ 'CreateQueue', {'QueueName': 'foo'}, AsyncQueue,
+ callback=self.callback,
+ )
+
+ def test_create_queue__with_visibility_timeout(self):
+ self.x.create_queue(
+ 'foo', visibility_timeout=33, callback=self.callback,
+ )
+ self.x.get_object.assert_called_with(
+ 'CreateQueue', {
+ 'QueueName': 'foo',
+ 'DefaultVisibilityTimeout': '33'
+ },
+ AsyncQueue, callback=self.callback
+ )
+
+ def test_delete_queue(self):
+ queue = Mock(name='queue')
+ self.x.delete_queue(queue, callback=self.callback)
+ self.x.get_status.assert_called_with(
+ 'DeleteQueue', None, queue.id, callback=self.callback,
+ )
+
+ def test_get_queue_attributes(self):
+ queue = Mock(name='queue')
+ self.x.get_queue_attributes(
+ queue, attribute='QueueSize', callback=self.callback,
+ )
+ self.x.get_object.assert_called_with(
+ 'GetQueueAttributes', {'AttributeName': 'QueueSize'},
+ Attributes, queue.id, callback=self.callback,
+ )
+
+ def test_set_queue_attribute(self):
+ queue = Mock(name='queue')
+ self.x.set_queue_attribute(
+ queue, 'Expires', '3600', callback=self.callback,
+ )
+ self.x.get_status.assert_called_with(
+ 'SetQueueAttribute', {
+ 'Attribute.Name': 'Expires',
+ 'Attribute.Value': '3600',
+ },
+ queue.id, callback=self.callback,
+ )
+
+ def test_receive_message(self):
+ queue = Mock(name='queue')
+ self.x.receive_message(queue, 4, callback=self.callback)
+ self.x.get_list.assert_called_with(
+ 'ReceiveMessage', {'MaxNumberOfMessages': 4},
+ [('Message', queue.message_class)],
+ queue.id, callback=self.callback,
+ )
+
+ def test_receive_message__with_visibility_timeout(self):
+ queue = Mock(name='queue')
+ self.x.receive_message(queue, 4, 3666, callback=self.callback)
+ self.x.get_list.assert_called_with(
+ 'ReceiveMessage', {
+ 'MaxNumberOfMessages': 4,
+ 'VisibilityTimeout': 3666,
+ },
+ [('Message', queue.message_class)],
+ queue.id, callback=self.callback,
+ )
+
+ def test_receive_message__with_wait_time_seconds(self):
+ queue = Mock(name='queue')
+ self.x.receive_message(
+ queue, 4, wait_time_seconds=303, callback=self.callback,
+ )
+ self.x.get_list.assert_called_with(
+ 'ReceiveMessage', {
+ 'MaxNumberOfMessages': 4,
+ 'WaitTimeSeconds': 303,
+ },
+ [('Message', queue.message_class)],
+ queue.id, callback=self.callback,
+ )
+
+ def test_receive_message__with_attributes(self):
+ queue = Mock(name='queue')
+ self.x.receive_message(
+ queue, 4, attributes=['foo', 'bar'], callback=self.callback,
+ )
+ self.x.get_list.assert_called_with(
+ 'ReceiveMessage', {
+ 'AttributeName.1': 'foo',
+ 'AttributeName.2': 'bar',
+ 'MaxNumberOfMessages': 4,
+ },
+ [('Message', queue.message_class)],
+ queue.id, callback=self.callback,
+ )
+
+ def MockMessage(self, id=None, receipt_handle=None, body=None):
+ m = Mock(name='message')
+ m.id = id or uuid()
+ m.receipt_handle = receipt_handle or uuid()
+ m._body = body
+
+ def _get_body():
+ return m._body
+ m.get_body.side_effect = _get_body
+
+ def _set_body(value):
+ m._body = value
+ m.set_body.side_effect = _set_body
+
+ return m
+
+ def test_delete_message(self):
+ queue = Mock(name='queue')
+ message = self.MockMessage()
+ self.x.delete_message(queue, message, callback=self.callback)
+ self.x.get_status.assert_called_with(
+ 'DeleteMessage', {'ReceiptHandle': message.receipt_handle},
+ queue.id, callback=self.callback,
+ )
+
+ def test_delete_message_batch(self):
+ queue = Mock(name='queue')
+ messages = [self.MockMessage('1', 'r1'),
+ self.MockMessage('2', 'r2')]
+ self.x.delete_message_batch(queue, messages, callback=self.callback)
+ self.x.get_object.assert_called_with(
+ 'DeleteMessageBatch', {
+ 'DeleteMessageBatchRequestEntry.1.Id': '1',
+ 'DeleteMessageBatchRequestEntry.1.ReceiptHandle': 'r1',
+ 'DeleteMessageBatchRequestEntry.2.Id': '2',
+ 'DeleteMessageBatchRequestEntry.2.ReceiptHandle': 'r2',
+ },
+ BatchResults, queue.id, verb='POST', callback=self.callback,
+ )
+
+ def test_send_message(self):
+ queue = Mock(name='queue')
+ self.x.send_message(queue, 'hello', callback=self.callback)
+ self.x.get_object.assert_called_with(
+ 'SendMessage', {'MessageBody': 'hello'},
+ AsyncMessage, queue.id, verb='POST', callback=self.callback,
+ )
+
+ def test_send_message__with_delay_seconds(self):
+ queue = Mock(name='queue')
+ self.x.send_message(
+ queue, 'hello', delay_seconds='303', callback=self.callback,
+ )
+ self.x.get_object.assert_called_with(
+ 'SendMessage', {'MessageBody': 'hello', 'DelaySeconds': 303},
+ AsyncMessage, queue.id, verb='POST', callback=self.callback,
+ )
+
+ def test_send_message_batch(self):
+ queue = Mock(name='queue')
+ messages = [self.MockMessage('1', 'r1', 'A'),
+ self.MockMessage('2', 'r2', 'B')]
+ self.x.send_message_batch(
+ queue, [(m.id, m.get_body(), 303) for m in messages],
+ callback=self.callback
+ )
+ self.x.get_object.assert_called_with(
+ 'SendMessageBatch', {
+ 'SendMessageBatchRequestEntry.1.Id': '1',
+ 'SendMessageBatchRequestEntry.1.MessageBody': 'A',
+ 'SendMessageBatchRequestEntry.1.DelaySeconds': 303,
+ 'SendMessageBatchRequestEntry.2.Id': '2',
+ 'SendMessageBatchRequestEntry.2.MessageBody': 'B',
+ 'SendMessageBatchRequestEntry.2.DelaySeconds': 303,
+ },
+ BatchResults, queue.id, verb='POST', callback=self.callback,
+ )
+
+ def test_change_message_visibility(self):
+ queue = Mock(name='queue')
+ self.x.change_message_visibility(
+ queue, 'rcpt', 33, callback=self.callback,
+ )
+ self.x.get_status.assert_called_with(
+ 'ChangeMessageVisibility', {
+ 'ReceiptHandle': 'rcpt',
+ 'VisibilityTimeout': 33,
+ },
+ queue.id, callback=self.callback,
+ )
+
+ def test_change_message_visibility_batch(self):
+ queue = Mock(name='queue')
+ messages = [
+ (self.MockMessage('1', 'r1'), 303),
+ (self.MockMessage('2', 'r2'), 909),
+ ]
+ self.x.change_message_visibility_batch(
+ queue, messages, callback=self.callback,
+ )
+
+ def preamble(n):
+ return '.'.join(['ChangeMessageVisibilityBatchRequestEntry', n])
+
+ self.x.get_object.assert_called_with(
+ 'ChangeMessageVisibilityBatch', {
+ preamble('1.Id'): '1',
+ preamble('1.ReceiptHandle'): 'r1',
+ preamble('1.VisibilityTimeout'): 303,
+ preamble('2.Id'): '2',
+ preamble('2.ReceiptHandle'): 'r2',
+ preamble('2.VisibilityTimeout'): 909,
+ },
+ BatchResults, queue.id, verb='POST', callback=self.callback,
+ )
+
+ def test_get_all_queues(self):
+ self.x.get_all_queues(callback=self.callback)
+ self.x.get_list.assert_called_with(
+ 'ListQueues', {}, [('QueueUrl', AsyncQueue)],
+ callback=self.callback,
+ )
+
+ def test_get_all_queues__with_prefix(self):
+ self.x.get_all_queues(prefix='kombu.', callback=self.callback)
+ self.x.get_list.assert_called_with(
+ 'ListQueues', {'QueueNamePrefix': 'kombu.'},
+ [('QueueUrl', AsyncQueue)],
+ callback=self.callback,
+ )
+
+ def MockQueue(self, url):
+ q = Mock(name='Queue')
+ q.url = url
+ return q
+
+ def test_get_queue(self):
+ self.x.get_queue('foo', callback=self.callback)
+ self.assertTrue(self.x.get_list.called)
+ on_ready = self.x.get_list.call_args[1]['callback']
+ queues = [
+ self.MockQueue('/queues/bar'),
+ self.MockQueue('/queues/baz'),
+ self.MockQueue('/queues/foo'),
+ ]
+ on_ready(queues)
+ self.callback.assert_called_with(queues[-1])
+
+ self.x.get_list.assert_called_with(
+ 'ListQueues', {'QueueNamePrefix': 'foo'},
+ [('QueueUrl', AsyncQueue)],
+ callback=on_ready,
+ )
+
+ def test_get_dead_letter_source_queues(self):
+ queue = Mock(name='queue')
+ self.x.get_dead_letter_source_queues(queue, callback=self.callback)
+ self.x.get_list.assert_called_with(
+ 'ListDeadLetterSourceQueues', {'QueueUrl': queue.url},
+ [('QueueUrl', AsyncQueue)], callback=self.callback,
+ )
+
+ def test_add_permission(self):
+ queue = Mock(name='queue')
+ self.x.add_permission(
+ queue, 'label', 'accid', 'action', callback=self.callback,
+ )
+ self.x.get_status.assert_called_with(
+ 'AddPermission', {
+ 'Label': 'label',
+ 'AWSAccountId': 'accid',
+ 'ActionName': 'action',
+ },
+ queue.id, callback=self.callback,
+ )
+
+ def test_remove_permission(self):
+ queue = Mock(name='queue')
+ self.x.remove_permission(queue, 'label', callback=self.callback)
+ self.x.get_status.assert_called_with(
+ 'RemovePermission', {'Label': 'label'}, queue.id,
+ callback=self.callback,
+ )
diff --git a/kombu/tests/async/aws/sqs/test_message.py b/kombu/tests/async/aws/sqs/test_message.py
new file mode 100644
index 00000000..e7f32f02
--- /dev/null
+++ b/kombu/tests/async/aws/sqs/test_message.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from kombu.async.aws.sqs.message import AsyncMessage
+
+from kombu.tests.case import HubCase, PromiseMock, Mock
+from kombu.utils import uuid
+
+
+class test_AsyncMessage(HubCase):
+
+ def setup(self):
+ self.queue = Mock(name='queue')
+ self.callback = PromiseMock(name='callback')
+ self.x = AsyncMessage(self.queue, 'body')
+ self.x.receipt_handle = uuid()
+
+ def test_delete(self):
+ self.assertTrue(self.x.delete(callback=self.callback))
+ self.x.queue.delete_message.assert_called_with(
+ self.x, self.callback,
+ )
+
+ self.x.queue = None
+ self.assertIsNone(self.x.delete(callback=self.callback))
+
+ def test_change_visibility(self):
+ self.assertTrue(self.x.change_visibility(303, callback=self.callback))
+ self.x.queue.connection.change_message_visibility.assert_called_with(
+ self.x.queue, self.x.receipt_handle, 303, self.callback,
+ )
+ self.x.queue = None
+ self.assertIsNone(self.x.change_visibility(
+ 303, callback=self.callback,
+ ))
diff --git a/kombu/tests/async/aws/sqs/test_queue.py b/kombu/tests/async/aws/sqs/test_queue.py
new file mode 100644
index 00000000..417b09d3
--- /dev/null
+++ b/kombu/tests/async/aws/sqs/test_queue.py
@@ -0,0 +1,201 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from kombu.async.aws.sqs.message import AsyncMessage
+from kombu.async.aws.sqs.queue import AsyncQueue, BotoClientError
+
+from kombu.tests.case import HubCase, PromiseMock, Mock
+
+
+class test_AsyncQueue(HubCase):
+
+ def setup(self):
+ self.conn = Mock(name='connection')
+ self.x = AsyncQueue(self.conn, '/url')
+ self.callback = PromiseMock(name='callback')
+
+ def test_message_class(self):
+ self.assertTrue(issubclass(self.x.message_class, AsyncMessage))
+
+ def test_get_attributes(self):
+ self.x.get_attributes(attributes='QueueSize', callback=self.callback)
+ self.x.connection.get_queue_attributes.assert_called_with(
+ self.x, 'QueueSize', self.callback,
+ )
+
+ def test_set_attribute(self):
+ self.x.set_attribute('key', 'value', callback=self.callback)
+ self.x.connection.set_queue_attribute.assert_called_with(
+ self.x, 'key', 'value', self.callback,
+ )
+
+ def test_get_timeout(self):
+ self.x.get_timeout(callback=self.callback)
+ self.assertTrue(self.x.connection.get_queue_attributes.called)
+ on_ready = self.x.connection.get_queue_attributes.call_args[0][2]
+ self.x.connection.get_queue_attributes.assert_called_with(
+ self.x, 'VisibilityTimeout', on_ready,
+ )
+
+ on_ready({'VisibilityTimeout': '303'})
+ self.callback.assert_called_with(303)
+
+ def test_set_timeout(self):
+ self.x.set_timeout(808, callback=self.callback)
+ self.assertTrue(self.x.connection.set_queue_attribute.called)
+ on_ready = self.x.connection.set_queue_attribute.call_args[0][3]
+ self.x.connection.set_queue_attribute.assert_called_with(
+ self.x, 'VisibilityTimeout', 808, on_ready,
+ )
+ on_ready(808)
+ self.callback.assert_called_with(808)
+ self.assertEqual(self.x.visibility_timeout, 808)
+
+ on_ready(None)
+ self.assertEqual(self.x.visibility_timeout, 808)
+
+ def test_add_permission(self):
+ self.x.add_permission(
+ 'label', 'accid', 'action', callback=self.callback,
+ )
+ self.x.connection.add_permission.assert_called_with(
+ self.x, 'label', 'accid', 'action', self.callback,
+ )
+
+ def test_remove_permission(self):
+ self.x.remove_permission('label', callback=self.callback)
+ self.x.connection.remove_permission.assert_called_with(
+ self.x, 'label', self.callback,
+ )
+
+ def test_read(self):
+ self.x.read(visibility_timeout=909, callback=self.callback)
+ self.assertTrue(self.x.connection.receive_message.called)
+ on_ready = self.x.connection.receive_message.call_args[1]['callback']
+ self.x.connection.receive_message.assert_called_with(
+ self.x, number_messages=1, visibility_timeout=909,
+ attributes=None, wait_time_seconds=None, callback=on_ready,
+ )
+
+ messages = [Mock(name='message1')]
+ on_ready(messages)
+
+ self.callback.assert_called_with(messages[0])
+
+ def MockMessage(self, id, md5):
+ m = Mock(name='Message-{0}'.format(id))
+ m.id = id
+ m.md5 = md5
+ return m
+
+ def test_write(self):
+ message = self.MockMessage('id1', 'digest1')
+ self.x.write(message, delay_seconds=303, callback=self.callback)
+ self.assertTrue(self.x.connection.send_message.called)
+ on_ready = self.x.connection.send_message.call_args[1]['callback']
+ self.x.connection.send_message.assert_called_with(
+ self.x, message.get_body_encoded(), 303,
+ callback=on_ready,
+ )
+
+ new_message = self.MockMessage('id2', 'digest2')
+ on_ready(new_message)
+ self.assertEqual(message.id, 'id2')
+ self.assertEqual(message.md5, 'digest2')
+
+ def test_write_batch(self):
+ messages = [('id1', 'A', 0), ('id2', 'B', 303)]
+ self.x.write_batch(messages, callback=self.callback)
+ self.x.connection.send_message_batch.assert_called_with(
+ self.x, messages, callback=self.callback,
+ )
+
+ def test_delete_message(self):
+ message = self.MockMessage('id1', 'digest1')
+ self.x.delete_message(message, callback=self.callback)
+ self.x.connection.delete_message.assert_called_with(
+ self.x, message, self.callback,
+ )
+
+ def test_delete_message_batch(self):
+ messages = [
+ self.MockMessage('id1', 'r1'),
+ self.MockMessage('id2', 'r2'),
+ ]
+ self.x.delete_message_batch(messages, callback=self.callback)
+ self.x.connection.delete_message_batch.assert_called_with(
+ self.x, messages, callback=self.callback,
+ )
+
+ def test_change_message_visibility_batch(self):
+ messages = [
+ (self.MockMessage('id1', 'r1'), 303),
+ (self.MockMessage('id2', 'r2'), 909),
+ ]
+ self.x.change_message_visibility_batch(
+ messages, callback=self.callback,
+ )
+ self.x.connection.change_message_visibility_batch.assert_called_with(
+ self.x, messages, callback=self.callback,
+ )
+
+ def test_delete(self):
+ self.x.delete(callback=self.callback)
+ self.x.connection.delete_queue.assert_called_with(
+ self.x, callback=self.callback,
+ )
+
+ def test_count(self):
+ self.x.count(callback=self.callback)
+ self.assertTrue(self.x.connection.get_queue_attributes.called)
+ on_ready = self.x.connection.get_queue_attributes.call_args[0][2]
+ self.x.connection.get_queue_attributes.assert_called_with(
+ self.x, 'ApproximateNumberOfMessages', on_ready,
+ )
+
+ on_ready({'ApproximateNumberOfMessages': '909'})
+ self.callback.assert_called_with(909)
+
+ def test_interface__count_slow(self):
+ with self.assertRaises(BotoClientError):
+ self.x.count_slow()
+
+ def test_interface__dump(self):
+ with self.assertRaises(BotoClientError):
+ self.x.dump()
+
+ def test_interface__save_to_file(self):
+ with self.assertRaises(BotoClientError):
+ self.x.save_to_file()
+
+ def test_interface__save_to_filename(self):
+ with self.assertRaises(BotoClientError):
+ self.x.save_to_filename()
+
+ def test_interface__save(self):
+ with self.assertRaises(BotoClientError):
+ self.x.save()
+
+ def test_interface__save_to_s3(self):
+ with self.assertRaises(BotoClientError):
+ self.x.save_to_s3()
+
+ def test_interface__load_from_s3(self):
+ with self.assertRaises(BotoClientError):
+ self.x.load_from_s3()
+
+ def test_interface__load_from_file(self):
+ with self.assertRaises(BotoClientError):
+ self.x.load_from_file()
+
+ def test_interface__load_from_filename(self):
+ with self.assertRaises(BotoClientError):
+ self.x.load_from_filename()
+
+ def test_interface__load(self):
+ with self.assertRaises(BotoClientError):
+ self.x.load()
+
+ def test_interface__clear(self):
+ with self.assertRaises(BotoClientError):
+ self.x.clear()
diff --git a/kombu/tests/async/aws/sqs/test_sqs.py b/kombu/tests/async/aws/sqs/test_sqs.py
new file mode 100644
index 00000000..7625c402
--- /dev/null
+++ b/kombu/tests/async/aws/sqs/test_sqs.py
@@ -0,0 +1,25 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from kombu.async.aws.sqs import regions, connect_to_region
+from kombu.async.aws.sqs.connection import AsyncSQSConnection
+
+from kombu.tests.case import HubCase, Mock, patch
+
+
+class test_connect_to_region(HubCase):
+
+ def test_using_async_connection(self):
+ for region in regions():
+ self.assertIs(region.connection_cls, AsyncSQSConnection)
+
+ def test_connect_to_region(self):
+ with patch('kombu.async.aws.sqs.regions') as regions:
+ region = Mock(name='region')
+ region.name = 'us-west-1'
+ regions.return_value = [region]
+ conn = connect_to_region('us-west-1', kw=3.33)
+ self.assertIs(conn, region.connect.return_value)
+ region.connect.assert_called_with(kw=3.33)
+
+ self.assertIsNone(connect_to_region('foo'))
diff --git a/kombu/tests/async/aws/test_aws.py b/kombu/tests/async/aws/test_aws.py
new file mode 100644
index 00000000..4d7b4163
--- /dev/null
+++ b/kombu/tests/async/aws/test_aws.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from kombu.async.aws import connect_sqs
+
+from kombu.tests.case import HubCase
+
+
+class test_connect_sqs(HubCase):
+
+ def test_connection(self):
+ x = connect_sqs('AAKI', 'ASAK')
+ self.assertTrue(x)
+ self.assertTrue(x.connection)
diff --git a/kombu/tests/async/aws/test_connection.py b/kombu/tests/async/aws/test_connection.py
new file mode 100644
index 00000000..efca9695
--- /dev/null
+++ b/kombu/tests/async/aws/test_connection.py
@@ -0,0 +1,427 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+
+from contextlib import contextmanager
+
+from amqp.promise import Thenable
+
+from kombu.exceptions import HttpError
+from kombu.five import StringIO
+
+from kombu.async import http
+from kombu.async.aws.connection import (
+ AsyncHTTPConnection,
+ AsyncHTTPSConnection,
+ AsyncHTTPResponse,
+ AsyncConnection,
+ AsyncAWSAuthConnection,
+ AsyncAWSQueryConnection,
+)
+
+from kombu.tests.case import HubCase, PromiseMock, Mock, patch
+
+# Not currently working
+VALIDATES_CERT = False
+
+
+def passthrough(*args, **kwargs):
+ m = Mock(*args, **kwargs)
+
+ def side_effect(ret):
+ return ret
+ m.side_effect = side_effect
+ return m
+
+
+class test_AsyncHTTPConnection(HubCase):
+
+ def test_AsyncHTTPSConnection(self):
+ x = AsyncHTTPSConnection('aws.vandelay.com')
+ self.assertEqual(x.scheme, 'https')
+
+ def test_http_client(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ self.assertIs(x.http_client, http.get_client())
+ client = Mock(name='http_client')
+ y = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
+ self.assertIs(y.http_client, client)
+
+ def test_args(self):
+ x = AsyncHTTPConnection(
+ 'aws.vandelay.com', 8083, strict=True, timeout=33.3,
+ )
+ self.assertEqual(x.host, 'aws.vandelay.com')
+ self.assertEqual(x.port, 8083)
+ self.assertTrue(x.strict)
+ self.assertEqual(x.timeout, 33.3)
+ self.assertEqual(x.scheme, 'http')
+
+ def test_request(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ x.request('PUT', '/importer-exporter')
+ self.assertEqual(x.path, '/importer-exporter')
+ self.assertEqual(x.method, 'PUT')
+
+ def test_request_with_body_buffer(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ body = Mock(name='body')
+ body.read.return_value = 'Vandelay Industries'
+ x.request('PUT', '/importer-exporter', body)
+ self.assertEqual(x.method, 'PUT')
+ self.assertEqual(x.path, '/importer-exporter')
+ self.assertEqual(x.body, 'Vandelay Industries')
+ body.read.assert_called_with()
+
+ def test_request_with_body_text(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ x.request('PUT', '/importer-exporter', 'Vandelay Industries')
+ self.assertEqual(x.method, 'PUT')
+ self.assertEqual(x.path, '/importer-exporter')
+ self.assertEqual(x.body, 'Vandelay Industries')
+
+ def test_request_with_headers(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ headers = {'Proxy': 'proxy.vandelay.com'}
+ x.request('PUT', '/importer-exporter', None, headers)
+ self.assertIn('Proxy', dict(x.headers))
+ self.assertEqual(dict(x.headers)['Proxy'], 'proxy.vandelay.com')
+
+ def assertRequestCreatedWith(self, url, conn):
+ conn.Request.assert_called_with(
+ url, method=conn.method,
+ headers=http.Headers(conn.headers), body=conn.body,
+ connect_timeout=conn.timeout, request_timeout=conn.timeout,
+ validate_cert=VALIDATES_CERT,
+ )
+
+ def test_getrequest_AsyncHTTPSConnection(self):
+ x = AsyncHTTPSConnection('aws.vandelay.com')
+ x.Request = Mock(name='Request')
+ x.getrequest()
+ self.assertRequestCreatedWith('https://aws.vandelay.com/', x)
+
+ def test_getrequest_nondefault_port(self):
+ x = AsyncHTTPConnection('aws.vandelay.com', port=8080)
+ x.Request = Mock(name='Request')
+ x.getrequest()
+ self.assertRequestCreatedWith('http://aws.vandelay.com:8080/', x)
+
+ y = AsyncHTTPSConnection('aws.vandelay.com', port=8443)
+ y.Request = Mock(name='Request')
+ y.getrequest()
+ self.assertRequestCreatedWith('https://aws.vandelay.com:8443/', y)
+
+ def test_getresponse(self):
+ client = Mock(name='client')
+ client.add_request = passthrough(name='client.add_request')
+ x = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
+ x.Response = Mock(name='x.Response')
+ request = x.getresponse()
+ x.http_client.add_request.assert_called_with(request)
+ self.assertIsInstance(request, Thenable)
+ self.assertIsInstance(request.on_ready, Thenable)
+
+ response = Mock(name='Response')
+ request.on_ready(response)
+ x.Response.assert_called_with(response)
+
+ def test_getresponse__real_response(self):
+ client = Mock(name='client')
+ client.add_request = passthrough(name='client.add_request')
+ callback = PromiseMock(name='callback')
+ x = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
+ request = x.getresponse(callback)
+ x.http_client.add_request.assert_called_with(request)
+
+ buf = StringIO()
+ buf.write('The quick brown fox jumps')
+
+ headers = http.Headers({'X-Foo': 'Hello', 'X-Bar': 'World'})
+
+ response = http.Response(request, 200, headers, buf)
+ request.on_ready(response)
+ self.assertTrue(callback.called)
+ wresponse = callback.call_args[0][0]
+
+ self.assertEqual(wresponse.read(), 'The quick brown fox jumps')
+ self.assertEqual(wresponse.status, 200)
+ self.assertEqual(wresponse.getheader('X-Foo'), 'Hello')
+ self.assertDictEqual(dict(wresponse.getheaders()), headers)
+ self.assertTrue(wresponse.msg)
+ self.assertTrue(wresponse.msg)
+ self.assertTrue(repr(wresponse))
+
+ def test_repr(self):
+ self.assertTrue(repr(AsyncHTTPConnection('aws.vandelay.com')))
+
+ def test_putrequest(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ x.putrequest('UPLOAD', '/new')
+ self.assertEqual(x.method, 'UPLOAD')
+ self.assertEqual(x.path, '/new')
+
+ def test_putheader(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ x.putheader('X-Foo', 'bar')
+ self.assertListEqual(x.headers, [('X-Foo', 'bar')])
+ x.putheader('X-Bar', 'baz')
+ self.assertListEqual(x.headers, [
+ ('X-Foo', 'bar'),
+ ('X-Bar', 'baz'),
+ ])
+
+ def test_send(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ x.send('foo')
+ self.assertEqual(x.body, 'foo')
+ x.send('bar')
+ self.assertEqual(x.body, 'foobar')
+
+ def test_interface(self):
+ x = AsyncHTTPConnection('aws.vandelay.com')
+ self.assertIsNone(x.set_debuglevel(3))
+ self.assertIsNone(x.connect())
+ self.assertIsNone(x.close())
+ self.assertIsNone(x.endheaders())
+
+
+class test_AsyncHTTPResponse(HubCase):
+
+ def test_with_error(self):
+ r = Mock(name='response')
+ r.error = HttpError(404, 'NotFound')
+ x = AsyncHTTPResponse(r)
+ self.assertEqual(x.reason, 'NotFound')
+
+ r.error = None
+ self.assertFalse(x.reason)
+
+
+class test_AsyncConnection(HubCase):
+
+ def test_client(self):
+ x = AsyncConnection()
+ self.assertIs(x._httpclient, http.get_client())
+ client = Mock(name='client')
+ y = AsyncConnection(http_client=client)
+ self.assertIs(y._httpclient, client)
+
+ def test_get_http_connection(self):
+ x = AsyncConnection(client=Mock(name='client'))
+ self.assertIsInstance(
+ x.get_http_connection('aws.vandelay.com', 80, False),
+ AsyncHTTPConnection,
+ )
+ self.assertIsInstance(
+ x.get_http_connection('aws.vandelay.com', 443, True),
+ AsyncHTTPSConnection,
+ )
+
+ conn = x.get_http_connection('aws.vandelay.com', 80, False)
+ self.assertIs(conn.http_client, x._httpclient)
+ self.assertEqual(conn.host, 'aws.vandelay.com')
+ self.assertEqual(conn.port, 80)
+
+
+class test_AsyncAWSAuthConnection(HubCase):
+
+ @patch('boto.log', create=True)
+ def test_make_request(self, _):
+ x = AsyncAWSAuthConnection('aws.vandelay.com',
+ http_client=Mock(name='client'))
+ Conn = x.get_http_connection = Mock(name='get_http_connection')
+ callback = PromiseMock(name='callback')
+ ret = x.make_request('GET', '/foo', callback=callback)
+ self.assertIs(ret, callback)
+ self.assertTrue(Conn.return_value.request.called)
+ Conn.return_value.getresponse.assert_called_with(
+ callback=callback,
+ )
+
+ @patch('boto.log', create=True)
+ def test_mexe(self, _):
+ x = AsyncAWSAuthConnection('aws.vandelay.com',
+ http_client=Mock(name='client'))
+ Conn = x.get_http_connection = Mock(name='get_http_connection')
+ request = x.build_base_http_request('GET', 'foo', '/auth')
+ callback = PromiseMock(name='callback')
+ x._mexe(request, callback=callback)
+ Conn.return_value.request.assert_called_with(
+ request.method, request.path, request.body, request.headers,
+ )
+ Conn.return_value.getresponse.assert_called_with(
+ callback=callback,
+ )
+
+ no_callback_ret = x._mexe(request)
+ self.assertIsInstance(
+ no_callback_ret, Thenable, '_mexe always returns promise',
+ )
+
+ @patch('boto.log', create=True)
+ def test_mexe__with_sender(self, _):
+ x = AsyncAWSAuthConnection('aws.vandelay.com',
+ http_client=Mock(name='client'))
+ Conn = x.get_http_connection = Mock(name='get_http_connection')
+ request = x.build_base_http_request('GET', 'foo', '/auth')
+ sender = Mock(name='sender')
+ callback = PromiseMock(name='callback')
+ x._mexe(request, sender=sender, callback=callback)
+ sender.assert_called_with(
+ Conn.return_value, request.method, request.path,
+ request.body, request.headers, callback,
+ )
+
+
+class test_AsyncAWSQueryConnection(HubCase):
+
+ def setup(self):
+ self.x = AsyncAWSQueryConnection('aws.vandelay.com',
+ http_client=Mock(name='client'))
+
+ @patch('boto.log', create=True)
+ def test_make_request(self, _):
+ _mexe, self.x._mexe = self.x._mexe, Mock(name='_mexe')
+ Conn = self.x.get_http_connection = Mock(name='get_http_connection')
+ callback = PromiseMock(name='callback')
+ self.x.make_request(
+ 'action', {'foo': 1}, '/', 'GET', callback=callback,
+ )
+ self.assertTrue(self.x._mexe.called)
+ request = self.x._mexe.call_args[0][0]
+ self.assertEqual(request.params['Action'], 'action')
+ self.assertEqual(request.params['Version'], self.x.APIVersion)
+
+ ret = _mexe(request, callback=callback)
+ self.assertIs(ret, callback)
+ self.assertTrue(Conn.return_value.request.called)
+ Conn.return_value.getresponse.assert_called_with(
+ callback=callback,
+ )
+
+ @patch('boto.log', create=True)
+ def test_make_request__no_action(self, _):
+ self.x._mexe = Mock(name='_mexe')
+ self.x.get_http_connection = Mock(name='get_http_connection')
+ callback = PromiseMock(name='callback')
+ self.x.make_request(
+ None, {'foo': 1}, '/', 'GET', callback=callback,
+ )
+ self.assertTrue(self.x._mexe.called)
+ request = self.x._mexe.call_args[0][0]
+ self.assertNotIn('Action', request.params)
+ self.assertEqual(request.params['Version'], self.x.APIVersion)
+
+ @contextmanager
+ def mock_sax_parse(self, parser):
+ with patch('kombu.async.aws.connection.sax_parse') as sax_parse:
+ with patch('kombu.async.aws.connection.XmlHandler') as xh:
+
+ def effect(body, h):
+ return parser(xh.call_args[0][0], body, h)
+ sax_parse.side_effect = effect
+ yield (sax_parse, xh)
+ self.assertTrue(sax_parse.called)
+
+ def Response(self, status, body):
+ r = Mock(name='response')
+ r.status = status
+ r.read.return_value = body
+ return r
+
+ @contextmanager
+ def mock_make_request(self):
+ self.x.make_request = Mock(name='make_request')
+ callback = PromiseMock(name='callback')
+ yield callback
+
+ def assert_make_request_called(self):
+ self.assertTrue(self.x.make_request.called)
+ return self.x.make_request.call_args[1]['callback']
+
+ def test_get_list(self):
+ with self.mock_make_request() as callback:
+ self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback)
+ on_ready = self.assert_make_request_called()
+
+ def parser(dest, body, h):
+ dest.append('hi')
+ dest.append('there')
+
+ with self.mock_sax_parse(parser):
+ on_ready(self.Response(200, 'hello'))
+ self.assertTrue(callback.called_with(['hi', 'there']))
+
+ def test_get_list_error(self):
+ with self.mock_make_request() as callback:
+ self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback)
+ on_ready = self.assert_make_request_called()
+
+ with self.assertRaises(self.x.ResponseError):
+ on_ready(self.Response(404, 'Not found'))
+
+ def test_get_object(self):
+ with self.mock_make_request() as callback:
+
+ class Result(object):
+ parent = None
+ value = None
+
+ def __init__(self, parent):
+ self.parent = parent
+
+ self.x.get_object('action', {'p': 3.3}, Result, callback=callback)
+ on_ready = self.assert_make_request_called()
+
+ def parser(dest, body, h):
+ dest.value = 42
+
+ with self.mock_sax_parse(parser):
+ on_ready(self.Response(200, 'hello'))
+
+ self.assertTrue(callback.called)
+ result = callback.call_args[0][0]
+ self.assertEqual(result.value, 42)
+ self.assertTrue(result.parent)
+
+ def test_get_object_error(self):
+ with self.mock_make_request() as callback:
+ self.x.get_object('action', {'p': 3.3}, object, callback=callback)
+ on_ready = self.assert_make_request_called()
+
+ with self.assertRaises(self.x.ResponseError):
+ on_ready(self.Response(404, 'Not found'))
+
+ def test_get_status(self):
+ with self.mock_make_request() as callback:
+ self.x.get_status('action', {'p': 3.3}, callback=callback)
+ on_ready = self.assert_make_request_called()
+ set_status_to = [True]
+
+ def parser(dest, body, b):
+ dest.status = set_status_to[0]
+
+ with self.mock_sax_parse(parser):
+ on_ready(self.Response(200, 'hello'))
+ callback.assert_called_with(True)
+
+ set_status_to[0] = False
+ with self.mock_sax_parse(parser):
+ on_ready(self.Response(200, 'hello'))
+ callback.assert_called_with(False)
+
+ def test_get_status_error(self):
+ with self.mock_make_request() as callback:
+ self.x.get_status('action', {'p': 3.3}, callback=callback)
+ on_ready = self.assert_make_request_called()
+
+ with self.assertRaises(self.x.ResponseError):
+ on_ready(self.Response(404, 'Not found'))
+
+ def test_get_status_error_empty_body(self):
+ with self.mock_make_request() as callback:
+ self.x.get_status('action', {'p': 3.3}, callback=callback)
+ on_ready = self.assert_make_request_called()
+
+ with self.assertRaises(self.x.ResponseError):
+ on_ready(self.Response(200, ''))
diff --git a/kombu/tests/async/test_http.py b/kombu/tests/async/test_http.py
index 42da11a5..4566ab61 100644
--- a/kombu/tests/async/test_http.py
+++ b/kombu/tests/async/test_http.py
@@ -3,22 +3,21 @@ from __future__ import absolute_import
from io import BytesIO
from amqp import promise
-from kombu.async import Hub
from kombu.async import http
-from kombu.async.http.base import BaseClient, normalize_header, header_parser
+from kombu.async.http.base import BaseClient, normalize_header
from kombu.exceptions import HttpError
-from kombu.tests.case import Case, Mock
+from kombu.tests.case import HubCase, Mock, PromiseMock
-class test_Headers(Case):
+class test_Headers(HubCase):
def test_normalize(self):
self.assertEqual(normalize_header('accept-encoding'),
'Accept-Encoding')
-class test_Request(Case):
+class test_Request(HubCase):
def test_init(self):
x = http.Request('http://foo', method='POST')
@@ -35,7 +34,7 @@ class test_Request(Case):
self.assertIsInstance(x.on_ready, promise)
def test_then(self):
- callback = Mock()
+ callback = PromiseMock(name='callback')
x = http.Request('http://foo')
x.then(callback)
@@ -43,7 +42,7 @@ class test_Request(Case):
callback.assert_called_with(1)
-class test_Response(Case):
+class test_Response(HubCase):
def test_init(self):
req = http.Request('http://foo')
@@ -77,7 +76,7 @@ class test_Response(Case):
self.assertEqual(r.body, b'hello')
-class test_BaseClient(Case):
+class test_BaseClient(HubCase):
def test_init(self):
c = BaseClient(Mock(name='hub'))
@@ -138,45 +137,11 @@ class test_BaseClient(Case):
c.close.assert_called_with()
-class test_Client(Case):
-
- def test_get_request(self):
- hub = Hub()
- callback = Mock(name='callback')
-
- def on_ready(response):
- pass #print('{0.effective_url} -> {0.code}'.format(response))
- requests = []
- for i in range(1000):
- requests.extend([
- http.Request(
- 'http://localhost:8000/README.rst',
- on_ready=promise(on_ready, callback=callback),
- ),
- http.Request(
- 'http://localhost:8000/AUTHORS',
- on_ready=promise(on_ready, callback=callback),
- ),
- http.Request(
- 'http://localhost:8000/pavement.py',
- on_ready=promise(on_ready, callback=callback),
- ),
- http.Request(
- 'http://localhost:8000/setup.py',
- on_ready=promise(on_ready, callback=callback),
- ),
- http.Request(
- 'http://localhost:8000/setup.py%s' % (i, ),
- on_ready=promise(on_ready, callback=callback),
- ),
- ])
- client = http.Client(hub)
- for request in requests:
- client.perform(request)
-
- from kombu.five import monotonic
- start_time = monotonic()
- print('START PERFORM')
- while callback.call_count < len(requests):
- hub.run_once()
- print('-END PERFORM: %r' % (monotonic() - start_time))
+class test_Client(HubCase):
+
+ def test_get_client(self):
+ client = http.get_client()
+ self.assertIs(client.hub, self.hub)
+ client2 = http.get_client(self.hub)
+ self.assertIs(client2, client)
+ self.assertIs(client2.hub, self.hub)
diff --git a/kombu/tests/async/test_hub.py b/kombu/tests/async/test_hub.py
index 7d5d81cd..9476ac48 100644
--- a/kombu/tests/async/test_hub.py
+++ b/kombu/tests/async/test_hub.py
@@ -8,10 +8,10 @@ from kombu.tests.case import Case
class test_Utils(Case):
- def setUp(self):
+ def setup(self):
self._prev_loop = get_event_loop()
- def tearDown(self):
+ def teardown(self):
set_event_loop(self._prev_loop)
def test_get_set_event_loop(self):
@@ -26,8 +26,8 @@ class test_Utils(Case):
class test_Hub(Case):
- def setUp(self):
+ def setup(self):
self.hub = Hub()
- def tearDown(self):
+ def teardown(self):
self.hub.close()
diff --git a/kombu/tests/case.py b/kombu/tests/case.py
index e8b6d32b..88a18acb 100644
--- a/kombu/tests/case.py
+++ b/kombu/tests/case.py
@@ -27,10 +27,53 @@ call = mock.call
class Case(unittest.TestCase):
+ def setUp(self):
+ self.setup()
+
+ def tearDown(self):
+ self.teardown()
+
def assertItemsEqual(self, a, b, *args, **kwargs):
return self.assertEqual(sorted(a), sorted(b), *args, **kwargs)
assertSameElements = assertItemsEqual
+ def setup(self):
+ pass
+
+ def teardown(self):
+ pass
+
+
+def PromiseMock(*args, **kwargs):
+ m = Mock(*args, **kwargs)
+
+ def on_throw(exc=None, *args, **kwargs):
+ if exc:
+ raise exc
+ raise
+ m.throw.side_effect = on_throw
+ m.set_error_state.side_effect = on_throw
+ m.throw1.side_effect = on_throw
+ return m
+
+
+class HubCase(Case):
+
+ def setUp(self):
+ from kombu.async import Hub, get_event_loop, set_event_loop
+ self._prev_hub = get_event_loop()
+ self.hub = Hub()
+ set_event_loop(self.hub)
+ super(HubCase, self).setUp()
+
+ def tearDown(self):
+ try:
+ super(HubCase, self).tearDown()
+ finally:
+ from kombu.async import set_event_loop
+ if self._prev_hub is not None:
+ set_event_loop(self._prev_hub)
+
class Mock(mock.Mock):
diff --git a/kombu/tests/test_compat.py b/kombu/tests/test_compat.py
index b081cf0c..1601714e 100644
--- a/kombu/tests/test_compat.py
+++ b/kombu/tests/test_compat.py
@@ -76,7 +76,7 @@ class test_misc(Case):
class test_Publisher(Case):
- def setUp(self):
+ def setup(self):
self.connection = Connection(transport=Transport)
def test_constructor(self):
@@ -127,7 +127,7 @@ class test_Publisher(Case):
class test_Consumer(Case):
- def setUp(self):
+ def setup(self):
self.connection = Connection(transport=Transport)
@patch('kombu.compat._iterconsume')
@@ -261,7 +261,7 @@ class test_Consumer(Case):
class test_ConsumerSet(Case):
- def setUp(self):
+ def setup(self):
self.connection = Connection(transport=Transport)
def test_providing_channel(self):
diff --git a/kombu/tests/test_compression.py b/kombu/tests/test_compression.py
index 7d651ee2..21d4cf1e 100644
--- a/kombu/tests/test_compression.py
+++ b/kombu/tests/test_compression.py
@@ -9,7 +9,7 @@ from .case import Case, SkipTest, mask_modules
class test_compression(Case):
- def setUp(self):
+ def setup(self):
try:
import bz2 # noqa
except ImportError:
diff --git a/kombu/tests/test_connection.py b/kombu/tests/test_connection.py
index 58790db0..fdc78f02 100644
--- a/kombu/tests/test_connection.py
+++ b/kombu/tests/test_connection.py
@@ -15,7 +15,7 @@ from .mocks import Transport
class test_connection_utils(Case):
- def setUp(self):
+ def setup(self):
self.url = 'amqp://user:pass@localhost:5672/my/vhost'
self.nopass = 'amqp://user@localhost:5672/my/vhost'
self.expected = {
@@ -144,7 +144,7 @@ class test_connection_utils(Case):
class test_Connection(Case):
- def setUp(self):
+ def setup(self):
self.conn = Connection(port=5672, transport=Transport)
def test_establish_connection(self):
@@ -261,12 +261,12 @@ class test_Connection(Case):
def test_supports_heartbeats(self):
c = Connection(transport=Mock)
- c.transport.supports_heartbeats = False
+ c.transport.implements.heartbeats = False
self.assertFalse(c.supports_heartbeats)
def test_is_evented(self):
c = Connection(transport=Mock)
- c.transport.supports_ev = False
+ c.transport.implements.async = False
self.assertFalse(c.is_evented)
def test_register_with_event_loop(self):
@@ -491,7 +491,7 @@ class test_Connection_with_transport_options(Case):
transport_options = {'pool_recycler': 3600, 'echo': True}
- def setUp(self):
+ def setup(self):
self.conn = Connection(port=5672, transport=Transport,
transport_options=self.transport_options)
diff --git a/kombu/tests/test_entities.py b/kombu/tests/test_entities.py
index 165160f1..317614c6 100644
--- a/kombu/tests/test_entities.py
+++ b/kombu/tests/test_entities.py
@@ -186,7 +186,7 @@ class test_Exchange(Case):
class test_Queue(Case):
- def setUp(self):
+ def setup(self):
self.exchange = Exchange('foo', 'direct')
def test_hash(self):
diff --git a/kombu/tests/test_log.py b/kombu/tests/test_log.py
index bd1242c0..c3e91730 100644
--- a/kombu/tests/test_log.py
+++ b/kombu/tests/test_log.py
@@ -63,7 +63,7 @@ class test_safe_format(Case):
class test_LogMixin(Case):
- def setUp(self):
+ def setup(self):
self.log = Log('Log', Mock())
self.logger = self.log.logger
diff --git a/kombu/tests/test_messaging.py b/kombu/tests/test_messaging.py
index c9573c2c..4f6c8411 100644
--- a/kombu/tests/test_messaging.py
+++ b/kombu/tests/test_messaging.py
@@ -15,7 +15,7 @@ from .mocks import Transport
class test_Producer(Case):
- def setUp(self):
+ def setup(self):
self.exchange = Exchange('foo', 'direct')
self.connection = Connection(transport=Transport)
self.connection.connect()
@@ -218,7 +218,7 @@ class test_Producer(Case):
class test_Consumer(Case):
- def setUp(self):
+ def setup(self):
self.connection = Connection(transport=Transport)
self.connection.connect()
self.assertTrue(self.connection.connection.connected)
diff --git a/kombu/tests/test_mixins.py b/kombu/tests/test_mixins.py
index b80f0131..6f868b9e 100644
--- a/kombu/tests/test_mixins.py
+++ b/kombu/tests/test_mixins.py
@@ -110,7 +110,7 @@ class test_ConsumerMixin(Case):
class test_ConsumerMixin_interface(Case):
- def setUp(self):
+ def setup(self):
self.c = ConsumerMixin()
def test_get_consumers(self):
diff --git a/kombu/tests/test_pidbox.py b/kombu/tests/test_pidbox.py
index 357de656..d87e9fcd 100644
--- a/kombu/tests/test_pidbox.py
+++ b/kombu/tests/test_pidbox.py
@@ -16,7 +16,7 @@ class test_Mailbox(Case):
def _handler(self, state):
return self.stats['var']
- def setUp(self):
+ def setup(self):
class Mailbox(pidbox.Mailbox):
diff --git a/kombu/tests/test_pools.py b/kombu/tests/test_pools.py
index 920c65a7..041c21c3 100644
--- a/kombu/tests/test_pools.py
+++ b/kombu/tests/test_pools.py
@@ -20,7 +20,7 @@ class test_ProducerPool(Case):
def Producer(self, connection):
return self.instance
- def setUp(self):
+ def setup(self):
self.connections = Mock()
self.pool = self.Pool(self.connections, limit=10)
diff --git a/kombu/tests/test_simple.py b/kombu/tests/test_simple.py
index 53a4ac38..f14b86e4 100644
--- a/kombu/tests/test_simple.py
+++ b/kombu/tests/test_simple.py
@@ -19,14 +19,14 @@ class SimpleBase(Case):
def _Queue(self, *args, **kwargs):
raise NotImplementedError()
- def setUp(self):
+ def setup(self):
if not self.abstract:
self.connection = Connection(transport='memory')
with self.connection.channel() as channel:
channel.exchange_declare('amq.direct')
self.q = self.Queue(None, no_ack=True)
- def tearDown(self):
+ def teardown(self):
if not self.abstract:
self.q.close()
self.connection.close()
diff --git a/kombu/tests/transport/test_SQS.py b/kombu/tests/transport/test_SQS.py
index e4efb53f..0f596efe 100644
--- a/kombu/tests/transport/test_SQS.py
+++ b/kombu/tests/transport/test_SQS.py
@@ -18,7 +18,7 @@ try:
except ImportError:
# Boto must not be installed if the SQS transport fails to import,
# so we skip all unit tests. Set SQS to None here, and it will be
- # checked during the setUp() phase later.
+ # checked during the setup() phase later.
SQS = None
@@ -93,7 +93,7 @@ class test_Channel(Case):
def handleMessageCallback(self, message):
self.callback_message = message
- def setUp(self):
+ def setup(self):
"""Mock the back-end SQS classes"""
# Sanity check... if SQS is None, then it did not import and we
# cannot execute our tests.
@@ -135,7 +135,7 @@ class test_Channel(Case):
# Lastly, make sure that we're set up to 'consume' this queue.
self.channel.basic_consume(self.queue_name,
- no_ack=True,
+ no_ack=False,
callback=self.handleMessageCallback,
consumer_tag='unittest')
@@ -160,15 +160,16 @@ class test_Channel(Case):
# Test getting a single message
message = 'my test message'
self.producer.publish(message)
- results = self.channel._get_from_sqs(self.queue_name)
+ q = self.channel._new_queue(self.queue_name)
+ results = q.get_messages()
self.assertEquals(len(results), 1)
# Now test getting many messages
- for i in xrange(3):
+ for i in range(3):
message = 'message: {0}'.format(i)
self.producer.publish(message)
- results = self.channel._get_from_sqs(self.queue_name, count=3)
+ results = q.get_messages(num_messages=3)
self.assertEquals(len(results), 3)
def test_get_with_empty_list(self):
@@ -186,10 +187,9 @@ class test_Channel(Case):
message = 'message: %s' % i
self.producer.publish(message)
+ q = self.channel._new_queue(self.queue_name)
# Get the messages now
- messages = self.channel._get_from_sqs(
- self.queue_name, count=message_count,
- )
+ messages = q.get_messages(num_messages=message_count)
# Now convert them to payloads
payloads = self.channel._messages_to_python(
@@ -209,15 +209,6 @@ class test_Channel(Case):
results = self.queue(self.channel).get().payload
self.assertEquals(message, results)
- def test_puts_and_gets(self):
- for i in xrange(3):
- message = 'message: %s' % i
- self.producer.publish(message)
-
- for i in xrange(3):
- self.assertEquals('message: %s' % i,
- self.queue(self.channel).get().payload)
-
def test_put_and_get_bulk(self):
# With QoS.prefetch_count = 0
message = 'my test message'
@@ -233,7 +224,7 @@ class test_Channel(Case):
self.channel.qos.prefetch_count = 5
# Now, generate all the messages
- for i in xrange(message_count):
+ for i in range(message_count):
message = 'message: %s' % i
self.producer.publish(message)
@@ -241,10 +232,12 @@ class test_Channel(Case):
# be 5 (message_count).
results = self.channel._get_bulk(self.queue_name)
self.assertEquals(5, len(results))
+ for i, r in enumerate(results):
+ self.channel.qos.append(r, i)
- # Now, do the get again, the number of messages returned should be 3.
+ # Now, do the get again, the number of messages returned should be 1.
results = self.channel._get_bulk(self.queue_name)
- self.assertEquals(3, len(results))
+ self.assertEquals(len(results), 1)
def test_drain_events_with_empty_list(self):
def mock_can_consume():
@@ -262,11 +255,11 @@ class test_Channel(Case):
self.channel.qos.prefetch_count = 5
# Now, generate all the messages
- for i in xrange(message_count):
+ for i in range(message_count):
self.producer.publish('message: %s' % i)
# Now drain all the events
- for i in xrange(message_count):
+ for i in range(message_count):
self.channel.drain_events()
# How many times was the SQSConnectionMock get_message method called?
@@ -283,11 +276,11 @@ class test_Channel(Case):
self.channel.qos.prefetch_count = None
# Now, generate all the messages
- for i in xrange(message_count):
+ for i in range(message_count):
self.producer.publish('message: %s' % i)
# Now drain all the events
- for i in xrange(message_count):
+ for i in range(message_count):
self.channel.drain_events()
# How many times was the SQSConnectionMock get_message method called?
diff --git a/kombu/tests/transport/test_amqplib.py b/kombu/tests/transport/test_amqplib.py
index cf7d6150..39f94d28 100644
--- a/kombu/tests/transport/test_amqplib.py
+++ b/kombu/tests/transport/test_amqplib.py
@@ -37,10 +37,7 @@ class amqplibCase(Case):
def setUp(self):
if amqplib is None:
raise SkipTest('amqplib not installed')
- self.setup()
-
- def setup(self):
- pass
+ super(amqplibCase, self).setUp()
class test_Channel(amqplibCase):
diff --git a/kombu/tests/transport/test_base.py b/kombu/tests/transport/test_base.py
index 5c4a50d5..ac90e51f 100644
--- a/kombu/tests/transport/test_base.py
+++ b/kombu/tests/transport/test_base.py
@@ -10,7 +10,7 @@ from kombu.tests.case import Case, Mock
class test_StdChannel(Case):
- def setUp(self):
+ def setup(self):
self.conn = Connection('memory://')
self.channel = self.conn.channel()
self.channel.queues.clear()
@@ -40,7 +40,7 @@ class test_StdChannel(Case):
class test_Message(Case):
- def setUp(self):
+ def setup(self):
self.conn = Connection('memory://')
self.channel = self.conn.channel()
self.message = Message(self.channel, delivery_tag=313)
@@ -134,7 +134,14 @@ class test_interface(Case):
self.assertTrue(Transport(None).driver_version())
def test_register_with_event_loop(self):
- Transport(None).register_with_event_loop(Mock(name='loop'))
+ Transport(None).register_with_event_loop(
+ Mock(name='connection'), Mock(name='loop'),
+ )
+
+ def test_unregister_from_event_loop(self):
+ Transport(None).unregister_from_event_loop(
+ Mock(name='connection'), Mock(name='loop'),
+ )
def test_manager(self):
self.assertTrue(Transport(None).manager)
diff --git a/kombu/tests/transport/test_filesystem.py b/kombu/tests/transport/test_filesystem.py
index 0649a8d0..1e2a37cd 100644
--- a/kombu/tests/transport/test_filesystem.py
+++ b/kombu/tests/transport/test_filesystem.py
@@ -10,7 +10,7 @@ from kombu.tests.case import Case, SkipTest
class test_FilesystemTransport(Case):
- def setUp(self):
+ def setup(self):
if sys.platform == 'win32':
raise SkipTest('Needs win32con module')
try:
diff --git a/kombu/tests/transport/test_librabbitmq.py b/kombu/tests/transport/test_librabbitmq.py
index a50b2624..7625bbc3 100644
--- a/kombu/tests/transport/test_librabbitmq.py
+++ b/kombu/tests/transport/test_librabbitmq.py
@@ -12,7 +12,7 @@ from kombu.tests.case import Case, Mock, SkipTest, patch
class lrmqCase(Case):
- def setUp(self):
+ def setup(self):
if librabbitmq is None:
raise SkipTest('librabbitmq is not installed')
@@ -61,8 +61,8 @@ class test_Channel(lrmqCase):
class test_Transport(lrmqCase):
- def setUp(self):
- super(test_Transport, self).setUp()
+ def setup(self):
+ super(test_Transport, self).setup()
self.client = Mock(name='client')
self.T = librabbitmq.Transport(self.client)
diff --git a/kombu/tests/transport/test_memory.py b/kombu/tests/transport/test_memory.py
index 605527f4..c052aa7b 100644
--- a/kombu/tests/transport/test_memory.py
+++ b/kombu/tests/transport/test_memory.py
@@ -9,7 +9,7 @@ from kombu.tests.case import Case
class test_MemoryTransport(Case):
- def setUp(self):
+ def setup(self):
self.c = Connection(transport='memory')
self.e = Exchange('test_transport_memory')
self.q = Queue('test_transport_memory',
diff --git a/kombu/tests/transport/test_pyamqp.py b/kombu/tests/transport/test_pyamqp.py
index d6a910b4..da557338 100644
--- a/kombu/tests/transport/test_pyamqp.py
+++ b/kombu/tests/transport/test_pyamqp.py
@@ -24,7 +24,7 @@ class MockConnection(dict):
class test_Channel(Case):
- def setUp(self):
+ def setup(self):
if pyamqp is None:
raise SkipTest('py-amqp not installed')
@@ -80,7 +80,7 @@ class test_Channel(Case):
class test_Transport(Case):
- def setUp(self):
+ def setup(self):
if pyamqp is None:
raise SkipTest('py-amqp not installed')
self.connection = Connection('pyamqp://')
@@ -136,7 +136,7 @@ class test_Transport(Case):
class test_pyamqp(Case):
- def setUp(self):
+ def setup(self):
if pyamqp is None:
raise SkipTest('py-amqp not installed')
diff --git a/kombu/tests/transport/test_redis.py b/kombu/tests/transport/test_redis.py
index 48f7c6be..285f6004 100644
--- a/kombu/tests/transport/test_redis.py
+++ b/kombu/tests/transport/test_redis.py
@@ -220,7 +220,7 @@ class Transport(redis.Transport):
class test_Channel(Case):
- def setUp(self):
+ def setup(self):
self.connection = self.create_connection()
self.channel = self.connection.default_channel
@@ -785,12 +785,12 @@ class test_Channel(Case):
class test_Redis(Case):
- def setUp(self):
+ def setup(self):
self.connection = Connection(transport=Transport)
self.exchange = Exchange('test_Redis', type='direct')
self.queue = Queue('test_Redis', self.exchange, 'test_Redis')
- def tearDown(self):
+ def teardown(self):
self.connection.close()
def test_publish__get(self):
@@ -941,7 +941,7 @@ def _redis_modules():
class test_MultiChannelPoller(Case):
- def setUp(self):
+ def setup(self):
self.Poller = redis.MultiChannelPoller
def test_on_poll_start(self):
diff --git a/kombu/tests/transport/test_sqlalchemy.py b/kombu/tests/transport/test_sqlalchemy.py
index 07055999..13cbd309 100644
--- a/kombu/tests/transport/test_sqlalchemy.py
+++ b/kombu/tests/transport/test_sqlalchemy.py
@@ -6,7 +6,7 @@ from kombu.tests.case import Case, SkipTest, patch
class test_sqlalchemy(Case):
- def setUp(self):
+ def setup(self):
try:
import sqlalchemy # noqa
except ImportError:
diff --git a/kombu/tests/transport/virtual/test_base.py b/kombu/tests/transport/virtual/test_base.py
index d249c4e7..df17d918 100644
--- a/kombu/tests/transport/virtual/test_base.py
+++ b/kombu/tests/transport/virtual/test_base.py
@@ -33,10 +33,10 @@ class test_BrokerState(Case):
class test_QoS(Case):
- def setUp(self):
+ def setup(self):
self.q = virtual.QoS(client().channel(), prefetch_count=10)
- def tearDown(self):
+ def teardown(self):
self.q._on_collect.cancel()
def test_constructor(self):
@@ -174,10 +174,10 @@ class test_AbstractChannel(Case):
class test_Channel(Case):
- def setUp(self):
+ def setup(self):
self.channel = client().channel()
- def tearDown(self):
+ def teardown(self):
if self.channel._qos is not None:
self.channel._qos._on_collect.cancel()
@@ -518,7 +518,7 @@ class test_Channel(Case):
class test_Transport(Case):
- def setUp(self):
+ def setup(self):
self.transport = client().transport
def test_custom_polling_interval(self):
diff --git a/kombu/tests/transport/virtual/test_exchange.py b/kombu/tests/transport/virtual/test_exchange.py
index ad590afc..bb11da48 100644
--- a/kombu/tests/transport/virtual/test_exchange.py
+++ b/kombu/tests/transport/virtual/test_exchange.py
@@ -10,7 +10,7 @@ from kombu.tests.mocks import Transport
class ExchangeCase(Case):
type = None
- def setUp(self):
+ def setup(self):
if self.type:
self.e = self.type(Connection(transport=Transport).channel())
@@ -74,8 +74,8 @@ class test_Topic(ExchangeCase):
('stock.us.*', None, 'rBar'),
]
- def setUp(self):
- super(test_Topic, self).setUp()
+ def setup(self):
+ super(test_Topic, self).setup()
self.table = [(rkey, self.e.key_to_pattern(rkey), queue)
for rkey, _, queue in self.table]
diff --git a/kombu/tests/utils/test_encoding.py b/kombu/tests/utils/test_encoding.py
index fd710c30..cc7e514e 100644
--- a/kombu/tests/utils/test_encoding.py
+++ b/kombu/tests/utils/test_encoding.py
@@ -39,7 +39,7 @@ class test_default_encoding(Case):
class test_encoding_utils(Case):
- def setUp(self):
+ def setup(self):
if sys.version_info >= (3, 0):
raise SkipTest('not relevant on py3k')
@@ -58,12 +58,12 @@ class test_encoding_utils(Case):
class test_safe_str(Case):
- def setUp(self):
+ def setup(self):
self._cencoding = patch('sys.getfilesystemencoding')
self._encoding = self._cencoding.__enter__()
self._encoding.return_value = 'ascii'
- def tearDown(self):
+ def teardown(self):
self._cencoding.__exit__()
def test_when_bytes(self):
diff --git a/kombu/tests/utils/test_utils.py b/kombu/tests/utils/test_utils.py
index 0248a303..e0077b41 100644
--- a/kombu/tests/utils/test_utils.py
+++ b/kombu/tests/utils/test_utils.py
@@ -173,7 +173,7 @@ def insomnia(fun):
class test_retry_over_time(Case):
- def setUp(self):
+ def setup(self):
self.index = 0
class Predicate(Exception):
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index 68cb053c..621c60ca 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -43,20 +43,20 @@ import collections
import socket
import string
+from amqp.promise import transform, ensure_promise, promise
from anyjson import loads, dumps
-import boto
from boto import exception
-from boto import sdb as _sdb
from boto import sqs as _sqs
-from boto.sdb.domain import Domain
-from boto.sdb.connection import SDBConnection
from boto.sqs.connection import SQSConnection
from boto.sqs.message import Message
-from kombu.five import Empty, range, text_t
+from kombu.async import get_event_loop
+from kombu.async.aws import sqs as _asynsqs
+from kombu.async.aws.sqs import AsyncSQSConnection
+from kombu.five import Empty, range, string_t, text_t
from kombu.log import get_logger
-from kombu.utils import cached_property, uuid
+from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.transport.virtual import scheduling
@@ -76,103 +76,17 @@ def maybe_int(x):
return int(x)
except ValueError:
return x
-BOTO_VERSION = tuple(maybe_int(part) for part in boto.__version__.split('.'))
-W_LONG_POLLING = BOTO_VERSION >= (2, 8)
#: SQS bulk get supports a maximum of 10 messages at a time.
SQS_MAX_MESSAGES = 10
-class Table(Domain):
- """Amazon SimpleDB domain describing the message routing table."""
- # caches queues already bound, so we don't have to declare them again.
- _already_bound = set()
-
- def routes_for(self, exchange):
- """Iterator giving all routes for an exchange."""
- return self.select("""WHERE exchange = '%s'""" % exchange)
-
- def get_queue(self, queue):
- """Get binding for queue."""
- qid = self._get_queue_id(queue)
- if qid:
- return self.get_item(qid)
-
- def create_binding(self, queue):
- """Get binding item for queue.
-
- Creates the item if it doesn't exist.
-
- """
- item = self.get_queue(queue)
- if item:
- return item, item['id']
- id = uuid()
- return self.new_item(id), id
-
- def queue_bind(self, exchange, routing_key, pattern, queue):
- if queue not in self._already_bound:
- binding, id = self.create_binding(queue)
- binding.update(exchange=exchange,
- routing_key=routing_key or '',
- pattern=pattern or '',
- queue=queue or '',
- id=id)
- binding.save()
- self._already_bound.add(queue)
-
- def queue_delete(self, queue):
- """delete queue by name."""
- self._already_bound.discard(queue)
- item = self._get_queue_item(queue)
- if item:
- self.delete_item(item)
-
- def exchange_delete(self, exchange):
- """Delete all routes for `exchange`."""
- for item in self.routes_for(exchange):
- self.delete_item(item['id'])
-
- def get_item(self, item_name):
- """Uses `consistent_read` by default."""
- # Domain is an old-style class, can't use super().
- for consistent_read in (False, True):
- item = Domain.get_item(self, item_name, consistent_read)
- if item:
- return item
-
- def select(self, query='', next_token=None,
- consistent_read=True, max_items=None):
- """Uses `consistent_read` by default."""
- query = """SELECT * FROM `%s` %s""" % (self.name, query)
- return Domain.select(self, query, next_token,
- consistent_read, max_items)
-
- def _try_first(self, query='', **kwargs):
- for c in (False, True):
- for item in self.select(query, consistent_read=c, **kwargs):
- return item
-
- def get_exchanges(self):
- return list(set(i['exchange'] for i in self.select()))
-
- def _get_queue_item(self, queue):
- return self._try_first("""WHERE queue = '%s' limit 1""" % queue)
-
- def _get_queue_id(self, queue):
- item = self._get_queue_item(queue)
- if item:
- return item['id']
-
-
class Channel(virtual.Channel):
- Table = Table
-
default_region = 'us-east-1'
default_visibility_timeout = 1800 # 30 minutes.
- default_wait_time_seconds = 0 # disabled see #198
+ default_wait_time_seconds = 10 # disabled see #198
domain_format = 'kombu%(vhost)s'
- _sdb = None
+ _asynsqs = None
_sqs = None
_queue_cache = {}
_noack_queues = set()
@@ -187,7 +101,6 @@ class Channel(virtual.Channel):
queues = self.sqs.get_all_queues(prefix=self.queue_name_prefix)
for queue in queues:
self._queue_cache[queue.name] = queue
- self._fanout_queues = set()
# The drain_events() method stores extra messages in a local
# Deque object. This allows multiple messages to be requested from
@@ -195,9 +108,13 @@ class Channel(virtual.Channel):
# to the caller of the drain_events() method.
self._queue_message_cache = collections.deque()
+ self.hub = kwargs.get('hub') or get_event_loop()
+
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
self._noack_queues.add(queue)
+ if self.hub:
+ self._loop1(queue)
return super(Channel, self).basic_consume(
queue, no_ack, *args, **kwargs
)
@@ -255,6 +172,8 @@ class Channel(virtual.Channel):
def _new_queue(self, queue, **kwargs):
"""Ensure a queue with given name exists in SQS."""
+ if not isinstance(queue, string_t):
+ return queue
# Translate to SQS name for consistency with initial
# _queue_cache population.
queue = self.entity_name(self.queue_name_prefix + queue)
@@ -266,57 +185,11 @@ class Channel(virtual.Channel):
)
return q
- def queue_bind(self, queue, exchange=None, routing_key='',
- arguments=None, **kwargs):
- super(Channel, self).queue_bind(queue, exchange, routing_key,
- arguments, **kwargs)
- if self.typeof(exchange).type == 'fanout':
- self._fanout_queues.add(queue)
-
- def _queue_bind(self, *args):
- """Bind ``queue`` to ``exchange`` with routing key.
-
- Route will be stored in SDB if so enabled.
-
- """
- if self.supports_fanout:
- self.table.queue_bind(*args)
-
- def get_table(self, exchange):
- """Get routing table.
-
- Retrieved from SDB if :attr:`supports_fanout`.
-
- """
- if self.supports_fanout:
- return [(r['routing_key'], r['pattern'], r['queue'])
- for r in self.table.routes_for(exchange)]
- return super(Channel, self).get_table(exchange)
-
- def get_exchanges(self):
- if self.supports_fanout:
- return self.table.get_exchanges()
- return super(Channel, self).get_exchanges()
-
def _delete(self, queue, *args):
"""delete queue by name."""
- if self.supports_fanout:
- self.table.queue_delete(queue)
super(Channel, self)._delete(queue)
self._queue_cache.pop(queue, None)
- def exchange_delete(self, exchange, **kwargs):
- """Delete exchange by name."""
- if self.supports_fanout:
- self.table.exchange_delete(exchange)
- super(Channel, self).exchange_delete(exchange, **kwargs)
-
- def _has_queue(self, queue, **kwargs):
- """Return True if ``queue`` was previously declared."""
- if self.supports_fanout:
- return bool(self.table.get_queue(queue))
- return super(Channel, self)._has_queue(queue)
-
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
q = self._new_queue(queue)
@@ -324,25 +197,6 @@ class Channel(virtual.Channel):
m.set_body(dumps(message))
q.write(m)
- def _put_fanout(self, exchange, message, routing_key, **kwargs):
- """Deliver fanout message to all queues in ``exchange``."""
- for route in self.table.routes_for(exchange):
- self._put(route['queue'], message, **kwargs)
-
- def _get_from_sqs(self, queue, count=1):
- """Retrieve messages from SQS and returns the raw SQS message objects.
-
- :returns: List of SQS message objects
-
- """
- q = self._new_queue(queue)
- if W_LONG_POLLING and queue not in self._fanout_queues:
- return q.get_messages(
- count, wait_time_seconds=self.wait_time_seconds,
- )
- else: # boto < 2.8
- return q.get_messages(count)
-
def _message_to_python(self, message, queue_name, queue):
payload = loads(bytes_to_str(message.get_body()))
if queue_name in self._noack_queues:
@@ -369,36 +223,33 @@ class Channel(virtual.Channel):
q = self._new_queue(queue)
return [self._message_to_python(m, queue, q) for m in messages]
- def _get_bulk(self, queue, max_if_unlimited=SQS_MAX_MESSAGES):
+ def _get_bulk(self, queue,
+ max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
"""Try to retrieve multiple messages off ``queue``.
- Where _get() returns a single Payload object, this method returns a
- list of Payload objects. The number of objects returned is determined
- by the total number of messages available in the queue and the
- number of messages that the QoS object allows (based on the
+ Where :meth:`_get` returns a single Payload object, this method
+ returns a list of Payload objects. The number of objects returned
+ is determined by the total number of messages available in the queue
+ and the number of messages the QoS object allows (based on the
prefetch_count).
.. note::
+
Ignores QoS limits so caller is responsible for checking
that we are allowed to consume at least one message from the
queue. get_bulk will then ask QoS for an estimate of
the number of extra messages that we can consume.
- args:
- queue: The queue name (string) to pull from
-
- returns:
- payloads: A list of payload objects returned
+ :param queue: The queue name to pull from.
+ :returns list: of message objects.
"""
# drain_events calls `can_consume` first, consuming
# a token, so we know that we are allowed to consume at least
# one message.
- maxcount = self.qos.can_consume_max_estimate()
- maxcount = max_if_unlimited if maxcount is None else max(maxcount, 1)
+ maxcount = self._get_message_estimate()
if maxcount:
- messages = self._get_from_sqs(
- queue, count=min(maxcount, SQS_MAX_MESSAGES),
- )
+ q = self._new_queue(queue)
+ messages = q.get_messages(num_messages=maxcount)
if messages:
return self._messages_to_python(messages, queue)
@@ -406,12 +257,69 @@ class Channel(virtual.Channel):
def _get(self, queue):
"""Try to retrieve a single message off ``queue``."""
- messages = self._get_from_sqs(queue, count=1)
-
+ q = self._new_queue(queue)
+ messages = q.get_messages(num_messages=1)
if messages:
return self._messages_to_python(messages, queue)[0]
raise Empty()
+ def _loop1(self, queue, _=None):
+ self.hub.call_soon(self._schedule_queue, queue)
+
+ def _schedule_queue(self, queue):
+ if queue in self._active_queues:
+ if self.qos.can_consume():
+ self._get_bulk_async(
+ queue, callback=promise(self._loop1, (queue, )),
+ )
+ else:
+ self._loop1(queue)
+
+ def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES):
+ maxcount = self.qos.can_consume_max_estimate()
+ return min(
+ max_if_unlimited if maxcount is None else max(maxcount, 1),
+ max_if_unlimited,
+ )
+
+ def _get_bulk_async(self, queue,
+ max_if_unlimited=SQS_MAX_MESSAGES, callback=None):
+ maxcount = self._get_message_estimate()
+ if maxcount:
+ return self._get_async(queue, maxcount, callback=callback)
+ # Not allowed to consume, make sure to notify callback..
+ callback = ensure_promise(callback)
+ callback([])
+ return callback
+
+ def _get_async(self, queue, count=1, callback=None):
+ q = self._new_queue(queue)
+ return self._get_from_sqs(
+ q, count=count, connection=self.asynsqs,
+ callback=transform(self._on_messages_ready, callback, q, queue),
+ )
+
+ def _on_messages_ready(self, queue, qname, messages):
+ if messages:
+ callbacks = self.connection._callbacks
+ for raw_message in messages:
+ message = self._message_to_python(raw_message, qname, queue)
+ callbacks[qname](message)
+
+ def _get_from_sqs(self, queue,
+ count=1, connection=None, callback=None):
+ """Retrieve and handle messages from SQS.
+
+ Uses long polling and returns :class:`~amqp.promise`.
+
+ """
+ connection = connection if connection is not None else queue.connection
+ return connection.receive_message(
+ queue, number_messages=count,
+ wait_time_seconds=self.wait_time_seconds,
+ callback=callback,
+ )
+
def _restore(self, message,
unwanted_delivery_info=('sqs_message', 'sqs_queue')):
for unwanted_key in unwanted_delivery_info:
@@ -448,7 +356,7 @@ class Channel(virtual.Channel):
def close(self):
super(Channel, self).close()
- for conn in (self._sqs, self._sdb):
+ for conn in (self._sqs, self._asynsqs):
if conn:
try:
conn.close()
@@ -477,19 +385,12 @@ class Channel(virtual.Channel):
return self._sqs
@property
- def sdb(self):
- if self._sdb is None:
- self._sdb = self._aws_connect_to(SDBConnection, _sdb.regions())
- return self._sdb
-
- @property
- def table(self):
- name = self.entity_name(
- self.domain_format % {'vhost': self.conninfo.virtual_host})
- d = self.sdb.get_object(
- 'CreateDomain', {'DomainName': name}, self.Table)
- d.name = name
- return d
+ def asynsqs(self):
+ if self._asynsqs is None:
+ self._asynsqs = self._aws_connect_to(
+ AsyncSQSConnection, _asynsqs.regions(),
+ )
+ return self._asynsqs
@property
def conninfo(self):
@@ -510,7 +411,7 @@ class Channel(virtual.Channel):
@cached_property
def supports_fanout(self):
- return self.transport_options.get('sdb_persistence', False)
+ return False
@cached_property
def region(self):
@@ -537,3 +438,4 @@ class Transport(virtual.Transport):
)
driver_type = 'sqs'
driver_name = 'sqs'
+ supports_ev = True
diff --git a/kombu/transport/amqplib.py b/kombu/transport/amqplib.py
index fff82a1f..08c088a4 100644
--- a/kombu/transport/amqplib.py
+++ b/kombu/transport/amqplib.py
@@ -315,7 +315,11 @@ class Transport(base.Transport):
driver_name = 'amqplib'
driver_type = 'amqp'
- supports_ev = True
+
+ implements = base.Transport.implements.extend(
+ async=True,
+ heartbeats=False,
+ )
def __init__(self, client, **kwargs):
self.client = client
diff --git a/kombu/transport/base.py b/kombu/transport/base.py
index c226307e..699a70f0 100644
--- a/kombu/transport/base.py
+++ b/kombu/transport/base.py
@@ -59,6 +59,28 @@ class Management(object):
raise _LeftBlank(self, 'get_bindings')
+class Implements(dict):
+
+ def __getattr__(self, key):
+ try:
+ return self[key]
+ except KeyError:
+ raise AttributeError(key)
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def extend(self, **kwargs):
+ return self.__class__(self, **kwargs)
+
+
+default_transport_capabilities = Implements(
+ async=False,
+ exchange_type=frozenset(['direct', 'topic', 'fanout', 'headers']),
+ heartbeats=False,
+)
+
+
class Transport(object):
"""Base class for transports."""
Management = Management
@@ -87,15 +109,10 @@ class Transport(object):
#: Name of driver library (e.g. 'py-amqp', 'redis', 'beanstalkc').
driver_name = 'N/A'
- #: Whether this transports support heartbeats,
- #: and that the :meth:`heartbeat_check` method has any effect.
- supports_heartbeats = False
-
- #: Set to true if the transport supports the AIO interface.
- supports_ev = False
-
__reader = None
+ implements = default_transport_capabilities.extend()
+
def __init__(self, client, **kwargs):
self.client = client
@@ -123,10 +140,10 @@ class Transport(object):
def get_heartbeat_interval(self, connection):
return 0
- def register_with_event_loop(self, loop):
+ def register_with_event_loop(self, connection, loop):
pass
- def unregister_from_event_loop(self, loop):
+ def unregister_from_event_loop(self, connection, loop):
pass
def verify_connection(self, connection):
@@ -171,3 +188,11 @@ class Transport(object):
@cached_property
def manager(self):
return self.get_manager()
+
+ @property
+ def supports_heartbeats(self):
+ return self.implements.heartbeats
+
+ @property
+ def supports_ev(self):
+ return self.implements.async
diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py
index 286bd78e..96786c1a 100644
--- a/kombu/transport/librabbitmq.py
+++ b/kombu/transport/librabbitmq.py
@@ -88,7 +88,10 @@ class Transport(base.Transport):
driver_type = 'amqp'
driver_name = 'librabbitmq'
- supports_ev = True
+ implements = base.Transport.implements.extend(
+ async=True,
+ heartbeats=False,
+ )
def __init__(self, client, **kwargs):
self.client = client
diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py
index 78af0f9f..eeef8756 100644
--- a/kombu/transport/mongodb.py
+++ b/kombu/transport/mongodb.py
@@ -305,5 +305,9 @@ class Transport(virtual.Transport):
driver_type = 'mongodb'
driver_name = 'pymongo'
+ implements = virtual.Transport.implements.extend(
+ exchange_types=frozenset(['direct', 'topic', 'fanout']),
+ )
+
def driver_version(self):
return pymongo.version
diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py
index 01844305..71f462fe 100644
--- a/kombu/transport/pyamqp.py
+++ b/kombu/transport/pyamqp.py
@@ -74,8 +74,11 @@ class Transport(base.Transport):
driver_name = 'py-amqp'
driver_type = 'amqp'
- supports_heartbeats = True
- supports_ev = True
+
+ implements = base.Transport.implements.extend(
+ async=True,
+ heartbeats=True,
+ )
def __init__(self, client, default_port=None, **kwargs):
self.client = client
diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py
index c3f6decd..9753c19e 100644
--- a/kombu/transport/redis.py
+++ b/kombu/transport/redis.py
@@ -905,10 +905,14 @@ class Transport(virtual.Transport):
polling_interval = None # disable sleep between unsuccessful polls.
default_port = DEFAULT_PORT
- supports_ev = True
driver_type = 'redis'
driver_name = 'redis'
+ implements = virtual.Transport.implements.extend(
+ async=True,
+ exchange_types=frozenset(['direct', 'topic', 'fanout'])
+ )
+
def __init__(self, *args, **kwargs):
super(Transport, self).__init__(*args, **kwargs)
diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py
index cb844de9..502f8c9f 100644
--- a/kombu/transport/virtual/__init__.py
+++ b/kombu/transport/virtual/__init__.py
@@ -771,6 +771,12 @@ class Transport(base.Transport):
#: Max number of channels
channel_max = 65535
+ implements = base.Transport.implements.extend(
+ async=False,
+ exchange_type=frozenset(['direct', 'topic']),
+ heartbeats=False,
+ )
+
def __init__(self, client, **kwargs):
self.client = client
self.channels = []
@@ -846,6 +852,13 @@ class Transport(base.Transport):
self._callbacks[queue](message)
+ def on_message_ready(self, channel, message, queue):
+ if not queue or queue not in self._callbacks:
+ raise KeyError(
+ 'Message for queue {0!r} without consumers: {1}'.format(
+ queue, message))
+ self._callbacks[queue](message)
+
def _drain_channel(self, channel, timeout=None):
return channel.drain_events(timeout=timeout)
diff --git a/kombu/transport/zmq.py b/kombu/transport/zmq.py
index e6b8a48b..0df89a1f 100644
--- a/kombu/transport/zmq.py
+++ b/kombu/transport/zmq.py
@@ -245,7 +245,9 @@ class Transport(virtual.Transport):
connection_errors = virtual.Transport.connection_errors + (ZMQError, )
- supports_ev = True
+ implements = virtual.Transport.implements.extend(
+ async=True,
+ )
polling_interval = None
def __init__(self, *args, **kwargs):
diff --git a/requirements/extras/sqs.txt b/requirements/extras/sqs.txt
index 66b95834..8cd697e0 100644
--- a/requirements/extras/sqs.txt
+++ b/requirements/extras/sqs.txt
@@ -1 +1 @@
-boto>=2.13.3
+boto>=2.8.0