diff options
author | Ask Solem <ask@celeryproject.org> | 2014-04-23 23:00:03 +0100 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2014-05-03 22:28:43 +0100 |
commit | 6a1abb7e946085befb3c7ea4a3e6d703160356e4 (patch) | |
tree | 4bbcf5ee7de3053ee2644011bdc1ed2cb518e5ef | |
parent | 01779ce7e64df48f98b051e5183b5fb38529c591 (diff) | |
download | kombu-6a1abb7e946085befb3c7ea4a3e6d703160356e4.tar.gz |
Initial import of kombu.async.aws
Tests: Use setup/teardown instead of setUp/tearDown
Tests
async.http: Auth support
Tests for kombu.async.aws.connection
All tests for SQS and implemented message batch operations
SQS: Removes poor fanout support using SDB
Async SQS working somewhat
More fixes
flakes
SQS: Improves async queue/channel scheduling
Adds Transport.implements for introspection of transport capabilities
56 files changed, 2041 insertions, 357 deletions
diff --git a/kombu/async/aws/__init__.py b/kombu/async/aws/__init__.py new file mode 100644 index 00000000..43a01abd --- /dev/null +++ b/kombu/async/aws/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + + +def connect_sqs(aws_access_key_id=None, aws_secret_access_key=None, **kwargs): + from .sqs.connection import AsyncSQSConnection + return AsyncSQSConnection( + aws_access_key_id, aws_secret_access_key, **kwargs + ) diff --git a/kombu/async/aws/connection.py b/kombu/async/aws/connection.py new file mode 100644 index 00000000..532c8a7e --- /dev/null +++ b/kombu/async/aws/connection.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +import mimetools + +from urlparse import urlunsplit +from xml.sax import parseString as sax_parse + +from amqp.promise import promise, transform + +from kombu.async.http import Headers, Request, get_client +from kombu.five import StringIO, items + +try: + import boto + from boto.connection import AWSAuthConnection, AWSQueryConnection + from boto.handler import XmlHandler + from boto.resultset import ResultSet +except ImportError: # pragma: no cover + boto = ResultSet = XmlHandler = None # noqa + + class _void(object): + pass + AWSAuthConnection, AWSQueryConnection = _void # noqa + +__all__ = ['AsyncHTTPConnection', 'AsyncHTTPSConnection', + 'AsyncHTTPResponse', 'AsyncConnection', + 'AsyncAWSAuthConnection', 'AsyncAWSQueryConnection'] + + +class AsyncHTTPResponse(object): + + def __init__(self, response): + self.response = response + self._msg = None + self.version = 10 + + def read(self, *args, **kwargs): + return self.response.body + + def getheader(self, name, default=None): + return self.response.headers.get(name, default) + + def getheaders(self): + return list(items(self.response.headers)) + + @property + def msg(self): + if self._msg is None: + self._msg = mimetools.Message( + StringIO('\r\n'.join( + '{0}: {1}'.format(*h) for h in self.getheaders()) + ) + ) + return self._msg + + @property + def status(self): + return self.response.code + + @property + def reason(self): + if self.response.error: + return self.response.error.message + return '' + + def __repr__(self): + return repr(self.response) + + +class AsyncHTTPConnection(object): + Request = Request + Response = AsyncHTTPResponse + + method = 'GET' + path = '/' + body = None + scheme = 'http' + default_ports = {'http': 80, 'https': 443} + + def __init__(self, host, port=None, + strict=None, timeout=20.0, http_client=None, **kwargs): + self.host = host + self.port = port + self.headers = [] + self.timeout = timeout + self.strict = strict + self.http_client = http_client or get_client() + + def request(self, method, path, body=None, headers=None): + self.path = path + self.method = method + if body is not None: + try: + read = body.read + except AttributeError: + self.body = body + else: + self.body = read() + if headers is not None: + self.headers.extend(list(items(headers))) + + def getrequest(self, scheme=None): + scheme = scheme if scheme else self.scheme + host = self.host + if self.port and self.port != self.default_ports[scheme]: + host = '{0}:{1}'.format(host, self.port) + url = urlunsplit((scheme, host, self.path, '', '')) + headers = Headers(self.headers) + return self.Request(url, method=self.method, headers=headers, + body=self.body, connect_timeout=self.timeout, + request_timeout=self.timeout, validate_cert=False) + + def getresponse(self, callback=None): + request = self.getrequest() + request.then(transform(self.Response, callback)) + return self.http_client.add_request(request) + + def set_debuglevel(self, level): + pass + + def connect(self): + pass + + def close(self): + pass + + def putrequest(self, method, path, **kwargs): + self.method = method + self.path = path + + def putheader(self, header, value): + self.headers.append((header, value)) + + def endheaders(self): + pass + + def send(self, data): + if self.body: + self.body += data + else: + self.body = data + + def __repr__(self): + return '<AsyncHTTPConnection: {0!r}>'.format(self.getrequest()) + + +class AsyncHTTPSConnection(AsyncHTTPConnection): + scheme = 'https' + + +class AsyncConnection(object): + + def __init__(self, http_client=None, **kwargs): + self._httpclient = http_client or get_client() + + def get_http_connection(self, host, port, is_secure): + return (AsyncHTTPSConnection if is_secure else AsyncHTTPConnection)( + host, port, http_client=self._httpclient, + ) + + def _mexe(self, request, sender=None, callback=None): + callback = callback or promise() + boto.log.debug( + 'HTTP %s %s/%s headers=%s body=%s', + request.host, request.path, + request.headers, request.body, + ) + + conn = self.get_http_connection( + request.host, request.port, self.is_secure, + ) + request.authorize(connection=self) + + if callable(sender): + sender(conn, request.method, request.path, request.body, + request.headers, callback) + else: + conn.request(request.method, request.path, + request.body, request.headers) + conn.getresponse(callback=callback) + return callback + + +class AsyncAWSAuthConnection(AsyncConnection, AWSAuthConnection): + + def __init__(self, host, + http_client=None, http_client_params={}, **kwargs): + AsyncConnection.__init__(self, http_client, **http_client_params) + AWSAuthConnection.__init__(self, host, **kwargs) + + def make_request(self, method, path, headers=None, data='', host=None, + auth_path=None, sender=None, callback=None, **kwargs): + req = self.build_base_http_request( + method, path, auth_path, {}, headers, data, host, + ) + return self._mexe(req, sender=sender, callback=callback) + + +class AsyncAWSQueryConnection(AsyncConnection, AWSQueryConnection): + + def __init__(self, host, + http_client=None, http_client_params={}, **kwargs): + AsyncConnection.__init__(self, http_client, **http_client_params) + AWSAuthConnection.__init__(self, host, **kwargs) + + def make_request(self, action, params, path, verb, callback=None): + request = self.build_base_http_request( + verb, path, None, params, {}, '', self.server_name()) + if action: + request.params['Action'] = action + request.params['Version'] = self.APIVersion + return self._mexe(request, callback=callback) + + def get_list(self, action, params, markers, + path='/', parent=None, verb='GET', callback=None): + return self.make_request( + action, params, path, verb, + callback=transform( + self._on_list_ready, callback, parent or self, markers, + ), + ) + + def get_object(self, action, params, cls, + path='/', parent=None, verb='GET', callback=None): + return self.make_request( + action, params, path, verb, + callback=transform( + self._on_obj_ready, callback, parent or self, cls, + ), + ) + + def get_status(self, action, params, + path='/', parent=None, verb='GET', callback=None): + return self.make_request( + action, params, path, verb, + callback=transform( + self._on_status_ready, callback, parent or self, + ), + ) + + def _on_list_ready(self, parent, markers, response): + body = response.read() + if response.status == 200 and body: + rs = ResultSet(markers) + h = XmlHandler(rs, parent) + sax_parse(body, h) + return rs + else: + raise self._for_status(response, body) + + def _on_obj_ready(self, parent, cls, response): + body = response.read() + if response.status == 200 and body: + obj = cls(parent) + h = XmlHandler(obj, parent) + sax_parse(body, h) + return obj + else: + raise self._for_status(response, body) + + def _on_status_ready(self, parent, response): + body = response.read() + if response.status == 200 and body: + rs = ResultSet() + h = XmlHandler(rs, parent) + sax_parse(body, h) + return rs.status + else: + raise self._for_status(response, body) + + def _for_status(self, response, body): + context = 'Empty body' if not body else 'HTTP Error' + exc = self.ResponseError(response.status, response.reason, body) + boto.log.error('{0}: %r'.format(context), exc) + return exc diff --git a/kombu/async/aws/sqs/__init__.py b/kombu/async/aws/sqs/__init__.py new file mode 100644 index 00000000..d5fb1bbb --- /dev/null +++ b/kombu/async/aws/sqs/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from boto.regioninfo import get_regions + +from .connection import AsyncSQSConnection + +__all__ = ['regions', 'connect_to_region'] + + +def regions(): + return get_regions('sqs', connection_cls=AsyncSQSConnection) + + +def connect_to_region(region_name, **kwargs): + for region in regions(): + if region.name == region_name: + return region.connect(**kwargs) diff --git a/kombu/async/aws/sqs/connection.py b/kombu/async/aws/sqs/connection.py new file mode 100644 index 00000000..4df9fbd3 --- /dev/null +++ b/kombu/async/aws/sqs/connection.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from amqp.promise import transform + +from boto.regioninfo import RegionInfo +from boto.sqs import connection as _connection +from boto.sqs.attributes import Attributes +from boto.sqs.batchresults import BatchResults + +from kombu.async.aws.connection import AsyncAWSQueryConnection + +from .message import AsyncMessage +from .queue import AsyncQueue + +__all__ = ['AsyncSQSConnection'] + + +class AsyncSQSConnection(AsyncAWSQueryConnection, _connection.SQSConnection): + + def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, + is_secure=True, port=None, proxy=None, proxy_port=None, + proxy_user=None, proxy_pass=None, debug=0, + https_connection_factory=None, region=None, *args, **kwargs): + self.region = region or RegionInfo( + self, self.DefaultRegionName, self.DefaultRegionEndpoint, + connection_cls=type(self), + ) + AsyncAWSQueryConnection.__init__( + self, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + is_secure=is_secure, port=port, + proxy=proxy, proxy_port=proxy_port, + proxy_user=proxy_user, proxy_pass=proxy_pass, + host=self.region.endpoint, debug=debug, + https_connection_factory=https_connection_factory, **kwargs + ) + + def create_queue(self, queue_name, + visibility_timeout=None, callback=None): + params = {'QueueName': queue_name} + if visibility_timeout: + params['DefaultVisibilityTimeout'] = format( + visibility_timeout, 'd', + ) + return self.get_object('CreateQueue', params, AsyncQueue, + callback=callback) + + def delete_queue(self, queue, force_deletion=False, callback=None): + return self.get_status('DeleteQueue', None, queue.id, + callback=callback) + + def get_queue_attributes(self, queue, attribute='All', callback=None): + return self.get_object( + 'GetQueueAttributes', {'AttributeName': attribute}, + Attributes, queue.id, callback=callback, + ) + + def set_queue_attribute(self, queue, attribute, value, callback=None): + return self.get_status( + 'SetQueueAttribute', + {'Attribute.Name': attribute, 'Attribute.Value': value}, + queue.id, callback=callback, + ) + + def receive_message(self, queue, + number_messages=1, visibility_timeout=None, + attributes=None, wait_time_seconds=None, + callback=None): + params = {'MaxNumberOfMessages': number_messages} + if visibility_timeout: + params['VisibilityTimeout'] = visibility_timeout + if attributes: + self.build_list_params(params, attributes, 'AttributeName') + if wait_time_seconds is not None: + params['WaitTimeSeconds'] = wait_time_seconds + return self.get_list( + 'ReceiveMessage', params, [('Message', queue.message_class)], + queue.id, callback=callback, + ) + + def delete_message(self, queue, message, callback=None): + return self.delete_message_from_handle( + queue, message.receipt_handle, callback, + ) + + def delete_message_batch(self, queue, messages, callback=None): + params = {} + for i, m in enumerate(messages): + prefix = 'DeleteMessageBatchRequestEntry.{0}'.format(i + 1) + params.update({ + '{0}.Id'.format(prefix): m.id, + '{0}.ReceiptHandle'.format(prefix): m.receipt_handle, + }) + return self.get_object( + 'DeleteMessageBatch', params, BatchResults, queue.id, + verb='POST', callback=callback, + ) + + def delete_message_from_handle(self, queue, receipt_handle, + callback=None): + return self.get_status( + 'DeleteMessage', {'ReceiptHandle': receipt_handle}, + queue.id, callback=callback, + ) + + def send_message(self, queue, message_content, + delay_seconds=None, callback=None): + params = {'MessageBody': message_content} + if delay_seconds: + params['DelaySeconds'] = int(delay_seconds) + return self.get_object( + 'SendMessage', params, AsyncMessage, queue.id, + verb='POST', callback=callback, + ) + + def send_message_batch(self, queue, messages, callback=None): + params = {} + for i, msg in enumerate(messages): + prefix = 'SendMessageBatchRequestEntry.{0}'.format(i + 1) + params.update({ + '{0}.Id'.format(prefix): msg[0], + '{0}.MessageBody'.format(prefix): msg[1], + '{0}.DelaySeconds'.format(prefix): msg[2], + }) + return self.get_object( + 'SendMessageBatch', params, BatchResults, queue.id, + verb='POST', callback=callback, + ) + + def change_message_visibility(self, queue, receipt_handle, + visibility_timeout, callback=None): + return self.get_status( + 'ChangeMessageVisibility', + {'ReceiptHandle': receipt_handle, + 'VisibilityTimeout': visibility_timeout}, + queue.id, callback=callback, + ) + + def change_message_visibility_batch(self, queue, messages, callback=None): + params = {} + for i, t in enumerate(messages): + pre = 'ChangeMessageVisibilityBatchRequestEntry.{0}'.format(i + 1) + params.update({ + '{0}.Id'.format(pre): t[0].id, + '{0}.ReceiptHandle'.format(pre): t[0].receipt_handle, + '{0}.VisibilityTimeout'.format(pre): t[1], + }) + return self.get_object( + 'ChangeMessageVisibilityBatch', params, BatchResults, queue.id, + verb='POST', callback=callback, + ) + + def get_all_queues(self, prefix='', callback=None): + params = {} + if prefix: + params['QueueNamePrefix'] = prefix + return self.get_list( + 'ListQueues', params, [('QueueUrl', AsyncQueue)], + callback=callback, + ) + + def get_queue(self, queue_name, callback=None): + # TODO Does not support owner_acct_id argument + return self.get_all_queues( + queue_name, + transform(self._on_queue_ready, callback, queue_name), + ) + lookup = get_queue + + def _on_queue_ready(self, name, queues): + return next( + (q for q in queues if q.url.endswith(name)), None, + ) + + def get_dead_letter_source_queues(self, queue, callback=None): + return self.get_list( + 'ListDeadLetterSourceQueues', {'QueueUrl': queue.url}, + [('QueueUrl', AsyncQueue)], + callback=callback, + ) + + def add_permission(self, queue, label, aws_account_id, action_name, + callback=None): + return self.get_status( + 'AddPermission', + {'Label': label, + 'AWSAccountId': aws_account_id, + 'ActionName': action_name}, + queue.id, callback=callback, + ) + + def remove_permission(self, queue, label, callback=None): + return self.get_status( + 'RemovePermission', {'Label': label}, queue.id, callback=callback, + ) diff --git a/kombu/async/aws/sqs/jsonmessage.py b/kombu/async/aws/sqs/jsonmessage.py new file mode 100644 index 00000000..e5ed0e7f --- /dev/null +++ b/kombu/async/aws/sqs/jsonmessage.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from boto.sqs import jsonmessage as _jsonmessage + +from .message import BaseAsyncMessage + + +class AsyncJSONMessage(BaseAsyncMessage, _jsonmessage.JSONMessage): + pass diff --git a/kombu/async/aws/sqs/message.py b/kombu/async/aws/sqs/message.py new file mode 100644 index 00000000..8141b6d2 --- /dev/null +++ b/kombu/async/aws/sqs/message.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from boto.sqs import message as _message + +__all__ = ['BaseAsyncMessage', 'AsyncRawMessage', 'AsyncMessage', + 'AsyncMHMessage', 'AsyncEncodedMHMessage'] + + +class BaseAsyncMessage(object): + + def delete(self, callback=None): + if self.queue: + return self.queue.delete_message(self, callback) + + def change_visibility(self, visibility_timeout, callback=None): + if self.queue: + return self.queue.connection.change_message_visibility( + self.queue, self.receipt_handle, visibility_timeout, callback, + ) + + +class AsyncRawMessage(BaseAsyncMessage, _message.RawMessage): + pass + + +class AsyncMessage(BaseAsyncMessage, _message.Message): + pass + + +class AsyncMHMessage(BaseAsyncMessage, _message.MHMessage): + pass + + +class AsyncEncodedMHMessage(BaseAsyncMessage, _message.EncodedMHMessage): + pass diff --git a/kombu/async/aws/sqs/queue.py b/kombu/async/aws/sqs/queue.py new file mode 100644 index 00000000..038c6a01 --- /dev/null +++ b/kombu/async/aws/sqs/queue.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from amqp.promise import transform + +from boto.exception import BotoClientError +from boto.sqs import queue as _queue + +from .message import AsyncMessage + +_all__ = ['AsyncQueue'] + + +def list_first(rs): + return rs[0] if len(rs) == 1 else None + + +class AsyncQueue(_queue.Queue): + + def __init__(self, connection=None, url=None, message_class=AsyncMessage): + self.connection = connection + self.url = url + self.message_class = message_class + self.visibility_timeout = None + + def _NA(self, *args, **kwargs): + raise BotoClientError('Not implemented') + count_slow = dump = save_to_file = save_to_filename = save = \ + save_to_s3 = load_from_s3 = load_from_file = load_from_filename = \ + load = clear = _NA + + def get_attributes(self, attributes='All', callback=None): + return self.connection.get_queue_attributes( + self, attributes, callback, + ) + + def set_attribute(self, attribute, value, callback=None): + return self.connection.set_queue_attribute( + self, attribute, value, callback, + ) + + def get_timeout(self, callback=None, _attr='VisibilityTimeout'): + return self.get_attributes( + _attr, transform( + self._coerce_field_value, callback, _attr, int, + ), + ) + + def _coerce_field_value(self, key, type, response): + return type(response[key]) + + def set_timeout(self, visibility_timeout, callback=None): + return self.set_attribute( + 'VisibilityTimeout', visibility_timeout, + transform( + self._on_timeout_set, callback, + ) + ) + + def _on_timeout_set(self, visibility_timeout): + if visibility_timeout: + self.visibility_timeout = visibility_timeout + return self.visibility_timeout + + def add_permission(self, label, aws_account_id, action_name, + callback=None): + return self.connection.add_permission( + self, label, aws_account_id, action_name, callback, + ) + + def remove_permission(self, label, callback=None): + return self.connection.remove_permission(self, label, callback) + + def read(self, visibility_timeout=None, wait_time_seconds=None, + callback=None): + return self.get_messages( + 1, visibility_timeout, + wait_time_seconds=wait_time_seconds, + callback=transform(list_first, callback), + ) + + def write(self, message, delay_seconds=None, callback=None): + return self.connection.send_message( + self, message.get_body_encoded(), delay_seconds, + callback=transform(self._on_message_sent, callback, message), + ) + + def write_batch(self, messages, callback=None): + return self.connection.send_message_batch( + self, messages, callback=callback, + ) + + def _on_message_sent(self, orig_message, new_message): + orig_message.id = new_message.id + orig_message.md5 = new_message.md5 + return new_message + + def get_messages(self, num_messages=1, visibility_timeout=None, + attributes=None, wait_time_seconds=None, callback=None): + return self.connection.receive_message( + self, number_messages=num_messages, + visibility_timeout=visibility_timeout, + attributes=attributes, + wait_time_seconds=wait_time_seconds, + callback=callback, + ) + + def delete_message(self, message, callback=None): + return self.connection.delete_message(self, message, callback) + + def delete_message_batch(self, messages, callback=None): + return self.connection.delete_message_batch( + self, messages, callback=callback, + ) + + def change_message_visibility_batch(self, messages, callback=None): + return self.connection.change_message_visibility_batch( + self, messages, callback=callback, + ) + + def delete(self, callback=None): + return self.connection.delete_queue(self, callback=callback) + + def count(self, page_size=10, vtimeout=10, callback=None, + _attr='ApproximateNumberOfMessages'): + return self.get_attributes( + _attr, callback=transform( + self._coerce_field_value, callback, _attr, int, + ), + ) diff --git a/kombu/async/http/__init__.py b/kombu/async/http/__init__.py index 8c18dfa9..d1b3deda 100644 --- a/kombu/async/http/__init__.py +++ b/kombu/async/http/__init__.py @@ -1,10 +1,21 @@ from __future__ import absolute_import +from kombu.async import get_event_loop + from .base import Request, Headers, Response __all__ = ['Client', 'Headers', 'Response', 'Request'] -def Client(hub, **kwargs): +def Client(hub=None, **kwargs): from .curl import CurlClient return CurlClient(hub, **kwargs) + + +def get_client(hub=None, **kwargs): + hub = hub or get_event_loop() + try: + return hub._current_http_client + except AttributeError: + client = hub._current_http_client = Client(hub, **kwargs) + return client diff --git a/kombu/async/http/base.py b/kombu/async/http/base.py index 7a3dfd28..35fbd831 100644 --- a/kombu/async/http/base.py +++ b/kombu/async/http/base.py @@ -6,7 +6,7 @@ from amqp.promise import Thenable, promise, maybe_promise from kombu.exceptions import HttpError from kombu.five import items -from kombu.utils import cached_property, coro +from kombu.utils import coro from kombu.utils.encoding import bytes_to_str from kombu.utils.functional import maybe_list, memoize @@ -55,6 +55,9 @@ class Request(object): default). :keyword validate_cert: Set to true if the server certificate should be verified when performing ``https://`` requests (enabled by default). + :keyword auth_username: Username for HTTP authentication. + :keyword auth_password: Password for HTTP authentication. + :keyword auth_mode: Type of HTTP authentication (``basic`` or ``digest``). :keyword user_agent: Custom user agent for this request. :keyword network_interace: Network interface to use for this request. :keyword on_ready: Callback to be called when the response has been @@ -84,6 +87,7 @@ class Request(object): """ body = user_agent = network_interface = \ + auth_username = auth_password = auth_mode = \ proxy_host = proxy_port = proxy_username = proxy_password = \ ca_certs = client_key = client_cert = None @@ -118,6 +122,9 @@ class Request(object): def then(self, callback, errback=None): self.on_ready.then(callback, errback) + + def __repr__(self): + return '<Request: {0.method} {0.url} {0.body}>'.format(self) Thenable.register(Request) diff --git a/kombu/async/http/curl.py b/kombu/async/http/curl.py index b4c516a5..d76c8e75 100644 --- a/kombu/async/http/curl.py +++ b/kombu/async/http/curl.py @@ -7,7 +7,7 @@ from functools import partial from io import BytesIO from time import time -from kombu.async.hub import READ, WRITE +from kombu.async.hub import READ, WRITE, get_event_loop from kombu.exceptions import HttpError from kombu.five import items from kombu.utils.encoding import bytes_to_str @@ -37,9 +37,10 @@ EXTRA_METHODS = frozenset(['DELETE', 'OPTIONS', 'PATCH']) class CurlClient(BaseClient): Curl = Curl - def __init__(self, hub, max_clients=10): + def __init__(self, hub=None, max_clients=10): if pycurl is None: raise ImportError('The curl client requires the pycurl library.') + hub = hub or get_event_loop() super(CurlClient, self).__init__(hub) self.max_clients = max_clients @@ -71,6 +72,7 @@ class CurlClient(BaseClient): self._pending.append(request) self._process_queue() self._set_timeout(0) + return request def _handle_socket(self, event, fd, multi, data, _pycurl=pycurl): if event == _pycurl.POLL_REMOVE: @@ -186,9 +188,9 @@ class CurlClient(BaseClient): request.headers.setdefault('Expect', '') request.headers.setdefault('Pragma', '') - curl.setopt( + setopt( _pycurl.HTTPHEADER, - ['%s %s'.format(h) for h in items(request.headers)], + ['{0}: {1}'.format(*h) for h in items(request.headers)], ) setopt( @@ -240,8 +242,8 @@ class CurlClient(BaseClient): setopt(meth, True) if request.method in ('POST', 'PUT'): - assert request.body is not None - reqbuffer = BytesIO(request.body) + body = request.body or '' + reqbuffer = BytesIO(body) setopt(_pycurl.READFUNCTION, reqbuffer.read) if request.method == 'POST': @@ -249,14 +251,24 @@ class CurlClient(BaseClient): if cmd == _pycurl.IOCMD_RESTARTREAD: reqbuffer.seek(0) setopt(_pycurl.IOCTLFUNCTION, ioctl) - setopt(_pycurl.POSTFIELDSIZE, len(request.body)) + setopt(_pycurl.POSTFIELDSIZE, len(body)) else: - setopt(_pycurl.INFILESIZE, len(request.body)) + setopt(_pycurl.INFILESIZE, len(body)) elif request.method == 'GET': - assert request.body is None - - # TODO Does not support Basic AUTH - curl.unsetopt(_pycurl.USERPWD) + assert not request.body + + if request.auth_username is not None: + auth_mode = { + 'basic': _pycurl.HTTPAUTH_BASIC, + 'digest': _pycurl.HTTPAUTH_DIGEST + }[request.auth_mode or 'basic'] + setopt(_pycurl.HTTPAUTH, auth_mode) + userpwd = '{0}:{1}'.format( + request.auth_username, request.auth_password or '', + ) + setopt(_pycurl.USERPWD, userpwd) + else: + curl.unsetopt(_pycurl.USERPWD) if request.client_cert is not None: setopt(_pycurl.SSLCERT, request.client_cert) diff --git a/kombu/async/hub.py b/kombu/async/hub.py index 264f8188..25c71cdb 100644 --- a/kombu/async/hub.py +++ b/kombu/async/hub.py @@ -15,7 +15,7 @@ from contextlib import contextmanager from time import sleep from types import GeneratorType as generator -from amqp import promise +from amqp.promise import promise, Thenable from kombu.five import Empty, range from kombu.log import get_logger @@ -184,9 +184,10 @@ class Hub(object): self._loop = None def call_soon(self, callback, *args): - handle = promise(callback, args) - self._ready.append(handle) - return handle + if not isinstance(callback, Thenable): + callback = promise(callback, args) + self._ready.append(callback) + return callback def call_later(self, delay, callback, *args): return self.timer.call_after(delay, callback, args) diff --git a/kombu/connection.py b/kombu/connection.py index 85b8f5e9..873b422b 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -829,11 +829,11 @@ class Connection(object): @property def supports_heartbeats(self): - return self.transport.supports_heartbeats + return self.transport.implements.heartbeats @property def is_evented(self): - return self.transport.supports_ev + return self.transport.implements.async BrokerConnection = Connection diff --git a/kombu/tests/async/aws/__init__.py b/kombu/tests/async/aws/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/kombu/tests/async/aws/__init__.py diff --git a/kombu/tests/async/aws/sqs/__init__.py b/kombu/tests/async/aws/sqs/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/kombu/tests/async/aws/sqs/__init__.py diff --git a/kombu/tests/async/aws/sqs/test_connection.py b/kombu/tests/async/aws/sqs/test_connection.py new file mode 100644 index 00000000..1ea6ba09 --- /dev/null +++ b/kombu/tests/async/aws/sqs/test_connection.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from kombu.async.aws.sqs.connection import ( + AsyncSQSConnection, Attributes, BatchResults, +) +from kombu.async.aws.sqs.message import AsyncMessage +from kombu.async.aws.sqs.queue import AsyncQueue +from kombu.utils import uuid + +from kombu.tests.case import HubCase, PromiseMock, Mock + + +class test_AsyncSQSConnection(HubCase): + + def setup(self): + self.x = AsyncSQSConnection('ak', 'sk') + self.x.get_object = Mock(name='X.get_object') + self.x.get_status = Mock(name='X.get_status') + self.x.get_list = Mock(nanme='X.get_list') + self.callback = PromiseMock(name='callback') + + def test_default_region(self): + self.assertTrue(self.x.region) + self.assertTrue(issubclass( + self.x.region.connection_cls, AsyncSQSConnection, + )) + + def test_create_queue(self): + self.x.create_queue('foo', callback=self.callback) + self.x.get_object.assert_called_with( + 'CreateQueue', {'QueueName': 'foo'}, AsyncQueue, + callback=self.callback, + ) + + def test_create_queue__with_visibility_timeout(self): + self.x.create_queue( + 'foo', visibility_timeout=33, callback=self.callback, + ) + self.x.get_object.assert_called_with( + 'CreateQueue', { + 'QueueName': 'foo', + 'DefaultVisibilityTimeout': '33' + }, + AsyncQueue, callback=self.callback + ) + + def test_delete_queue(self): + queue = Mock(name='queue') + self.x.delete_queue(queue, callback=self.callback) + self.x.get_status.assert_called_with( + 'DeleteQueue', None, queue.id, callback=self.callback, + ) + + def test_get_queue_attributes(self): + queue = Mock(name='queue') + self.x.get_queue_attributes( + queue, attribute='QueueSize', callback=self.callback, + ) + self.x.get_object.assert_called_with( + 'GetQueueAttributes', {'AttributeName': 'QueueSize'}, + Attributes, queue.id, callback=self.callback, + ) + + def test_set_queue_attribute(self): + queue = Mock(name='queue') + self.x.set_queue_attribute( + queue, 'Expires', '3600', callback=self.callback, + ) + self.x.get_status.assert_called_with( + 'SetQueueAttribute', { + 'Attribute.Name': 'Expires', + 'Attribute.Value': '3600', + }, + queue.id, callback=self.callback, + ) + + def test_receive_message(self): + queue = Mock(name='queue') + self.x.receive_message(queue, 4, callback=self.callback) + self.x.get_list.assert_called_with( + 'ReceiveMessage', {'MaxNumberOfMessages': 4}, + [('Message', queue.message_class)], + queue.id, callback=self.callback, + ) + + def test_receive_message__with_visibility_timeout(self): + queue = Mock(name='queue') + self.x.receive_message(queue, 4, 3666, callback=self.callback) + self.x.get_list.assert_called_with( + 'ReceiveMessage', { + 'MaxNumberOfMessages': 4, + 'VisibilityTimeout': 3666, + }, + [('Message', queue.message_class)], + queue.id, callback=self.callback, + ) + + def test_receive_message__with_wait_time_seconds(self): + queue = Mock(name='queue') + self.x.receive_message( + queue, 4, wait_time_seconds=303, callback=self.callback, + ) + self.x.get_list.assert_called_with( + 'ReceiveMessage', { + 'MaxNumberOfMessages': 4, + 'WaitTimeSeconds': 303, + }, + [('Message', queue.message_class)], + queue.id, callback=self.callback, + ) + + def test_receive_message__with_attributes(self): + queue = Mock(name='queue') + self.x.receive_message( + queue, 4, attributes=['foo', 'bar'], callback=self.callback, + ) + self.x.get_list.assert_called_with( + 'ReceiveMessage', { + 'AttributeName.1': 'foo', + 'AttributeName.2': 'bar', + 'MaxNumberOfMessages': 4, + }, + [('Message', queue.message_class)], + queue.id, callback=self.callback, + ) + + def MockMessage(self, id=None, receipt_handle=None, body=None): + m = Mock(name='message') + m.id = id or uuid() + m.receipt_handle = receipt_handle or uuid() + m._body = body + + def _get_body(): + return m._body + m.get_body.side_effect = _get_body + + def _set_body(value): + m._body = value + m.set_body.side_effect = _set_body + + return m + + def test_delete_message(self): + queue = Mock(name='queue') + message = self.MockMessage() + self.x.delete_message(queue, message, callback=self.callback) + self.x.get_status.assert_called_with( + 'DeleteMessage', {'ReceiptHandle': message.receipt_handle}, + queue.id, callback=self.callback, + ) + + def test_delete_message_batch(self): + queue = Mock(name='queue') + messages = [self.MockMessage('1', 'r1'), + self.MockMessage('2', 'r2')] + self.x.delete_message_batch(queue, messages, callback=self.callback) + self.x.get_object.assert_called_with( + 'DeleteMessageBatch', { + 'DeleteMessageBatchRequestEntry.1.Id': '1', + 'DeleteMessageBatchRequestEntry.1.ReceiptHandle': 'r1', + 'DeleteMessageBatchRequestEntry.2.Id': '2', + 'DeleteMessageBatchRequestEntry.2.ReceiptHandle': 'r2', + }, + BatchResults, queue.id, verb='POST', callback=self.callback, + ) + + def test_send_message(self): + queue = Mock(name='queue') + self.x.send_message(queue, 'hello', callback=self.callback) + self.x.get_object.assert_called_with( + 'SendMessage', {'MessageBody': 'hello'}, + AsyncMessage, queue.id, verb='POST', callback=self.callback, + ) + + def test_send_message__with_delay_seconds(self): + queue = Mock(name='queue') + self.x.send_message( + queue, 'hello', delay_seconds='303', callback=self.callback, + ) + self.x.get_object.assert_called_with( + 'SendMessage', {'MessageBody': 'hello', 'DelaySeconds': 303}, + AsyncMessage, queue.id, verb='POST', callback=self.callback, + ) + + def test_send_message_batch(self): + queue = Mock(name='queue') + messages = [self.MockMessage('1', 'r1', 'A'), + self.MockMessage('2', 'r2', 'B')] + self.x.send_message_batch( + queue, [(m.id, m.get_body(), 303) for m in messages], + callback=self.callback + ) + self.x.get_object.assert_called_with( + 'SendMessageBatch', { + 'SendMessageBatchRequestEntry.1.Id': '1', + 'SendMessageBatchRequestEntry.1.MessageBody': 'A', + 'SendMessageBatchRequestEntry.1.DelaySeconds': 303, + 'SendMessageBatchRequestEntry.2.Id': '2', + 'SendMessageBatchRequestEntry.2.MessageBody': 'B', + 'SendMessageBatchRequestEntry.2.DelaySeconds': 303, + }, + BatchResults, queue.id, verb='POST', callback=self.callback, + ) + + def test_change_message_visibility(self): + queue = Mock(name='queue') + self.x.change_message_visibility( + queue, 'rcpt', 33, callback=self.callback, + ) + self.x.get_status.assert_called_with( + 'ChangeMessageVisibility', { + 'ReceiptHandle': 'rcpt', + 'VisibilityTimeout': 33, + }, + queue.id, callback=self.callback, + ) + + def test_change_message_visibility_batch(self): + queue = Mock(name='queue') + messages = [ + (self.MockMessage('1', 'r1'), 303), + (self.MockMessage('2', 'r2'), 909), + ] + self.x.change_message_visibility_batch( + queue, messages, callback=self.callback, + ) + + def preamble(n): + return '.'.join(['ChangeMessageVisibilityBatchRequestEntry', n]) + + self.x.get_object.assert_called_with( + 'ChangeMessageVisibilityBatch', { + preamble('1.Id'): '1', + preamble('1.ReceiptHandle'): 'r1', + preamble('1.VisibilityTimeout'): 303, + preamble('2.Id'): '2', + preamble('2.ReceiptHandle'): 'r2', + preamble('2.VisibilityTimeout'): 909, + }, + BatchResults, queue.id, verb='POST', callback=self.callback, + ) + + def test_get_all_queues(self): + self.x.get_all_queues(callback=self.callback) + self.x.get_list.assert_called_with( + 'ListQueues', {}, [('QueueUrl', AsyncQueue)], + callback=self.callback, + ) + + def test_get_all_queues__with_prefix(self): + self.x.get_all_queues(prefix='kombu.', callback=self.callback) + self.x.get_list.assert_called_with( + 'ListQueues', {'QueueNamePrefix': 'kombu.'}, + [('QueueUrl', AsyncQueue)], + callback=self.callback, + ) + + def MockQueue(self, url): + q = Mock(name='Queue') + q.url = url + return q + + def test_get_queue(self): + self.x.get_queue('foo', callback=self.callback) + self.assertTrue(self.x.get_list.called) + on_ready = self.x.get_list.call_args[1]['callback'] + queues = [ + self.MockQueue('/queues/bar'), + self.MockQueue('/queues/baz'), + self.MockQueue('/queues/foo'), + ] + on_ready(queues) + self.callback.assert_called_with(queues[-1]) + + self.x.get_list.assert_called_with( + 'ListQueues', {'QueueNamePrefix': 'foo'}, + [('QueueUrl', AsyncQueue)], + callback=on_ready, + ) + + def test_get_dead_letter_source_queues(self): + queue = Mock(name='queue') + self.x.get_dead_letter_source_queues(queue, callback=self.callback) + self.x.get_list.assert_called_with( + 'ListDeadLetterSourceQueues', {'QueueUrl': queue.url}, + [('QueueUrl', AsyncQueue)], callback=self.callback, + ) + + def test_add_permission(self): + queue = Mock(name='queue') + self.x.add_permission( + queue, 'label', 'accid', 'action', callback=self.callback, + ) + self.x.get_status.assert_called_with( + 'AddPermission', { + 'Label': 'label', + 'AWSAccountId': 'accid', + 'ActionName': 'action', + }, + queue.id, callback=self.callback, + ) + + def test_remove_permission(self): + queue = Mock(name='queue') + self.x.remove_permission(queue, 'label', callback=self.callback) + self.x.get_status.assert_called_with( + 'RemovePermission', {'Label': 'label'}, queue.id, + callback=self.callback, + ) diff --git a/kombu/tests/async/aws/sqs/test_message.py b/kombu/tests/async/aws/sqs/test_message.py new file mode 100644 index 00000000..e7f32f02 --- /dev/null +++ b/kombu/tests/async/aws/sqs/test_message.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from kombu.async.aws.sqs.message import AsyncMessage + +from kombu.tests.case import HubCase, PromiseMock, Mock +from kombu.utils import uuid + + +class test_AsyncMessage(HubCase): + + def setup(self): + self.queue = Mock(name='queue') + self.callback = PromiseMock(name='callback') + self.x = AsyncMessage(self.queue, 'body') + self.x.receipt_handle = uuid() + + def test_delete(self): + self.assertTrue(self.x.delete(callback=self.callback)) + self.x.queue.delete_message.assert_called_with( + self.x, self.callback, + ) + + self.x.queue = None + self.assertIsNone(self.x.delete(callback=self.callback)) + + def test_change_visibility(self): + self.assertTrue(self.x.change_visibility(303, callback=self.callback)) + self.x.queue.connection.change_message_visibility.assert_called_with( + self.x.queue, self.x.receipt_handle, 303, self.callback, + ) + self.x.queue = None + self.assertIsNone(self.x.change_visibility( + 303, callback=self.callback, + )) diff --git a/kombu/tests/async/aws/sqs/test_queue.py b/kombu/tests/async/aws/sqs/test_queue.py new file mode 100644 index 00000000..417b09d3 --- /dev/null +++ b/kombu/tests/async/aws/sqs/test_queue.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from kombu.async.aws.sqs.message import AsyncMessage +from kombu.async.aws.sqs.queue import AsyncQueue, BotoClientError + +from kombu.tests.case import HubCase, PromiseMock, Mock + + +class test_AsyncQueue(HubCase): + + def setup(self): + self.conn = Mock(name='connection') + self.x = AsyncQueue(self.conn, '/url') + self.callback = PromiseMock(name='callback') + + def test_message_class(self): + self.assertTrue(issubclass(self.x.message_class, AsyncMessage)) + + def test_get_attributes(self): + self.x.get_attributes(attributes='QueueSize', callback=self.callback) + self.x.connection.get_queue_attributes.assert_called_with( + self.x, 'QueueSize', self.callback, + ) + + def test_set_attribute(self): + self.x.set_attribute('key', 'value', callback=self.callback) + self.x.connection.set_queue_attribute.assert_called_with( + self.x, 'key', 'value', self.callback, + ) + + def test_get_timeout(self): + self.x.get_timeout(callback=self.callback) + self.assertTrue(self.x.connection.get_queue_attributes.called) + on_ready = self.x.connection.get_queue_attributes.call_args[0][2] + self.x.connection.get_queue_attributes.assert_called_with( + self.x, 'VisibilityTimeout', on_ready, + ) + + on_ready({'VisibilityTimeout': '303'}) + self.callback.assert_called_with(303) + + def test_set_timeout(self): + self.x.set_timeout(808, callback=self.callback) + self.assertTrue(self.x.connection.set_queue_attribute.called) + on_ready = self.x.connection.set_queue_attribute.call_args[0][3] + self.x.connection.set_queue_attribute.assert_called_with( + self.x, 'VisibilityTimeout', 808, on_ready, + ) + on_ready(808) + self.callback.assert_called_with(808) + self.assertEqual(self.x.visibility_timeout, 808) + + on_ready(None) + self.assertEqual(self.x.visibility_timeout, 808) + + def test_add_permission(self): + self.x.add_permission( + 'label', 'accid', 'action', callback=self.callback, + ) + self.x.connection.add_permission.assert_called_with( + self.x, 'label', 'accid', 'action', self.callback, + ) + + def test_remove_permission(self): + self.x.remove_permission('label', callback=self.callback) + self.x.connection.remove_permission.assert_called_with( + self.x, 'label', self.callback, + ) + + def test_read(self): + self.x.read(visibility_timeout=909, callback=self.callback) + self.assertTrue(self.x.connection.receive_message.called) + on_ready = self.x.connection.receive_message.call_args[1]['callback'] + self.x.connection.receive_message.assert_called_with( + self.x, number_messages=1, visibility_timeout=909, + attributes=None, wait_time_seconds=None, callback=on_ready, + ) + + messages = [Mock(name='message1')] + on_ready(messages) + + self.callback.assert_called_with(messages[0]) + + def MockMessage(self, id, md5): + m = Mock(name='Message-{0}'.format(id)) + m.id = id + m.md5 = md5 + return m + + def test_write(self): + message = self.MockMessage('id1', 'digest1') + self.x.write(message, delay_seconds=303, callback=self.callback) + self.assertTrue(self.x.connection.send_message.called) + on_ready = self.x.connection.send_message.call_args[1]['callback'] + self.x.connection.send_message.assert_called_with( + self.x, message.get_body_encoded(), 303, + callback=on_ready, + ) + + new_message = self.MockMessage('id2', 'digest2') + on_ready(new_message) + self.assertEqual(message.id, 'id2') + self.assertEqual(message.md5, 'digest2') + + def test_write_batch(self): + messages = [('id1', 'A', 0), ('id2', 'B', 303)] + self.x.write_batch(messages, callback=self.callback) + self.x.connection.send_message_batch.assert_called_with( + self.x, messages, callback=self.callback, + ) + + def test_delete_message(self): + message = self.MockMessage('id1', 'digest1') + self.x.delete_message(message, callback=self.callback) + self.x.connection.delete_message.assert_called_with( + self.x, message, self.callback, + ) + + def test_delete_message_batch(self): + messages = [ + self.MockMessage('id1', 'r1'), + self.MockMessage('id2', 'r2'), + ] + self.x.delete_message_batch(messages, callback=self.callback) + self.x.connection.delete_message_batch.assert_called_with( + self.x, messages, callback=self.callback, + ) + + def test_change_message_visibility_batch(self): + messages = [ + (self.MockMessage('id1', 'r1'), 303), + (self.MockMessage('id2', 'r2'), 909), + ] + self.x.change_message_visibility_batch( + messages, callback=self.callback, + ) + self.x.connection.change_message_visibility_batch.assert_called_with( + self.x, messages, callback=self.callback, + ) + + def test_delete(self): + self.x.delete(callback=self.callback) + self.x.connection.delete_queue.assert_called_with( + self.x, callback=self.callback, + ) + + def test_count(self): + self.x.count(callback=self.callback) + self.assertTrue(self.x.connection.get_queue_attributes.called) + on_ready = self.x.connection.get_queue_attributes.call_args[0][2] + self.x.connection.get_queue_attributes.assert_called_with( + self.x, 'ApproximateNumberOfMessages', on_ready, + ) + + on_ready({'ApproximateNumberOfMessages': '909'}) + self.callback.assert_called_with(909) + + def test_interface__count_slow(self): + with self.assertRaises(BotoClientError): + self.x.count_slow() + + def test_interface__dump(self): + with self.assertRaises(BotoClientError): + self.x.dump() + + def test_interface__save_to_file(self): + with self.assertRaises(BotoClientError): + self.x.save_to_file() + + def test_interface__save_to_filename(self): + with self.assertRaises(BotoClientError): + self.x.save_to_filename() + + def test_interface__save(self): + with self.assertRaises(BotoClientError): + self.x.save() + + def test_interface__save_to_s3(self): + with self.assertRaises(BotoClientError): + self.x.save_to_s3() + + def test_interface__load_from_s3(self): + with self.assertRaises(BotoClientError): + self.x.load_from_s3() + + def test_interface__load_from_file(self): + with self.assertRaises(BotoClientError): + self.x.load_from_file() + + def test_interface__load_from_filename(self): + with self.assertRaises(BotoClientError): + self.x.load_from_filename() + + def test_interface__load(self): + with self.assertRaises(BotoClientError): + self.x.load() + + def test_interface__clear(self): + with self.assertRaises(BotoClientError): + self.x.clear() diff --git a/kombu/tests/async/aws/sqs/test_sqs.py b/kombu/tests/async/aws/sqs/test_sqs.py new file mode 100644 index 00000000..7625c402 --- /dev/null +++ b/kombu/tests/async/aws/sqs/test_sqs.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from kombu.async.aws.sqs import regions, connect_to_region +from kombu.async.aws.sqs.connection import AsyncSQSConnection + +from kombu.tests.case import HubCase, Mock, patch + + +class test_connect_to_region(HubCase): + + def test_using_async_connection(self): + for region in regions(): + self.assertIs(region.connection_cls, AsyncSQSConnection) + + def test_connect_to_region(self): + with patch('kombu.async.aws.sqs.regions') as regions: + region = Mock(name='region') + region.name = 'us-west-1' + regions.return_value = [region] + conn = connect_to_region('us-west-1', kw=3.33) + self.assertIs(conn, region.connect.return_value) + region.connect.assert_called_with(kw=3.33) + + self.assertIsNone(connect_to_region('foo')) diff --git a/kombu/tests/async/aws/test_aws.py b/kombu/tests/async/aws/test_aws.py new file mode 100644 index 00000000..4d7b4163 --- /dev/null +++ b/kombu/tests/async/aws/test_aws.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from kombu.async.aws import connect_sqs + +from kombu.tests.case import HubCase + + +class test_connect_sqs(HubCase): + + def test_connection(self): + x = connect_sqs('AAKI', 'ASAK') + self.assertTrue(x) + self.assertTrue(x.connection) diff --git a/kombu/tests/async/aws/test_connection.py b/kombu/tests/async/aws/test_connection.py new file mode 100644 index 00000000..efca9695 --- /dev/null +++ b/kombu/tests/async/aws/test_connection.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import + +from contextlib import contextmanager + +from amqp.promise import Thenable + +from kombu.exceptions import HttpError +from kombu.five import StringIO + +from kombu.async import http +from kombu.async.aws.connection import ( + AsyncHTTPConnection, + AsyncHTTPSConnection, + AsyncHTTPResponse, + AsyncConnection, + AsyncAWSAuthConnection, + AsyncAWSQueryConnection, +) + +from kombu.tests.case import HubCase, PromiseMock, Mock, patch + +# Not currently working +VALIDATES_CERT = False + + +def passthrough(*args, **kwargs): + m = Mock(*args, **kwargs) + + def side_effect(ret): + return ret + m.side_effect = side_effect + return m + + +class test_AsyncHTTPConnection(HubCase): + + def test_AsyncHTTPSConnection(self): + x = AsyncHTTPSConnection('aws.vandelay.com') + self.assertEqual(x.scheme, 'https') + + def test_http_client(self): + x = AsyncHTTPConnection('aws.vandelay.com') + self.assertIs(x.http_client, http.get_client()) + client = Mock(name='http_client') + y = AsyncHTTPConnection('aws.vandelay.com', http_client=client) + self.assertIs(y.http_client, client) + + def test_args(self): + x = AsyncHTTPConnection( + 'aws.vandelay.com', 8083, strict=True, timeout=33.3, + ) + self.assertEqual(x.host, 'aws.vandelay.com') + self.assertEqual(x.port, 8083) + self.assertTrue(x.strict) + self.assertEqual(x.timeout, 33.3) + self.assertEqual(x.scheme, 'http') + + def test_request(self): + x = AsyncHTTPConnection('aws.vandelay.com') + x.request('PUT', '/importer-exporter') + self.assertEqual(x.path, '/importer-exporter') + self.assertEqual(x.method, 'PUT') + + def test_request_with_body_buffer(self): + x = AsyncHTTPConnection('aws.vandelay.com') + body = Mock(name='body') + body.read.return_value = 'Vandelay Industries' + x.request('PUT', '/importer-exporter', body) + self.assertEqual(x.method, 'PUT') + self.assertEqual(x.path, '/importer-exporter') + self.assertEqual(x.body, 'Vandelay Industries') + body.read.assert_called_with() + + def test_request_with_body_text(self): + x = AsyncHTTPConnection('aws.vandelay.com') + x.request('PUT', '/importer-exporter', 'Vandelay Industries') + self.assertEqual(x.method, 'PUT') + self.assertEqual(x.path, '/importer-exporter') + self.assertEqual(x.body, 'Vandelay Industries') + + def test_request_with_headers(self): + x = AsyncHTTPConnection('aws.vandelay.com') + headers = {'Proxy': 'proxy.vandelay.com'} + x.request('PUT', '/importer-exporter', None, headers) + self.assertIn('Proxy', dict(x.headers)) + self.assertEqual(dict(x.headers)['Proxy'], 'proxy.vandelay.com') + + def assertRequestCreatedWith(self, url, conn): + conn.Request.assert_called_with( + url, method=conn.method, + headers=http.Headers(conn.headers), body=conn.body, + connect_timeout=conn.timeout, request_timeout=conn.timeout, + validate_cert=VALIDATES_CERT, + ) + + def test_getrequest_AsyncHTTPSConnection(self): + x = AsyncHTTPSConnection('aws.vandelay.com') + x.Request = Mock(name='Request') + x.getrequest() + self.assertRequestCreatedWith('https://aws.vandelay.com/', x) + + def test_getrequest_nondefault_port(self): + x = AsyncHTTPConnection('aws.vandelay.com', port=8080) + x.Request = Mock(name='Request') + x.getrequest() + self.assertRequestCreatedWith('http://aws.vandelay.com:8080/', x) + + y = AsyncHTTPSConnection('aws.vandelay.com', port=8443) + y.Request = Mock(name='Request') + y.getrequest() + self.assertRequestCreatedWith('https://aws.vandelay.com:8443/', y) + + def test_getresponse(self): + client = Mock(name='client') + client.add_request = passthrough(name='client.add_request') + x = AsyncHTTPConnection('aws.vandelay.com', http_client=client) + x.Response = Mock(name='x.Response') + request = x.getresponse() + x.http_client.add_request.assert_called_with(request) + self.assertIsInstance(request, Thenable) + self.assertIsInstance(request.on_ready, Thenable) + + response = Mock(name='Response') + request.on_ready(response) + x.Response.assert_called_with(response) + + def test_getresponse__real_response(self): + client = Mock(name='client') + client.add_request = passthrough(name='client.add_request') + callback = PromiseMock(name='callback') + x = AsyncHTTPConnection('aws.vandelay.com', http_client=client) + request = x.getresponse(callback) + x.http_client.add_request.assert_called_with(request) + + buf = StringIO() + buf.write('The quick brown fox jumps') + + headers = http.Headers({'X-Foo': 'Hello', 'X-Bar': 'World'}) + + response = http.Response(request, 200, headers, buf) + request.on_ready(response) + self.assertTrue(callback.called) + wresponse = callback.call_args[0][0] + + self.assertEqual(wresponse.read(), 'The quick brown fox jumps') + self.assertEqual(wresponse.status, 200) + self.assertEqual(wresponse.getheader('X-Foo'), 'Hello') + self.assertDictEqual(dict(wresponse.getheaders()), headers) + self.assertTrue(wresponse.msg) + self.assertTrue(wresponse.msg) + self.assertTrue(repr(wresponse)) + + def test_repr(self): + self.assertTrue(repr(AsyncHTTPConnection('aws.vandelay.com'))) + + def test_putrequest(self): + x = AsyncHTTPConnection('aws.vandelay.com') + x.putrequest('UPLOAD', '/new') + self.assertEqual(x.method, 'UPLOAD') + self.assertEqual(x.path, '/new') + + def test_putheader(self): + x = AsyncHTTPConnection('aws.vandelay.com') + x.putheader('X-Foo', 'bar') + self.assertListEqual(x.headers, [('X-Foo', 'bar')]) + x.putheader('X-Bar', 'baz') + self.assertListEqual(x.headers, [ + ('X-Foo', 'bar'), + ('X-Bar', 'baz'), + ]) + + def test_send(self): + x = AsyncHTTPConnection('aws.vandelay.com') + x.send('foo') + self.assertEqual(x.body, 'foo') + x.send('bar') + self.assertEqual(x.body, 'foobar') + + def test_interface(self): + x = AsyncHTTPConnection('aws.vandelay.com') + self.assertIsNone(x.set_debuglevel(3)) + self.assertIsNone(x.connect()) + self.assertIsNone(x.close()) + self.assertIsNone(x.endheaders()) + + +class test_AsyncHTTPResponse(HubCase): + + def test_with_error(self): + r = Mock(name='response') + r.error = HttpError(404, 'NotFound') + x = AsyncHTTPResponse(r) + self.assertEqual(x.reason, 'NotFound') + + r.error = None + self.assertFalse(x.reason) + + +class test_AsyncConnection(HubCase): + + def test_client(self): + x = AsyncConnection() + self.assertIs(x._httpclient, http.get_client()) + client = Mock(name='client') + y = AsyncConnection(http_client=client) + self.assertIs(y._httpclient, client) + + def test_get_http_connection(self): + x = AsyncConnection(client=Mock(name='client')) + self.assertIsInstance( + x.get_http_connection('aws.vandelay.com', 80, False), + AsyncHTTPConnection, + ) + self.assertIsInstance( + x.get_http_connection('aws.vandelay.com', 443, True), + AsyncHTTPSConnection, + ) + + conn = x.get_http_connection('aws.vandelay.com', 80, False) + self.assertIs(conn.http_client, x._httpclient) + self.assertEqual(conn.host, 'aws.vandelay.com') + self.assertEqual(conn.port, 80) + + +class test_AsyncAWSAuthConnection(HubCase): + + @patch('boto.log', create=True) + def test_make_request(self, _): + x = AsyncAWSAuthConnection('aws.vandelay.com', + http_client=Mock(name='client')) + Conn = x.get_http_connection = Mock(name='get_http_connection') + callback = PromiseMock(name='callback') + ret = x.make_request('GET', '/foo', callback=callback) + self.assertIs(ret, callback) + self.assertTrue(Conn.return_value.request.called) + Conn.return_value.getresponse.assert_called_with( + callback=callback, + ) + + @patch('boto.log', create=True) + def test_mexe(self, _): + x = AsyncAWSAuthConnection('aws.vandelay.com', + http_client=Mock(name='client')) + Conn = x.get_http_connection = Mock(name='get_http_connection') + request = x.build_base_http_request('GET', 'foo', '/auth') + callback = PromiseMock(name='callback') + x._mexe(request, callback=callback) + Conn.return_value.request.assert_called_with( + request.method, request.path, request.body, request.headers, + ) + Conn.return_value.getresponse.assert_called_with( + callback=callback, + ) + + no_callback_ret = x._mexe(request) + self.assertIsInstance( + no_callback_ret, Thenable, '_mexe always returns promise', + ) + + @patch('boto.log', create=True) + def test_mexe__with_sender(self, _): + x = AsyncAWSAuthConnection('aws.vandelay.com', + http_client=Mock(name='client')) + Conn = x.get_http_connection = Mock(name='get_http_connection') + request = x.build_base_http_request('GET', 'foo', '/auth') + sender = Mock(name='sender') + callback = PromiseMock(name='callback') + x._mexe(request, sender=sender, callback=callback) + sender.assert_called_with( + Conn.return_value, request.method, request.path, + request.body, request.headers, callback, + ) + + +class test_AsyncAWSQueryConnection(HubCase): + + def setup(self): + self.x = AsyncAWSQueryConnection('aws.vandelay.com', + http_client=Mock(name='client')) + + @patch('boto.log', create=True) + def test_make_request(self, _): + _mexe, self.x._mexe = self.x._mexe, Mock(name='_mexe') + Conn = self.x.get_http_connection = Mock(name='get_http_connection') + callback = PromiseMock(name='callback') + self.x.make_request( + 'action', {'foo': 1}, '/', 'GET', callback=callback, + ) + self.assertTrue(self.x._mexe.called) + request = self.x._mexe.call_args[0][0] + self.assertEqual(request.params['Action'], 'action') + self.assertEqual(request.params['Version'], self.x.APIVersion) + + ret = _mexe(request, callback=callback) + self.assertIs(ret, callback) + self.assertTrue(Conn.return_value.request.called) + Conn.return_value.getresponse.assert_called_with( + callback=callback, + ) + + @patch('boto.log', create=True) + def test_make_request__no_action(self, _): + self.x._mexe = Mock(name='_mexe') + self.x.get_http_connection = Mock(name='get_http_connection') + callback = PromiseMock(name='callback') + self.x.make_request( + None, {'foo': 1}, '/', 'GET', callback=callback, + ) + self.assertTrue(self.x._mexe.called) + request = self.x._mexe.call_args[0][0] + self.assertNotIn('Action', request.params) + self.assertEqual(request.params['Version'], self.x.APIVersion) + + @contextmanager + def mock_sax_parse(self, parser): + with patch('kombu.async.aws.connection.sax_parse') as sax_parse: + with patch('kombu.async.aws.connection.XmlHandler') as xh: + + def effect(body, h): + return parser(xh.call_args[0][0], body, h) + sax_parse.side_effect = effect + yield (sax_parse, xh) + self.assertTrue(sax_parse.called) + + def Response(self, status, body): + r = Mock(name='response') + r.status = status + r.read.return_value = body + return r + + @contextmanager + def mock_make_request(self): + self.x.make_request = Mock(name='make_request') + callback = PromiseMock(name='callback') + yield callback + + def assert_make_request_called(self): + self.assertTrue(self.x.make_request.called) + return self.x.make_request.call_args[1]['callback'] + + def test_get_list(self): + with self.mock_make_request() as callback: + self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback) + on_ready = self.assert_make_request_called() + + def parser(dest, body, h): + dest.append('hi') + dest.append('there') + + with self.mock_sax_parse(parser): + on_ready(self.Response(200, 'hello')) + self.assertTrue(callback.called_with(['hi', 'there'])) + + def test_get_list_error(self): + with self.mock_make_request() as callback: + self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback) + on_ready = self.assert_make_request_called() + + with self.assertRaises(self.x.ResponseError): + on_ready(self.Response(404, 'Not found')) + + def test_get_object(self): + with self.mock_make_request() as callback: + + class Result(object): + parent = None + value = None + + def __init__(self, parent): + self.parent = parent + + self.x.get_object('action', {'p': 3.3}, Result, callback=callback) + on_ready = self.assert_make_request_called() + + def parser(dest, body, h): + dest.value = 42 + + with self.mock_sax_parse(parser): + on_ready(self.Response(200, 'hello')) + + self.assertTrue(callback.called) + result = callback.call_args[0][0] + self.assertEqual(result.value, 42) + self.assertTrue(result.parent) + + def test_get_object_error(self): + with self.mock_make_request() as callback: + self.x.get_object('action', {'p': 3.3}, object, callback=callback) + on_ready = self.assert_make_request_called() + + with self.assertRaises(self.x.ResponseError): + on_ready(self.Response(404, 'Not found')) + + def test_get_status(self): + with self.mock_make_request() as callback: + self.x.get_status('action', {'p': 3.3}, callback=callback) + on_ready = self.assert_make_request_called() + set_status_to = [True] + + def parser(dest, body, b): + dest.status = set_status_to[0] + + with self.mock_sax_parse(parser): + on_ready(self.Response(200, 'hello')) + callback.assert_called_with(True) + + set_status_to[0] = False + with self.mock_sax_parse(parser): + on_ready(self.Response(200, 'hello')) + callback.assert_called_with(False) + + def test_get_status_error(self): + with self.mock_make_request() as callback: + self.x.get_status('action', {'p': 3.3}, callback=callback) + on_ready = self.assert_make_request_called() + + with self.assertRaises(self.x.ResponseError): + on_ready(self.Response(404, 'Not found')) + + def test_get_status_error_empty_body(self): + with self.mock_make_request() as callback: + self.x.get_status('action', {'p': 3.3}, callback=callback) + on_ready = self.assert_make_request_called() + + with self.assertRaises(self.x.ResponseError): + on_ready(self.Response(200, '')) diff --git a/kombu/tests/async/test_http.py b/kombu/tests/async/test_http.py index 42da11a5..4566ab61 100644 --- a/kombu/tests/async/test_http.py +++ b/kombu/tests/async/test_http.py @@ -3,22 +3,21 @@ from __future__ import absolute_import from io import BytesIO from amqp import promise -from kombu.async import Hub from kombu.async import http -from kombu.async.http.base import BaseClient, normalize_header, header_parser +from kombu.async.http.base import BaseClient, normalize_header from kombu.exceptions import HttpError -from kombu.tests.case import Case, Mock +from kombu.tests.case import HubCase, Mock, PromiseMock -class test_Headers(Case): +class test_Headers(HubCase): def test_normalize(self): self.assertEqual(normalize_header('accept-encoding'), 'Accept-Encoding') -class test_Request(Case): +class test_Request(HubCase): def test_init(self): x = http.Request('http://foo', method='POST') @@ -35,7 +34,7 @@ class test_Request(Case): self.assertIsInstance(x.on_ready, promise) def test_then(self): - callback = Mock() + callback = PromiseMock(name='callback') x = http.Request('http://foo') x.then(callback) @@ -43,7 +42,7 @@ class test_Request(Case): callback.assert_called_with(1) -class test_Response(Case): +class test_Response(HubCase): def test_init(self): req = http.Request('http://foo') @@ -77,7 +76,7 @@ class test_Response(Case): self.assertEqual(r.body, b'hello') -class test_BaseClient(Case): +class test_BaseClient(HubCase): def test_init(self): c = BaseClient(Mock(name='hub')) @@ -138,45 +137,11 @@ class test_BaseClient(Case): c.close.assert_called_with() -class test_Client(Case): - - def test_get_request(self): - hub = Hub() - callback = Mock(name='callback') - - def on_ready(response): - pass #print('{0.effective_url} -> {0.code}'.format(response)) - requests = [] - for i in range(1000): - requests.extend([ - http.Request( - 'http://localhost:8000/README.rst', - on_ready=promise(on_ready, callback=callback), - ), - http.Request( - 'http://localhost:8000/AUTHORS', - on_ready=promise(on_ready, callback=callback), - ), - http.Request( - 'http://localhost:8000/pavement.py', - on_ready=promise(on_ready, callback=callback), - ), - http.Request( - 'http://localhost:8000/setup.py', - on_ready=promise(on_ready, callback=callback), - ), - http.Request( - 'http://localhost:8000/setup.py%s' % (i, ), - on_ready=promise(on_ready, callback=callback), - ), - ]) - client = http.Client(hub) - for request in requests: - client.perform(request) - - from kombu.five import monotonic - start_time = monotonic() - print('START PERFORM') - while callback.call_count < len(requests): - hub.run_once() - print('-END PERFORM: %r' % (monotonic() - start_time)) +class test_Client(HubCase): + + def test_get_client(self): + client = http.get_client() + self.assertIs(client.hub, self.hub) + client2 = http.get_client(self.hub) + self.assertIs(client2, client) + self.assertIs(client2.hub, self.hub) diff --git a/kombu/tests/async/test_hub.py b/kombu/tests/async/test_hub.py index 7d5d81cd..9476ac48 100644 --- a/kombu/tests/async/test_hub.py +++ b/kombu/tests/async/test_hub.py @@ -8,10 +8,10 @@ from kombu.tests.case import Case class test_Utils(Case): - def setUp(self): + def setup(self): self._prev_loop = get_event_loop() - def tearDown(self): + def teardown(self): set_event_loop(self._prev_loop) def test_get_set_event_loop(self): @@ -26,8 +26,8 @@ class test_Utils(Case): class test_Hub(Case): - def setUp(self): + def setup(self): self.hub = Hub() - def tearDown(self): + def teardown(self): self.hub.close() diff --git a/kombu/tests/case.py b/kombu/tests/case.py index e8b6d32b..88a18acb 100644 --- a/kombu/tests/case.py +++ b/kombu/tests/case.py @@ -27,10 +27,53 @@ call = mock.call class Case(unittest.TestCase): + def setUp(self): + self.setup() + + def tearDown(self): + self.teardown() + def assertItemsEqual(self, a, b, *args, **kwargs): return self.assertEqual(sorted(a), sorted(b), *args, **kwargs) assertSameElements = assertItemsEqual + def setup(self): + pass + + def teardown(self): + pass + + +def PromiseMock(*args, **kwargs): + m = Mock(*args, **kwargs) + + def on_throw(exc=None, *args, **kwargs): + if exc: + raise exc + raise + m.throw.side_effect = on_throw + m.set_error_state.side_effect = on_throw + m.throw1.side_effect = on_throw + return m + + +class HubCase(Case): + + def setUp(self): + from kombu.async import Hub, get_event_loop, set_event_loop + self._prev_hub = get_event_loop() + self.hub = Hub() + set_event_loop(self.hub) + super(HubCase, self).setUp() + + def tearDown(self): + try: + super(HubCase, self).tearDown() + finally: + from kombu.async import set_event_loop + if self._prev_hub is not None: + set_event_loop(self._prev_hub) + class Mock(mock.Mock): diff --git a/kombu/tests/test_compat.py b/kombu/tests/test_compat.py index b081cf0c..1601714e 100644 --- a/kombu/tests/test_compat.py +++ b/kombu/tests/test_compat.py @@ -76,7 +76,7 @@ class test_misc(Case): class test_Publisher(Case): - def setUp(self): + def setup(self): self.connection = Connection(transport=Transport) def test_constructor(self): @@ -127,7 +127,7 @@ class test_Publisher(Case): class test_Consumer(Case): - def setUp(self): + def setup(self): self.connection = Connection(transport=Transport) @patch('kombu.compat._iterconsume') @@ -261,7 +261,7 @@ class test_Consumer(Case): class test_ConsumerSet(Case): - def setUp(self): + def setup(self): self.connection = Connection(transport=Transport) def test_providing_channel(self): diff --git a/kombu/tests/test_compression.py b/kombu/tests/test_compression.py index 7d651ee2..21d4cf1e 100644 --- a/kombu/tests/test_compression.py +++ b/kombu/tests/test_compression.py @@ -9,7 +9,7 @@ from .case import Case, SkipTest, mask_modules class test_compression(Case): - def setUp(self): + def setup(self): try: import bz2 # noqa except ImportError: diff --git a/kombu/tests/test_connection.py b/kombu/tests/test_connection.py index 58790db0..fdc78f02 100644 --- a/kombu/tests/test_connection.py +++ b/kombu/tests/test_connection.py @@ -15,7 +15,7 @@ from .mocks import Transport class test_connection_utils(Case): - def setUp(self): + def setup(self): self.url = 'amqp://user:pass@localhost:5672/my/vhost' self.nopass = 'amqp://user@localhost:5672/my/vhost' self.expected = { @@ -144,7 +144,7 @@ class test_connection_utils(Case): class test_Connection(Case): - def setUp(self): + def setup(self): self.conn = Connection(port=5672, transport=Transport) def test_establish_connection(self): @@ -261,12 +261,12 @@ class test_Connection(Case): def test_supports_heartbeats(self): c = Connection(transport=Mock) - c.transport.supports_heartbeats = False + c.transport.implements.heartbeats = False self.assertFalse(c.supports_heartbeats) def test_is_evented(self): c = Connection(transport=Mock) - c.transport.supports_ev = False + c.transport.implements.async = False self.assertFalse(c.is_evented) def test_register_with_event_loop(self): @@ -491,7 +491,7 @@ class test_Connection_with_transport_options(Case): transport_options = {'pool_recycler': 3600, 'echo': True} - def setUp(self): + def setup(self): self.conn = Connection(port=5672, transport=Transport, transport_options=self.transport_options) diff --git a/kombu/tests/test_entities.py b/kombu/tests/test_entities.py index 165160f1..317614c6 100644 --- a/kombu/tests/test_entities.py +++ b/kombu/tests/test_entities.py @@ -186,7 +186,7 @@ class test_Exchange(Case): class test_Queue(Case): - def setUp(self): + def setup(self): self.exchange = Exchange('foo', 'direct') def test_hash(self): diff --git a/kombu/tests/test_log.py b/kombu/tests/test_log.py index bd1242c0..c3e91730 100644 --- a/kombu/tests/test_log.py +++ b/kombu/tests/test_log.py @@ -63,7 +63,7 @@ class test_safe_format(Case): class test_LogMixin(Case): - def setUp(self): + def setup(self): self.log = Log('Log', Mock()) self.logger = self.log.logger diff --git a/kombu/tests/test_messaging.py b/kombu/tests/test_messaging.py index c9573c2c..4f6c8411 100644 --- a/kombu/tests/test_messaging.py +++ b/kombu/tests/test_messaging.py @@ -15,7 +15,7 @@ from .mocks import Transport class test_Producer(Case): - def setUp(self): + def setup(self): self.exchange = Exchange('foo', 'direct') self.connection = Connection(transport=Transport) self.connection.connect() @@ -218,7 +218,7 @@ class test_Producer(Case): class test_Consumer(Case): - def setUp(self): + def setup(self): self.connection = Connection(transport=Transport) self.connection.connect() self.assertTrue(self.connection.connection.connected) diff --git a/kombu/tests/test_mixins.py b/kombu/tests/test_mixins.py index b80f0131..6f868b9e 100644 --- a/kombu/tests/test_mixins.py +++ b/kombu/tests/test_mixins.py @@ -110,7 +110,7 @@ class test_ConsumerMixin(Case): class test_ConsumerMixin_interface(Case): - def setUp(self): + def setup(self): self.c = ConsumerMixin() def test_get_consumers(self): diff --git a/kombu/tests/test_pidbox.py b/kombu/tests/test_pidbox.py index 357de656..d87e9fcd 100644 --- a/kombu/tests/test_pidbox.py +++ b/kombu/tests/test_pidbox.py @@ -16,7 +16,7 @@ class test_Mailbox(Case): def _handler(self, state): return self.stats['var'] - def setUp(self): + def setup(self): class Mailbox(pidbox.Mailbox): diff --git a/kombu/tests/test_pools.py b/kombu/tests/test_pools.py index 920c65a7..041c21c3 100644 --- a/kombu/tests/test_pools.py +++ b/kombu/tests/test_pools.py @@ -20,7 +20,7 @@ class test_ProducerPool(Case): def Producer(self, connection): return self.instance - def setUp(self): + def setup(self): self.connections = Mock() self.pool = self.Pool(self.connections, limit=10) diff --git a/kombu/tests/test_simple.py b/kombu/tests/test_simple.py index 53a4ac38..f14b86e4 100644 --- a/kombu/tests/test_simple.py +++ b/kombu/tests/test_simple.py @@ -19,14 +19,14 @@ class SimpleBase(Case): def _Queue(self, *args, **kwargs): raise NotImplementedError() - def setUp(self): + def setup(self): if not self.abstract: self.connection = Connection(transport='memory') with self.connection.channel() as channel: channel.exchange_declare('amq.direct') self.q = self.Queue(None, no_ack=True) - def tearDown(self): + def teardown(self): if not self.abstract: self.q.close() self.connection.close() diff --git a/kombu/tests/transport/test_SQS.py b/kombu/tests/transport/test_SQS.py index e4efb53f..0f596efe 100644 --- a/kombu/tests/transport/test_SQS.py +++ b/kombu/tests/transport/test_SQS.py @@ -18,7 +18,7 @@ try: except ImportError: # Boto must not be installed if the SQS transport fails to import, # so we skip all unit tests. Set SQS to None here, and it will be - # checked during the setUp() phase later. + # checked during the setup() phase later. SQS = None @@ -93,7 +93,7 @@ class test_Channel(Case): def handleMessageCallback(self, message): self.callback_message = message - def setUp(self): + def setup(self): """Mock the back-end SQS classes""" # Sanity check... if SQS is None, then it did not import and we # cannot execute our tests. @@ -135,7 +135,7 @@ class test_Channel(Case): # Lastly, make sure that we're set up to 'consume' this queue. self.channel.basic_consume(self.queue_name, - no_ack=True, + no_ack=False, callback=self.handleMessageCallback, consumer_tag='unittest') @@ -160,15 +160,16 @@ class test_Channel(Case): # Test getting a single message message = 'my test message' self.producer.publish(message) - results = self.channel._get_from_sqs(self.queue_name) + q = self.channel._new_queue(self.queue_name) + results = q.get_messages() self.assertEquals(len(results), 1) # Now test getting many messages - for i in xrange(3): + for i in range(3): message = 'message: {0}'.format(i) self.producer.publish(message) - results = self.channel._get_from_sqs(self.queue_name, count=3) + results = q.get_messages(num_messages=3) self.assertEquals(len(results), 3) def test_get_with_empty_list(self): @@ -186,10 +187,9 @@ class test_Channel(Case): message = 'message: %s' % i self.producer.publish(message) + q = self.channel._new_queue(self.queue_name) # Get the messages now - messages = self.channel._get_from_sqs( - self.queue_name, count=message_count, - ) + messages = q.get_messages(num_messages=message_count) # Now convert them to payloads payloads = self.channel._messages_to_python( @@ -209,15 +209,6 @@ class test_Channel(Case): results = self.queue(self.channel).get().payload self.assertEquals(message, results) - def test_puts_and_gets(self): - for i in xrange(3): - message = 'message: %s' % i - self.producer.publish(message) - - for i in xrange(3): - self.assertEquals('message: %s' % i, - self.queue(self.channel).get().payload) - def test_put_and_get_bulk(self): # With QoS.prefetch_count = 0 message = 'my test message' @@ -233,7 +224,7 @@ class test_Channel(Case): self.channel.qos.prefetch_count = 5 # Now, generate all the messages - for i in xrange(message_count): + for i in range(message_count): message = 'message: %s' % i self.producer.publish(message) @@ -241,10 +232,12 @@ class test_Channel(Case): # be 5 (message_count). results = self.channel._get_bulk(self.queue_name) self.assertEquals(5, len(results)) + for i, r in enumerate(results): + self.channel.qos.append(r, i) - # Now, do the get again, the number of messages returned should be 3. + # Now, do the get again, the number of messages returned should be 1. results = self.channel._get_bulk(self.queue_name) - self.assertEquals(3, len(results)) + self.assertEquals(len(results), 1) def test_drain_events_with_empty_list(self): def mock_can_consume(): @@ -262,11 +255,11 @@ class test_Channel(Case): self.channel.qos.prefetch_count = 5 # Now, generate all the messages - for i in xrange(message_count): + for i in range(message_count): self.producer.publish('message: %s' % i) # Now drain all the events - for i in xrange(message_count): + for i in range(message_count): self.channel.drain_events() # How many times was the SQSConnectionMock get_message method called? @@ -283,11 +276,11 @@ class test_Channel(Case): self.channel.qos.prefetch_count = None # Now, generate all the messages - for i in xrange(message_count): + for i in range(message_count): self.producer.publish('message: %s' % i) # Now drain all the events - for i in xrange(message_count): + for i in range(message_count): self.channel.drain_events() # How many times was the SQSConnectionMock get_message method called? diff --git a/kombu/tests/transport/test_amqplib.py b/kombu/tests/transport/test_amqplib.py index cf7d6150..39f94d28 100644 --- a/kombu/tests/transport/test_amqplib.py +++ b/kombu/tests/transport/test_amqplib.py @@ -37,10 +37,7 @@ class amqplibCase(Case): def setUp(self): if amqplib is None: raise SkipTest('amqplib not installed') - self.setup() - - def setup(self): - pass + super(amqplibCase, self).setUp() class test_Channel(amqplibCase): diff --git a/kombu/tests/transport/test_base.py b/kombu/tests/transport/test_base.py index 5c4a50d5..ac90e51f 100644 --- a/kombu/tests/transport/test_base.py +++ b/kombu/tests/transport/test_base.py @@ -10,7 +10,7 @@ from kombu.tests.case import Case, Mock class test_StdChannel(Case): - def setUp(self): + def setup(self): self.conn = Connection('memory://') self.channel = self.conn.channel() self.channel.queues.clear() @@ -40,7 +40,7 @@ class test_StdChannel(Case): class test_Message(Case): - def setUp(self): + def setup(self): self.conn = Connection('memory://') self.channel = self.conn.channel() self.message = Message(self.channel, delivery_tag=313) @@ -134,7 +134,14 @@ class test_interface(Case): self.assertTrue(Transport(None).driver_version()) def test_register_with_event_loop(self): - Transport(None).register_with_event_loop(Mock(name='loop')) + Transport(None).register_with_event_loop( + Mock(name='connection'), Mock(name='loop'), + ) + + def test_unregister_from_event_loop(self): + Transport(None).unregister_from_event_loop( + Mock(name='connection'), Mock(name='loop'), + ) def test_manager(self): self.assertTrue(Transport(None).manager) diff --git a/kombu/tests/transport/test_filesystem.py b/kombu/tests/transport/test_filesystem.py index 0649a8d0..1e2a37cd 100644 --- a/kombu/tests/transport/test_filesystem.py +++ b/kombu/tests/transport/test_filesystem.py @@ -10,7 +10,7 @@ from kombu.tests.case import Case, SkipTest class test_FilesystemTransport(Case): - def setUp(self): + def setup(self): if sys.platform == 'win32': raise SkipTest('Needs win32con module') try: diff --git a/kombu/tests/transport/test_librabbitmq.py b/kombu/tests/transport/test_librabbitmq.py index a50b2624..7625bbc3 100644 --- a/kombu/tests/transport/test_librabbitmq.py +++ b/kombu/tests/transport/test_librabbitmq.py @@ -12,7 +12,7 @@ from kombu.tests.case import Case, Mock, SkipTest, patch class lrmqCase(Case): - def setUp(self): + def setup(self): if librabbitmq is None: raise SkipTest('librabbitmq is not installed') @@ -61,8 +61,8 @@ class test_Channel(lrmqCase): class test_Transport(lrmqCase): - def setUp(self): - super(test_Transport, self).setUp() + def setup(self): + super(test_Transport, self).setup() self.client = Mock(name='client') self.T = librabbitmq.Transport(self.client) diff --git a/kombu/tests/transport/test_memory.py b/kombu/tests/transport/test_memory.py index 605527f4..c052aa7b 100644 --- a/kombu/tests/transport/test_memory.py +++ b/kombu/tests/transport/test_memory.py @@ -9,7 +9,7 @@ from kombu.tests.case import Case class test_MemoryTransport(Case): - def setUp(self): + def setup(self): self.c = Connection(transport='memory') self.e = Exchange('test_transport_memory') self.q = Queue('test_transport_memory', diff --git a/kombu/tests/transport/test_pyamqp.py b/kombu/tests/transport/test_pyamqp.py index d6a910b4..da557338 100644 --- a/kombu/tests/transport/test_pyamqp.py +++ b/kombu/tests/transport/test_pyamqp.py @@ -24,7 +24,7 @@ class MockConnection(dict): class test_Channel(Case): - def setUp(self): + def setup(self): if pyamqp is None: raise SkipTest('py-amqp not installed') @@ -80,7 +80,7 @@ class test_Channel(Case): class test_Transport(Case): - def setUp(self): + def setup(self): if pyamqp is None: raise SkipTest('py-amqp not installed') self.connection = Connection('pyamqp://') @@ -136,7 +136,7 @@ class test_Transport(Case): class test_pyamqp(Case): - def setUp(self): + def setup(self): if pyamqp is None: raise SkipTest('py-amqp not installed') diff --git a/kombu/tests/transport/test_redis.py b/kombu/tests/transport/test_redis.py index 48f7c6be..285f6004 100644 --- a/kombu/tests/transport/test_redis.py +++ b/kombu/tests/transport/test_redis.py @@ -220,7 +220,7 @@ class Transport(redis.Transport): class test_Channel(Case): - def setUp(self): + def setup(self): self.connection = self.create_connection() self.channel = self.connection.default_channel @@ -785,12 +785,12 @@ class test_Channel(Case): class test_Redis(Case): - def setUp(self): + def setup(self): self.connection = Connection(transport=Transport) self.exchange = Exchange('test_Redis', type='direct') self.queue = Queue('test_Redis', self.exchange, 'test_Redis') - def tearDown(self): + def teardown(self): self.connection.close() def test_publish__get(self): @@ -941,7 +941,7 @@ def _redis_modules(): class test_MultiChannelPoller(Case): - def setUp(self): + def setup(self): self.Poller = redis.MultiChannelPoller def test_on_poll_start(self): diff --git a/kombu/tests/transport/test_sqlalchemy.py b/kombu/tests/transport/test_sqlalchemy.py index 07055999..13cbd309 100644 --- a/kombu/tests/transport/test_sqlalchemy.py +++ b/kombu/tests/transport/test_sqlalchemy.py @@ -6,7 +6,7 @@ from kombu.tests.case import Case, SkipTest, patch class test_sqlalchemy(Case): - def setUp(self): + def setup(self): try: import sqlalchemy # noqa except ImportError: diff --git a/kombu/tests/transport/virtual/test_base.py b/kombu/tests/transport/virtual/test_base.py index d249c4e7..df17d918 100644 --- a/kombu/tests/transport/virtual/test_base.py +++ b/kombu/tests/transport/virtual/test_base.py @@ -33,10 +33,10 @@ class test_BrokerState(Case): class test_QoS(Case): - def setUp(self): + def setup(self): self.q = virtual.QoS(client().channel(), prefetch_count=10) - def tearDown(self): + def teardown(self): self.q._on_collect.cancel() def test_constructor(self): @@ -174,10 +174,10 @@ class test_AbstractChannel(Case): class test_Channel(Case): - def setUp(self): + def setup(self): self.channel = client().channel() - def tearDown(self): + def teardown(self): if self.channel._qos is not None: self.channel._qos._on_collect.cancel() @@ -518,7 +518,7 @@ class test_Channel(Case): class test_Transport(Case): - def setUp(self): + def setup(self): self.transport = client().transport def test_custom_polling_interval(self): diff --git a/kombu/tests/transport/virtual/test_exchange.py b/kombu/tests/transport/virtual/test_exchange.py index ad590afc..bb11da48 100644 --- a/kombu/tests/transport/virtual/test_exchange.py +++ b/kombu/tests/transport/virtual/test_exchange.py @@ -10,7 +10,7 @@ from kombu.tests.mocks import Transport class ExchangeCase(Case): type = None - def setUp(self): + def setup(self): if self.type: self.e = self.type(Connection(transport=Transport).channel()) @@ -74,8 +74,8 @@ class test_Topic(ExchangeCase): ('stock.us.*', None, 'rBar'), ] - def setUp(self): - super(test_Topic, self).setUp() + def setup(self): + super(test_Topic, self).setup() self.table = [(rkey, self.e.key_to_pattern(rkey), queue) for rkey, _, queue in self.table] diff --git a/kombu/tests/utils/test_encoding.py b/kombu/tests/utils/test_encoding.py index fd710c30..cc7e514e 100644 --- a/kombu/tests/utils/test_encoding.py +++ b/kombu/tests/utils/test_encoding.py @@ -39,7 +39,7 @@ class test_default_encoding(Case): class test_encoding_utils(Case): - def setUp(self): + def setup(self): if sys.version_info >= (3, 0): raise SkipTest('not relevant on py3k') @@ -58,12 +58,12 @@ class test_encoding_utils(Case): class test_safe_str(Case): - def setUp(self): + def setup(self): self._cencoding = patch('sys.getfilesystemencoding') self._encoding = self._cencoding.__enter__() self._encoding.return_value = 'ascii' - def tearDown(self): + def teardown(self): self._cencoding.__exit__() def test_when_bytes(self): diff --git a/kombu/tests/utils/test_utils.py b/kombu/tests/utils/test_utils.py index 0248a303..e0077b41 100644 --- a/kombu/tests/utils/test_utils.py +++ b/kombu/tests/utils/test_utils.py @@ -173,7 +173,7 @@ def insomnia(fun): class test_retry_over_time(Case): - def setUp(self): + def setup(self): self.index = 0 class Predicate(Exception): diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index 68cb053c..621c60ca 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -43,20 +43,20 @@ import collections import socket import string +from amqp.promise import transform, ensure_promise, promise from anyjson import loads, dumps -import boto from boto import exception -from boto import sdb as _sdb from boto import sqs as _sqs -from boto.sdb.domain import Domain -from boto.sdb.connection import SDBConnection from boto.sqs.connection import SQSConnection from boto.sqs.message import Message -from kombu.five import Empty, range, text_t +from kombu.async import get_event_loop +from kombu.async.aws import sqs as _asynsqs +from kombu.async.aws.sqs import AsyncSQSConnection +from kombu.five import Empty, range, string_t, text_t from kombu.log import get_logger -from kombu.utils import cached_property, uuid +from kombu.utils import cached_property from kombu.utils.encoding import bytes_to_str, safe_str from kombu.transport.virtual import scheduling @@ -76,103 +76,17 @@ def maybe_int(x): return int(x) except ValueError: return x -BOTO_VERSION = tuple(maybe_int(part) for part in boto.__version__.split('.')) -W_LONG_POLLING = BOTO_VERSION >= (2, 8) #: SQS bulk get supports a maximum of 10 messages at a time. SQS_MAX_MESSAGES = 10 -class Table(Domain): - """Amazon SimpleDB domain describing the message routing table.""" - # caches queues already bound, so we don't have to declare them again. - _already_bound = set() - - def routes_for(self, exchange): - """Iterator giving all routes for an exchange.""" - return self.select("""WHERE exchange = '%s'""" % exchange) - - def get_queue(self, queue): - """Get binding for queue.""" - qid = self._get_queue_id(queue) - if qid: - return self.get_item(qid) - - def create_binding(self, queue): - """Get binding item for queue. - - Creates the item if it doesn't exist. - - """ - item = self.get_queue(queue) - if item: - return item, item['id'] - id = uuid() - return self.new_item(id), id - - def queue_bind(self, exchange, routing_key, pattern, queue): - if queue not in self._already_bound: - binding, id = self.create_binding(queue) - binding.update(exchange=exchange, - routing_key=routing_key or '', - pattern=pattern or '', - queue=queue or '', - id=id) - binding.save() - self._already_bound.add(queue) - - def queue_delete(self, queue): - """delete queue by name.""" - self._already_bound.discard(queue) - item = self._get_queue_item(queue) - if item: - self.delete_item(item) - - def exchange_delete(self, exchange): - """Delete all routes for `exchange`.""" - for item in self.routes_for(exchange): - self.delete_item(item['id']) - - def get_item(self, item_name): - """Uses `consistent_read` by default.""" - # Domain is an old-style class, can't use super(). - for consistent_read in (False, True): - item = Domain.get_item(self, item_name, consistent_read) - if item: - return item - - def select(self, query='', next_token=None, - consistent_read=True, max_items=None): - """Uses `consistent_read` by default.""" - query = """SELECT * FROM `%s` %s""" % (self.name, query) - return Domain.select(self, query, next_token, - consistent_read, max_items) - - def _try_first(self, query='', **kwargs): - for c in (False, True): - for item in self.select(query, consistent_read=c, **kwargs): - return item - - def get_exchanges(self): - return list(set(i['exchange'] for i in self.select())) - - def _get_queue_item(self, queue): - return self._try_first("""WHERE queue = '%s' limit 1""" % queue) - - def _get_queue_id(self, queue): - item = self._get_queue_item(queue) - if item: - return item['id'] - - class Channel(virtual.Channel): - Table = Table - default_region = 'us-east-1' default_visibility_timeout = 1800 # 30 minutes. - default_wait_time_seconds = 0 # disabled see #198 + default_wait_time_seconds = 10 # disabled see #198 domain_format = 'kombu%(vhost)s' - _sdb = None + _asynsqs = None _sqs = None _queue_cache = {} _noack_queues = set() @@ -187,7 +101,6 @@ class Channel(virtual.Channel): queues = self.sqs.get_all_queues(prefix=self.queue_name_prefix) for queue in queues: self._queue_cache[queue.name] = queue - self._fanout_queues = set() # The drain_events() method stores extra messages in a local # Deque object. This allows multiple messages to be requested from @@ -195,9 +108,13 @@ class Channel(virtual.Channel): # to the caller of the drain_events() method. self._queue_message_cache = collections.deque() + self.hub = kwargs.get('hub') or get_event_loop() + def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: self._noack_queues.add(queue) + if self.hub: + self._loop1(queue) return super(Channel, self).basic_consume( queue, no_ack, *args, **kwargs ) @@ -255,6 +172,8 @@ class Channel(virtual.Channel): def _new_queue(self, queue, **kwargs): """Ensure a queue with given name exists in SQS.""" + if not isinstance(queue, string_t): + return queue # Translate to SQS name for consistency with initial # _queue_cache population. queue = self.entity_name(self.queue_name_prefix + queue) @@ -266,57 +185,11 @@ class Channel(virtual.Channel): ) return q - def queue_bind(self, queue, exchange=None, routing_key='', - arguments=None, **kwargs): - super(Channel, self).queue_bind(queue, exchange, routing_key, - arguments, **kwargs) - if self.typeof(exchange).type == 'fanout': - self._fanout_queues.add(queue) - - def _queue_bind(self, *args): - """Bind ``queue`` to ``exchange`` with routing key. - - Route will be stored in SDB if so enabled. - - """ - if self.supports_fanout: - self.table.queue_bind(*args) - - def get_table(self, exchange): - """Get routing table. - - Retrieved from SDB if :attr:`supports_fanout`. - - """ - if self.supports_fanout: - return [(r['routing_key'], r['pattern'], r['queue']) - for r in self.table.routes_for(exchange)] - return super(Channel, self).get_table(exchange) - - def get_exchanges(self): - if self.supports_fanout: - return self.table.get_exchanges() - return super(Channel, self).get_exchanges() - def _delete(self, queue, *args): """delete queue by name.""" - if self.supports_fanout: - self.table.queue_delete(queue) super(Channel, self)._delete(queue) self._queue_cache.pop(queue, None) - def exchange_delete(self, exchange, **kwargs): - """Delete exchange by name.""" - if self.supports_fanout: - self.table.exchange_delete(exchange) - super(Channel, self).exchange_delete(exchange, **kwargs) - - def _has_queue(self, queue, **kwargs): - """Return True if ``queue`` was previously declared.""" - if self.supports_fanout: - return bool(self.table.get_queue(queue)) - return super(Channel, self)._has_queue(queue) - def _put(self, queue, message, **kwargs): """Put message onto queue.""" q = self._new_queue(queue) @@ -324,25 +197,6 @@ class Channel(virtual.Channel): m.set_body(dumps(message)) q.write(m) - def _put_fanout(self, exchange, message, routing_key, **kwargs): - """Deliver fanout message to all queues in ``exchange``.""" - for route in self.table.routes_for(exchange): - self._put(route['queue'], message, **kwargs) - - def _get_from_sqs(self, queue, count=1): - """Retrieve messages from SQS and returns the raw SQS message objects. - - :returns: List of SQS message objects - - """ - q = self._new_queue(queue) - if W_LONG_POLLING and queue not in self._fanout_queues: - return q.get_messages( - count, wait_time_seconds=self.wait_time_seconds, - ) - else: # boto < 2.8 - return q.get_messages(count) - def _message_to_python(self, message, queue_name, queue): payload = loads(bytes_to_str(message.get_body())) if queue_name in self._noack_queues: @@ -369,36 +223,33 @@ class Channel(virtual.Channel): q = self._new_queue(queue) return [self._message_to_python(m, queue, q) for m in messages] - def _get_bulk(self, queue, max_if_unlimited=SQS_MAX_MESSAGES): + def _get_bulk(self, queue, + max_if_unlimited=SQS_MAX_MESSAGES, callback=None): """Try to retrieve multiple messages off ``queue``. - Where _get() returns a single Payload object, this method returns a - list of Payload objects. The number of objects returned is determined - by the total number of messages available in the queue and the - number of messages that the QoS object allows (based on the + Where :meth:`_get` returns a single Payload object, this method + returns a list of Payload objects. The number of objects returned + is determined by the total number of messages available in the queue + and the number of messages the QoS object allows (based on the prefetch_count). .. note:: + Ignores QoS limits so caller is responsible for checking that we are allowed to consume at least one message from the queue. get_bulk will then ask QoS for an estimate of the number of extra messages that we can consume. - args: - queue: The queue name (string) to pull from - - returns: - payloads: A list of payload objects returned + :param queue: The queue name to pull from. + :returns list: of message objects. """ # drain_events calls `can_consume` first, consuming # a token, so we know that we are allowed to consume at least # one message. - maxcount = self.qos.can_consume_max_estimate() - maxcount = max_if_unlimited if maxcount is None else max(maxcount, 1) + maxcount = self._get_message_estimate() if maxcount: - messages = self._get_from_sqs( - queue, count=min(maxcount, SQS_MAX_MESSAGES), - ) + q = self._new_queue(queue) + messages = q.get_messages(num_messages=maxcount) if messages: return self._messages_to_python(messages, queue) @@ -406,12 +257,69 @@ class Channel(virtual.Channel): def _get(self, queue): """Try to retrieve a single message off ``queue``.""" - messages = self._get_from_sqs(queue, count=1) - + q = self._new_queue(queue) + messages = q.get_messages(num_messages=1) if messages: return self._messages_to_python(messages, queue)[0] raise Empty() + def _loop1(self, queue, _=None): + self.hub.call_soon(self._schedule_queue, queue) + + def _schedule_queue(self, queue): + if queue in self._active_queues: + if self.qos.can_consume(): + self._get_bulk_async( + queue, callback=promise(self._loop1, (queue, )), + ) + else: + self._loop1(queue) + + def _get_message_estimate(self, max_if_unlimited=SQS_MAX_MESSAGES): + maxcount = self.qos.can_consume_max_estimate() + return min( + max_if_unlimited if maxcount is None else max(maxcount, 1), + max_if_unlimited, + ) + + def _get_bulk_async(self, queue, + max_if_unlimited=SQS_MAX_MESSAGES, callback=None): + maxcount = self._get_message_estimate() + if maxcount: + return self._get_async(queue, maxcount, callback=callback) + # Not allowed to consume, make sure to notify callback.. + callback = ensure_promise(callback) + callback([]) + return callback + + def _get_async(self, queue, count=1, callback=None): + q = self._new_queue(queue) + return self._get_from_sqs( + q, count=count, connection=self.asynsqs, + callback=transform(self._on_messages_ready, callback, q, queue), + ) + + def _on_messages_ready(self, queue, qname, messages): + if messages: + callbacks = self.connection._callbacks + for raw_message in messages: + message = self._message_to_python(raw_message, qname, queue) + callbacks[qname](message) + + def _get_from_sqs(self, queue, + count=1, connection=None, callback=None): + """Retrieve and handle messages from SQS. + + Uses long polling and returns :class:`~amqp.promise`. + + """ + connection = connection if connection is not None else queue.connection + return connection.receive_message( + queue, number_messages=count, + wait_time_seconds=self.wait_time_seconds, + callback=callback, + ) + def _restore(self, message, unwanted_delivery_info=('sqs_message', 'sqs_queue')): for unwanted_key in unwanted_delivery_info: @@ -448,7 +356,7 @@ class Channel(virtual.Channel): def close(self): super(Channel, self).close() - for conn in (self._sqs, self._sdb): + for conn in (self._sqs, self._asynsqs): if conn: try: conn.close() @@ -477,19 +385,12 @@ class Channel(virtual.Channel): return self._sqs @property - def sdb(self): - if self._sdb is None: - self._sdb = self._aws_connect_to(SDBConnection, _sdb.regions()) - return self._sdb - - @property - def table(self): - name = self.entity_name( - self.domain_format % {'vhost': self.conninfo.virtual_host}) - d = self.sdb.get_object( - 'CreateDomain', {'DomainName': name}, self.Table) - d.name = name - return d + def asynsqs(self): + if self._asynsqs is None: + self._asynsqs = self._aws_connect_to( + AsyncSQSConnection, _asynsqs.regions(), + ) + return self._asynsqs @property def conninfo(self): @@ -510,7 +411,7 @@ class Channel(virtual.Channel): @cached_property def supports_fanout(self): - return self.transport_options.get('sdb_persistence', False) + return False @cached_property def region(self): @@ -537,3 +438,4 @@ class Transport(virtual.Transport): ) driver_type = 'sqs' driver_name = 'sqs' + supports_ev = True diff --git a/kombu/transport/amqplib.py b/kombu/transport/amqplib.py index fff82a1f..08c088a4 100644 --- a/kombu/transport/amqplib.py +++ b/kombu/transport/amqplib.py @@ -315,7 +315,11 @@ class Transport(base.Transport): driver_name = 'amqplib' driver_type = 'amqp' - supports_ev = True + + implements = base.Transport.implements.extend( + async=True, + heartbeats=False, + ) def __init__(self, client, **kwargs): self.client = client diff --git a/kombu/transport/base.py b/kombu/transport/base.py index c226307e..699a70f0 100644 --- a/kombu/transport/base.py +++ b/kombu/transport/base.py @@ -59,6 +59,28 @@ class Management(object): raise _LeftBlank(self, 'get_bindings') +class Implements(dict): + + def __getattr__(self, key): + try: + return self[key] + except KeyError: + raise AttributeError(key) + + def __setattr__(self, key, value): + self[key] = value + + def extend(self, **kwargs): + return self.__class__(self, **kwargs) + + +default_transport_capabilities = Implements( + async=False, + exchange_type=frozenset(['direct', 'topic', 'fanout', 'headers']), + heartbeats=False, +) + + class Transport(object): """Base class for transports.""" Management = Management @@ -87,15 +109,10 @@ class Transport(object): #: Name of driver library (e.g. 'py-amqp', 'redis', 'beanstalkc'). driver_name = 'N/A' - #: Whether this transports support heartbeats, - #: and that the :meth:`heartbeat_check` method has any effect. - supports_heartbeats = False - - #: Set to true if the transport supports the AIO interface. - supports_ev = False - __reader = None + implements = default_transport_capabilities.extend() + def __init__(self, client, **kwargs): self.client = client @@ -123,10 +140,10 @@ class Transport(object): def get_heartbeat_interval(self, connection): return 0 - def register_with_event_loop(self, loop): + def register_with_event_loop(self, connection, loop): pass - def unregister_from_event_loop(self, loop): + def unregister_from_event_loop(self, connection, loop): pass def verify_connection(self, connection): @@ -171,3 +188,11 @@ class Transport(object): @cached_property def manager(self): return self.get_manager() + + @property + def supports_heartbeats(self): + return self.implements.heartbeats + + @property + def supports_ev(self): + return self.implements.async diff --git a/kombu/transport/librabbitmq.py b/kombu/transport/librabbitmq.py index 286bd78e..96786c1a 100644 --- a/kombu/transport/librabbitmq.py +++ b/kombu/transport/librabbitmq.py @@ -88,7 +88,10 @@ class Transport(base.Transport): driver_type = 'amqp' driver_name = 'librabbitmq' - supports_ev = True + implements = base.Transport.implements.extend( + async=True, + heartbeats=False, + ) def __init__(self, client, **kwargs): self.client = client diff --git a/kombu/transport/mongodb.py b/kombu/transport/mongodb.py index 78af0f9f..eeef8756 100644 --- a/kombu/transport/mongodb.py +++ b/kombu/transport/mongodb.py @@ -305,5 +305,9 @@ class Transport(virtual.Transport): driver_type = 'mongodb' driver_name = 'pymongo' + implements = virtual.Transport.implements.extend( + exchange_types=frozenset(['direct', 'topic', 'fanout']), + ) + def driver_version(self): return pymongo.version diff --git a/kombu/transport/pyamqp.py b/kombu/transport/pyamqp.py index 01844305..71f462fe 100644 --- a/kombu/transport/pyamqp.py +++ b/kombu/transport/pyamqp.py @@ -74,8 +74,11 @@ class Transport(base.Transport): driver_name = 'py-amqp' driver_type = 'amqp' - supports_heartbeats = True - supports_ev = True + + implements = base.Transport.implements.extend( + async=True, + heartbeats=True, + ) def __init__(self, client, default_port=None, **kwargs): self.client = client diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index c3f6decd..9753c19e 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -905,10 +905,14 @@ class Transport(virtual.Transport): polling_interval = None # disable sleep between unsuccessful polls. default_port = DEFAULT_PORT - supports_ev = True driver_type = 'redis' driver_name = 'redis' + implements = virtual.Transport.implements.extend( + async=True, + exchange_types=frozenset(['direct', 'topic', 'fanout']) + ) + def __init__(self, *args, **kwargs): super(Transport, self).__init__(*args, **kwargs) diff --git a/kombu/transport/virtual/__init__.py b/kombu/transport/virtual/__init__.py index cb844de9..502f8c9f 100644 --- a/kombu/transport/virtual/__init__.py +++ b/kombu/transport/virtual/__init__.py @@ -771,6 +771,12 @@ class Transport(base.Transport): #: Max number of channels channel_max = 65535 + implements = base.Transport.implements.extend( + async=False, + exchange_type=frozenset(['direct', 'topic']), + heartbeats=False, + ) + def __init__(self, client, **kwargs): self.client = client self.channels = [] @@ -846,6 +852,13 @@ class Transport(base.Transport): self._callbacks[queue](message) + def on_message_ready(self, channel, message, queue): + if not queue or queue not in self._callbacks: + raise KeyError( + 'Message for queue {0!r} without consumers: {1}'.format( + queue, message)) + self._callbacks[queue](message) + def _drain_channel(self, channel, timeout=None): return channel.drain_events(timeout=timeout) diff --git a/kombu/transport/zmq.py b/kombu/transport/zmq.py index e6b8a48b..0df89a1f 100644 --- a/kombu/transport/zmq.py +++ b/kombu/transport/zmq.py @@ -245,7 +245,9 @@ class Transport(virtual.Transport): connection_errors = virtual.Transport.connection_errors + (ZMQError, ) - supports_ev = True + implements = virtual.Transport.implements.extend( + async=True, + ) polling_interval = None def __init__(self, *args, **kwargs): diff --git a/requirements/extras/sqs.txt b/requirements/extras/sqs.txt index 66b95834..8cd697e0 100644 --- a/requirements/extras/sqs.txt +++ b/requirements/extras/sqs.txt @@ -1 +1 @@ -boto>=2.13.3 +boto>=2.8.0 |