diff options
author | Mischa Spiegelmock <revmischa@cpan.org> | 2017-04-13 22:22:18 -0700 |
---|---|---|
committer | Asif Saifuddin Auvi <auvipy@users.noreply.github.com> | 2017-04-14 11:22:18 +0600 |
commit | 129a9e4ed05bf9a99d12fff9e17c9ffb37b14c4d (patch) | |
tree | c4eea532ea4ca84dee0049ab4304d86f0cc73ab7 | |
parent | bf820b20b022556c72402565f0ae50124017d6fe (diff) | |
download | kombu-129a9e4ed05bf9a99d12fff9e17c9ffb37b14c4d.tar.gz |
Switching to boto3 only (#693)
* Switch Boto2 to Boto3 for SQS messaging
* Fixed region support
* Add SQS FIFO queue support
* Add sensible defaults for message attributes
* Asynchronous support, plus boto3 for region endpoint lookups
* Clean up imports
* Fix Python 2 support
* Fix receive_message tests
* Reformat docstring
* boto3 import changes for CI
* skip tests if boto3 not installed
* skip tests if boto3 not installed
* flake8
* noboto
* ditching boto2. got queue URL fetching, async HTTP request generation and signing working.
* request signing working kinda
* async parsing of SQS message response more or less working
* botocore sqs dep
* ripping out more old boto2 stuff
* removing tests that are no longer valid with boto3/SQS
* fix boto3 dep, min version and no botocore
* no boto2 for test
* cleaning up some SQS tests. fixing header parsing of response to msg
* fixing some sqs tests
* removing response-parsing tests that are no longer necessary as we're using the botocore response parsing machinery instead of implementing SAX parsing in kombu.
* fixing more SQS tests
* wants a region
* trying to fix py2 parsing of sqs message
* lint
* py2/py2 message header parsing stupidness
* forgot
* python 2 sux
* flake8
* Import boto3 from the right place
* Changes
* Update encode fuction
* Fix lint
* remove some unused things
* removing unused stuff
* ugh
* ugh
* ugh
* landscape ignoring
* shut up, landscape
-rw-r--r-- | examples/hello_consumer.py | 2 | ||||
-rw-r--r-- | kombu/async/aws/connection.py | 229 | ||||
-rw-r--r-- | kombu/async/aws/ext.py | 31 | ||||
-rw-r--r-- | kombu/async/aws/sqs/__init__.py | 22 | ||||
-rw-r--r-- | kombu/async/aws/sqs/connection.py | 60 | ||||
-rw-r--r-- | kombu/async/aws/sqs/ext.py | 31 | ||||
-rw-r--r-- | kombu/async/aws/sqs/message.py | 51 | ||||
-rw-r--r-- | kombu/async/aws/sqs/queue.py | 3 | ||||
-rw-r--r-- | kombu/async/http/base.py | 9 | ||||
-rw-r--r-- | kombu/async/http/curl.py | 17 | ||||
-rw-r--r-- | kombu/transport/SQS.py | 197 | ||||
-rw-r--r-- | requirements/extras/sqs.txt | 2 | ||||
-rw-r--r-- | requirements/funtest.txt | 1 | ||||
-rw-r--r-- | requirements/test-ci.txt | 1 | ||||
-rw-r--r-- | t/integration/tests/test_SQS.py | 1 | ||||
-rw-r--r-- | t/unit/async/aws/case.py | 2 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_connection.py | 73 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_message.py | 37 | ||||
-rw-r--r-- | t/unit/async/aws/sqs/test_sqs.py | 34 | ||||
-rw-r--r-- | t/unit/async/aws/test_aws.py | 2 | ||||
-rw-r--r-- | t/unit/async/aws/test_connection.py | 271 | ||||
-rw-r--r-- | t/unit/transport/test_SQS.py | 218 |
22 files changed, 491 insertions, 803 deletions
diff --git a/examples/hello_consumer.py b/examples/hello_consumer.py index 5722450e..71f40f13 100644 --- a/examples/hello_consumer.py +++ b/examples/hello_consumer.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, unicode_literals, print_function -from kombu import Connection +from kombu import Connection # noqa with Connection('amqp://guest:guest@localhost:5672//') as conn: diff --git a/kombu/async/aws/connection.py b/kombu/async/aws/connection.py index cc907d7b..303f3577 100644 --- a/kombu/async/aws/connection.py +++ b/kombu/async/aws/connection.py @@ -1,37 +1,36 @@ -# -*- coding: utf-8 -*- +# * coding: utf8 * """Amazon AWS Connection.""" from __future__ import absolute_import, unicode_literals -from io import BytesIO - from vine import promise, transform +from kombu.async.aws.ext import AWSRequest, get_response + from kombu.async.http import Headers, Request, get_client from kombu.five import items, python_2_unicode_compatible -from .ext import ( - boto, AWSAuthConnection, AWSQueryConnection, XmlHandler, ResultSet, -) - -try: - from urllib.parse import urlunsplit -except ImportError: - from urlparse import urlunsplit # noqa -from xml.sax import parseString as sax_parse # noqa +import io try: # pragma: no cover - from email import message_from_file + from email import message_from_bytes from email.mime.message import MIMEMessage + + # py3 + def message_from_headers(hdr): # noqa + bs = "\r\n".join("{}: {}".format(*h) for h in hdr) + return message_from_bytes(bs.encode()) + except ImportError: # pragma: no cover from mimetools import Message as MIMEMessage # noqa - def message_from_file(m): # noqa - return m + # py2 + def message_from_headers(hdr): # noqa + return io.BytesIO(b'\r\n'.join( + b'{0}: {1}'.format(*h) for h in hdr + )) __all__ = [ - 'AsyncHTTPConnection', 'AsyncHTTPSConnection', - 'AsyncHTTPResponse', 'AsyncConnection', - 'AsyncAWSAuthConnection', 'AsyncAWSQueryConnection', + 'AsyncHTTPSConnection', 'AsyncConnection', ] @@ -56,11 +55,7 @@ class AsyncHTTPResponse(object): @property def msg(self): if self._msg is None: - self._msg = MIMEMessage(message_from_file( - BytesIO(b'\r\n'.join( - b'{0}: {1}'.format(*h) for h in self.getheaders()) - ) - )) + self._msg = MIMEMessage(message_from_headers(self.getheaders())) return self._msg @property @@ -78,7 +73,7 @@ class AsyncHTTPResponse(object): @python_2_unicode_compatible -class AsyncHTTPConnection(object): +class AsyncHTTPSConnection(object): """Async HTTP Connection.""" Request = Request @@ -87,13 +82,9 @@ class AsyncHTTPConnection(object): method = 'GET' path = '/' body = None - scheme = 'http' default_ports = {'http': 80, 'https': 443} - def __init__(self, host, port=None, - strict=None, timeout=20.0, http_client=None, **kwargs): - self.host = host - self.port = port + def __init__(self, strict=None, timeout=20.0, http_client=None): self.headers = [] self.timeout = timeout self.strict = strict @@ -112,14 +103,9 @@ class AsyncHTTPConnection(object): if headers is not None: self.headers.extend(list(items(headers))) - def getrequest(self, scheme=None): - scheme = scheme if scheme else self.scheme - host = self.host - if self.port and self.port != self.default_ports[scheme]: - host = '{0}:{1}'.format(host, self.port) - url = urlunsplit((scheme, host, self.path, '', '')) + def getrequest(self): headers = Headers(self.headers) - return self.Request(url, method=self.method, headers=headers, + return self.Request(self.path, method=self.method, headers=headers, body=self.body, connect_timeout=self.timeout, request_timeout=self.timeout, validate_cert=False) @@ -137,7 +123,7 @@ class AsyncHTTPConnection(object): def close(self): pass - def putrequest(self, method, path, **kwargs): + def putrequest(self, method, path): self.method = method self.path = path @@ -157,139 +143,120 @@ class AsyncHTTPConnection(object): return '<AsyncHTTPConnection: {0!r}>'.format(self.getrequest()) -class AsyncHTTPSConnection(AsyncHTTPConnection): - """Async HTTPS Connection.""" - - scheme = 'https' - - class AsyncConnection(object): """Async AWS Connection.""" - def __init__(self, http_client=None, **kwargs): - if boto is None: - raise ImportError('boto is not installed') + def __init__(self, sqs_connection, http_client=None, **kwargs): # noqa + self.sqs_connection = sqs_connection self._httpclient = http_client or get_client() - def get_http_connection(self, host, port, is_secure): - return (AsyncHTTPSConnection if is_secure else AsyncHTTPConnection)( - host, port, http_client=self._httpclient, - ) + def get_http_connection(self): + return AsyncHTTPSConnection(http_client=self._httpclient) def _mexe(self, request, sender=None, callback=None): callback = callback or promise() - boto.log.debug( - 'HTTP %s/%s headers=%s body=%s', - request.host, request.path, - request.headers, request.body, - ) - - conn = self.get_http_connection( - request.host, request.port, self.is_secure, - ) - request.authorize(connection=self) + conn = self.get_http_connection() if callable(sender): sender(conn, request.method, request.path, request.body, request.headers, callback) else: - conn.request(request.method, request.path, + conn.request(request.method, request.url, request.body, request.headers) conn.getresponse(callback=callback) return callback -class AsyncAWSAuthConnection(AsyncConnection, AWSAuthConnection): - """Async AWS Authn Connection.""" - - def __init__(self, host, - http_client=None, http_client_params={}, **kwargs): - AsyncConnection.__init__(self, http_client, **http_client_params) - AWSAuthConnection.__init__(self, host, **kwargs) - - def make_request(self, method, path, headers=None, data='', host=None, - auth_path=None, sender=None, callback=None, **kwargs): - req = self.build_base_http_request( - method, path, auth_path, {}, headers, data, host, - ) - return self._mexe(req, sender=sender, callback=callback) - - -class AsyncAWSQueryConnection(AsyncConnection, AWSQueryConnection): +class AsyncAWSQueryConnection(AsyncConnection): """Async AWS Query Connection.""" - def __init__(self, host, - http_client=None, http_client_params={}, **kwargs): - AsyncConnection.__init__(self, http_client, **http_client_params) - AWSAuthConnection.__init__(self, host, **kwargs) - - def make_request(self, action, params, path, verb, callback=None): - request = self.build_base_http_request( - verb, path, None, params, {}, '', self.server_name()) - if action: - request.params['Action'] = action - request.params['Version'] = self.APIVersion - return self._mexe(request, callback=callback) - - def get_list(self, action, params, markers, - path='/', parent=None, verb='GET', callback=None): + def __init__(self, sqs_connection, http_client=None, + http_client_params=None, **kwargs): + if not http_client_params: + http_client_params = {} + AsyncConnection.__init__(self, sqs_connection, http_client, + **http_client_params) + + def make_request(self, operation, params_, path, verb, callback=None): # noqa + params = params_.copy() + if operation: + params['Action'] = operation + signer = self.sqs_connection._request_signer # noqa + + # defaults for non-get + signing_type = 'standard' + param_payload = {'data': params} + if verb.lower() == 'get': + # query-based opts + signing_type = 'presignurl' + param_payload = {'params': params} + + request = AWSRequest(method=verb, url=path, **param_payload) + signer.sign(operation, request, signing_type=signing_type) + prepared_request = request.prepare() + + # print(prepared_request.url) + # print(prepared_request.headers) + # print(prepared_request.body) + return self._mexe(prepared_request, callback=callback) + + def get_list(self, operation, params, markers, path='/', parent=None, verb='POST', callback=None): # noqa return self.make_request( - action, params, path, verb, + operation, params, path, verb, callback=transform( self._on_list_ready, callback, parent or self, markers, + operation ), ) - def get_object(self, action, params, cls, - path='/', parent=None, verb='GET', callback=None): + def get_object(self, operation, params, path='/', parent=None, verb='GET', callback=None): # noqa return self.make_request( - action, params, path, verb, + operation, params, path, verb, callback=transform( - self._on_obj_ready, callback, parent or self, cls, + self._on_obj_ready, callback, parent or self, operation ), ) - def get_status(self, action, params, - path='/', parent=None, verb='GET', callback=None): + def get_status(self, operation, params, path='/', parent=None, verb='GET', callback=None): # noqa return self.make_request( - action, params, path, verb, + operation, params, path, verb, callback=transform( - self._on_status_ready, callback, parent or self, + self._on_status_ready, callback, parent or self, operation ), ) - def _on_list_ready(self, parent, markers, response): - body = response.read() - if response.status == 200 and body: - rs = ResultSet(markers) - h = XmlHandler(rs, parent) - sax_parse(body, h) - return rs + def _on_list_ready(self, parent, markers, operation, response): # noqa + service_model = self.sqs_connection.meta.service_model + if response.status == 200: + _, parsed = get_response( + service_model.operation_model(operation), response.response + ) + return parsed else: - raise self._for_status(response, body) - - def _on_obj_ready(self, parent, cls, response): - body = response.read() - if response.status == 200 and body: - obj = cls(parent) - h = XmlHandler(obj, parent) - sax_parse(body, h) - return obj + raise self._for_status(response, response.read()) + + def _on_obj_ready(self, parent, operation, response): # noqa + service_model = self.sqs_connection.meta.service_model + if response.status == 200: + _, parsed = get_response( + service_model.operation_model(operation), response.response + ) + return parsed else: - raise self._for_status(response, body) - - def _on_status_ready(self, parent, response): - body = response.read() - if response.status == 200 and body: - rs = ResultSet() - h = XmlHandler(rs, parent) - sax_parse(body, h) - return rs.status + raise self._for_status(response, response.read()) + + def _on_status_ready(self, parent, operation, response): # noqa + service_model = self.sqs_connection.meta.service_model + if response.status == 200: + httpres, _ = get_response( + service_model.operation_model(operation), response.response + ) + return httpres.code else: - raise self._for_status(response, body) + raise self._for_status(response, response.read()) def _for_status(self, response, body): context = 'Empty body' if not body else 'HTTP Error' - exc = self.ResponseError(response.status, response.reason, body) - boto.log.error('{0}: %r'.format(context), exc) - return exc + return Exception("Request {} HTTP {} {} ({})".format( + context, response.status, response.reason, body + )) diff --git a/kombu/async/aws/ext.py b/kombu/async/aws/ext.py index b0e497cb..3519ab65 100644 --- a/kombu/async/aws/ext.py +++ b/kombu/async/aws/ext.py @@ -1,29 +1,26 @@ # -*- coding: utf-8 -*- -"""Amazon boto interface.""" +"""Amazon boto3 interface.""" from __future__ import absolute_import, unicode_literals try: - import boto -except ImportError: # pragma: no cover - boto = get_regions = ResultSet = RegionInfo = XmlHandler = None + import boto3 + from botocore import exceptions + from botocore.awsrequest import AWSRequest + from botocore.response import get_response +except ImportError: + boto3 = None class _void(object): pass - AWSAuthConnection = AWSQueryConnection = _void # noqa - class BotoError(Exception): + class BotoCoreError(Exception): pass - exception = _void() - exception.SQSError = BotoError - exception.SQSDecodeError = BotoError -else: - from boto import exception - from boto.connection import AWSAuthConnection, AWSQueryConnection - from boto.handler import XmlHandler - from boto.resultset import ResultSet - from boto.regioninfo import RegionInfo, get_regions + exceptions = _void() + exceptions.BotoCoreError = BotoCoreError + AWSRequest = _void() + get_response = _void() + __all__ = [ - 'exception', 'AWSAuthConnection', 'AWSQueryConnection', - 'XmlHandler', 'ResultSet', 'RegionInfo', 'get_regions', + 'exceptions', 'AWSRequest', 'get_response' ] diff --git a/kombu/async/aws/sqs/__init__.py b/kombu/async/aws/sqs/__init__.py index fe529584..e69de29b 100644 --- a/kombu/async/aws/sqs/__init__.py +++ b/kombu/async/aws/sqs/__init__.py @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, unicode_literals - -from kombu.async.aws.ext import boto, get_regions - -from .connection import AsyncSQSConnection - -__all__ = ['regions', 'connect_to_region'] - - -def regions(): - """Return list of known AWS regions.""" - if boto is None: - raise ImportError('boto is not installed') - return get_regions('sqs', connection_cls=AsyncSQSConnection) - - -def connect_to_region(region_name, **kwargs): - """Connect to specific AWS region.""" - for region in regions(): - if region.name == region_name: - return region.connect(**kwargs) diff --git a/kombu/async/aws/sqs/connection.py b/kombu/async/aws/sqs/connection.py index 32b9926b..7b0b738e 100644 --- a/kombu/async/aws/sqs/connection.py +++ b/kombu/async/aws/sqs/connection.py @@ -5,9 +5,8 @@ from __future__ import absolute_import, unicode_literals from vine import transform from kombu.async.aws.connection import AsyncAWSQueryConnection -from kombu.async.aws.ext import RegionInfo -from .ext import boto, Attributes, BatchResults, SQSConnection +from .ext import boto3 from .message import AsyncMessage from .queue import AsyncQueue @@ -15,28 +14,17 @@ from .queue import AsyncQueue __all__ = ['AsyncSQSConnection'] -class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): +class AsyncSQSConnection(AsyncAWSQueryConnection): """Async SQS Connection.""" - def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, - is_secure=True, port=None, proxy=None, proxy_port=None, - proxy_user=None, proxy_pass=None, debug=0, - https_connection_factory=None, region=None, *args, **kwargs): - if boto is None: - raise ImportError('boto is not installed') - self.region = region or RegionInfo( - self, self.DefaultRegionName, self.DefaultRegionEndpoint, - connection_cls=type(self), - ) + def __init__(self, sqs_connection, debug=0, region=None, **kwargs): + if boto3 is None: + raise ImportError('boto3 is not installed') AsyncAWSQueryConnection.__init__( self, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - is_secure=is_secure, port=port, - proxy=proxy, proxy_port=proxy_port, - proxy_user=proxy_user, proxy_pass=proxy_pass, - host=self.region.endpoint, debug=debug, - https_connection_factory=https_connection_factory, **kwargs + sqs_connection, + region_name=region, debug=debug, + **kwargs ) def create_queue(self, queue_name, @@ -46,17 +34,21 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): params['DefaultVisibilityTimeout'] = format( visibility_timeout, 'd', ) - return self.get_object('CreateQueue', params, AsyncQueue, + return self.get_object('CreateQueue', params, callback=callback) def delete_queue(self, queue, force_deletion=False, callback=None): return self.get_status('DeleteQueue', None, queue.id, callback=callback) + def get_queue_url(self, queue): + res = self.sqs_connection.get_queue_url(QueueName=queue) + return res['QueueUrl'] + def get_queue_attributes(self, queue, attribute='All', callback=None): return self.get_object( 'GetQueueAttributes', {'AttributeName': attribute}, - Attributes, queue.id, callback=callback, + queue.id, callback=callback, ) def set_queue_attribute(self, queue, attribute, value, callback=None): @@ -74,17 +66,21 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): if visibility_timeout: params['VisibilityTimeout'] = visibility_timeout if attributes: - self.build_list_params(params, attributes, 'AttributeName') + attrs = {} + for idx, attr in enumerate(attributes): + attrs['AttributeName.' + str(idx + 1)] = attr + params.update(attrs) if wait_time_seconds is not None: params['WaitTimeSeconds'] = wait_time_seconds + queue_url = self.get_queue_url(queue) return self.get_list( - 'ReceiveMessage', params, [('Message', queue.message_class)], - queue.id, callback=callback, + 'ReceiveMessage', params, [('Message', AsyncMessage)], + queue_url, callback=callback, parent=queue, ) - def delete_message(self, queue, message, callback=None): + def delete_message(self, queue, receipt_handle, callback=None): return self.delete_message_from_handle( - queue, message.receipt_handle, callback, + queue, receipt_handle, callback, ) def delete_message_batch(self, queue, messages, callback=None): @@ -96,7 +92,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): '{0}.ReceiptHandle'.format(prefix): m.receipt_handle, }) return self.get_object( - 'DeleteMessageBatch', params, BatchResults, queue.id, + 'DeleteMessageBatch', params, queue.id, verb='POST', callback=callback, ) @@ -104,7 +100,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): callback=None): return self.get_status( 'DeleteMessage', {'ReceiptHandle': receipt_handle}, - queue.id, callback=callback, + queue, callback=callback, ) def send_message(self, queue, message_content, @@ -113,7 +109,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): if delay_seconds: params['DelaySeconds'] = int(delay_seconds) return self.get_object( - 'SendMessage', params, AsyncMessage, queue.id, + 'SendMessage', params, queue.id, verb='POST', callback=callback, ) @@ -127,7 +123,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): '{0}.DelaySeconds'.format(prefix): msg[2], }) return self.get_object( - 'SendMessageBatch', params, BatchResults, queue.id, + 'SendMessageBatch', params, queue.id, verb='POST', callback=callback, ) @@ -150,7 +146,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection): '{0}.VisibilityTimeout'.format(pre): t[1], }) return self.get_object( - 'ChangeMessageVisibilityBatch', params, BatchResults, queue.id, + 'ChangeMessageVisibilityBatch', params, queue.id, verb='POST', callback=callback, ) diff --git a/kombu/async/aws/sqs/ext.py b/kombu/async/aws/sqs/ext.py index eb48f3e9..09fdf1a1 100644 --- a/kombu/async/aws/sqs/ext.py +++ b/kombu/async/aws/sqs/ext.py @@ -1,32 +1,9 @@ # -*- coding: utf-8 -*- -"""Amazon SQS boto interface.""" +"""Amazon SQS boto3 interface.""" from __future__ import absolute_import, unicode_literals try: - import boto -except ImportError: # pragma: no cover - boto = Attributes = BatchResults = None # noqa - - class _void(object): - pass - regions = SQSConnection = Queue = _void - - RawMessage = Message = MHMessage = \ - EncodedMHMessage = JSONMessage = _void -else: - from boto.sqs.attributes import Attributes - from boto.sqs.batchresults import BatchResults - from boto.sqs.message import ( - EncodedMHMessage, Message, MHMessage, RawMessage, - ) - from boto.sqs import regions - from boto.sqs.jsonmessage import JSONMessage - from boto.sqs.connection import SQSConnection - from boto.sqs.queue import Queue - -__all__ = [ - 'Attributes', 'BatchResults', 'EncodedMHMessage', 'MHMessage', - 'Message', 'RawMessage', 'JSONMessage', 'SQSConnection', - 'Queue', 'regions', -] + import boto3 +except ImportError: + boto3 = None diff --git a/kombu/async/aws/sqs/message.py b/kombu/async/aws/sqs/message.py index 36359841..b4b28331 100644 --- a/kombu/async/aws/sqs/message.py +++ b/kombu/async/aws/sqs/message.py @@ -2,45 +2,32 @@ """Amazon SQS message implementation.""" from __future__ import absolute_import, unicode_literals -from .ext import ( - RawMessage, Message, MHMessage, EncodedMHMessage, JSONMessage, -) +from kombu.message import Message +import base64 -__all__ = [ - 'BaseAsyncMessage', 'AsyncRawMessage', 'AsyncMessage', - 'AsyncMHMessage', 'AsyncEncodedMHMessage', 'AsyncJSONMessage', -] - -class BaseAsyncMessage(object): +class BaseAsyncMessage(Message): """Base class for messages received on async client.""" - def delete(self, callback=None): - if self.queue: - return self.queue.delete_message(self, callback) - - def change_visibility(self, visibility_timeout, callback=None): - if self.queue: - return self.queue.connection.change_message_visibility( - self.queue, self.receipt_handle, visibility_timeout, callback, - ) - -class AsyncRawMessage(BaseAsyncMessage, RawMessage): +class AsyncRawMessage(BaseAsyncMessage): """Raw Message.""" -class AsyncMessage(BaseAsyncMessage, Message): +class AsyncMessage(BaseAsyncMessage): """Serialized message.""" - -class AsyncMHMessage(BaseAsyncMessage, MHMessage): - """MHM Message (uhm, look that up later).""" - - -class AsyncEncodedMHMessage(BaseAsyncMessage, EncodedMHMessage): - """Encoded MH Message.""" - - -class AsyncJSONMessage(BaseAsyncMessage, JSONMessage): - """Json serialized message.""" + def encode(self, value): + """Encode/decode the value using Base64 encoding.""" + return base64.b64encode(value).decode('utf-8') + + def __getitem__(self, item): + """Support Boto3-style access on a message.""" + if item == 'ReceiptHandle': + return self.receipt_handle + elif item == 'Body': + return self.get_body() + elif item == 'queue': + return self.queue + else: + raise KeyError(item) diff --git a/kombu/async/aws/sqs/queue.py b/kombu/async/aws/sqs/queue.py index 140ff31a..c0f8c0f3 100644 --- a/kombu/async/aws/sqs/queue.py +++ b/kombu/async/aws/sqs/queue.py @@ -4,7 +4,6 @@ from __future__ import absolute_import, unicode_literals from vine import transform -from .ext import Queue as _Queue from .message import AsyncMessage _all__ = ['AsyncQueue'] @@ -15,7 +14,7 @@ def list_first(rs): return rs[0] if len(rs) == 1 else None -class AsyncQueue(_Queue): +class AsyncQueue(): """Async SQS Queue.""" def __init__(self, connection=None, url=None, message_class=AsyncMessage): diff --git a/kombu/async/http/base.py b/kombu/async/http/base.py index f8a8bc0a..a8eb6edf 100644 --- a/kombu/async/http/base.py +++ b/kombu/async/http/base.py @@ -200,6 +200,15 @@ class Response(object): self._body = self.buffer.getvalue() return self._body + # these are for compatibility with Requests + @property + def status_code(self): + return self.code + + @property + def content(self): + return self.body + @coro def header_parser(keyt=normalize_header): diff --git a/kombu/async/http/curl.py b/kombu/async/http/curl.py index 1c50eef8..d8520ded 100644 --- a/kombu/async/http/curl.py +++ b/kombu/async/http/curl.py @@ -171,15 +171,12 @@ class CurlClient(BaseClient): code = curl.getinfo(_pycurl.HTTP_CODE) effective_url = curl.getinfo(_pycurl.EFFECTIVE_URL) buffer.seek(0) - try: - request = info['request'] - request.on_ready(self.Response( - request=request, code=code, headers=info['headers'], - buffer=buffer, effective_url=effective_url, error=error, - )) - except Exception as exc: - self.hub.on_callback_error(request.on_ready, exc) - raise + # try: + request = info['request'] + request.on_ready(self.Response( + request=request, code=code, headers=info['headers'], + buffer=buffer, effective_url=effective_url, error=error, + )) def _setup_request(self, curl, request, buffer, headers, _pycurl=pycurl): setopt = curl.setopt @@ -243,7 +240,7 @@ class CurlClient(BaseClient): setopt(meth, True) if request.method in ('POST', 'PUT'): - body = request.body or '' + body = request.body.encode('utf-8') if request.body else bytes() reqbuffer = BytesIO(body) setopt(_pycurl.READFUNCTION, reqbuffer.read) if request.method == 'POST': diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py index a0fb6e8d..59cfdc87 100644 --- a/kombu/transport/SQS.py +++ b/kombu/transport/SQS.py @@ -39,15 +39,14 @@ from __future__ import absolute_import, unicode_literals import socket import string +import uuid from vine import transform, ensure_promise, promise from kombu.async import get_event_loop -from kombu.async.aws import sqs as _asynsqs -from kombu.async.aws.ext import boto, exception -from kombu.async.aws.sqs.connection import AsyncSQSConnection, SQSConnection -from kombu.async.aws.sqs.ext import regions -from kombu.async.aws.sqs.message import Message +from kombu.async.aws.ext import boto3, exceptions +from kombu.async.aws.sqs.connection import AsyncSQSConnection +from kombu.async.aws.sqs.message import AsyncMessage from kombu.five import Empty, range, string_t, text_t from kombu.log import get_logger from kombu.utils import scheduling @@ -91,8 +90,8 @@ class Channel(virtual.Channel): _noack_queues = set() def __init__(self, *args, **kwargs): - if boto is None: - raise ImportError('boto is not installed') + if boto3 is None: + raise ImportError('boto3 is not installed') super(Channel, self).__init__(*args, **kwargs) # SQS blows up if you try to create a new queue when one already @@ -104,18 +103,10 @@ class Channel(virtual.Channel): self.hub = kwargs.get('hub') or get_event_loop() def _update_queue_cache(self, queue_name_prefix): - try: - queues = self.sqs.get_all_queues(prefix=queue_name_prefix) - except exception.SQSError as exc: - if exc.status == 403: - raise RuntimeError( - 'SQS authorization error, access_key={0}'.format( - self.sqs.access_key)) - raise - else: - self._queue_cache.update({ - queue.name: queue for queue in queues - }) + resp = self.sqs.list_queues(QueueNamePrefix=queue_name_prefix) + for url in resp.get('QueueUrls', []): + queue_name = url.split('/')[-1] + self._queue_cache[queue_name] = url def basic_consume(self, queue, no_ack, *args, **kwargs): if no_ack: @@ -132,7 +123,7 @@ class Channel(virtual.Channel): self._noack_queues.discard(queue) return super(Channel, self).basic_cancel(consumer_tag) - def drain_events(self, timeout=None, **kwargs): + def drain_events(self, timeout=None, callback=None, **kwargs): """Return a single payload message from one of our queues. Raises: @@ -143,7 +134,7 @@ class Channel(virtual.Channel): raise Empty() # At this point, go and get more messages from SQS - self._poll(self.cycle, self.connection._deliver, timeout=timeout) + self._poll(self.cycle, callback, timeout=timeout) def _reset_cycle(self): """Reset the consume cycle. @@ -160,7 +151,15 @@ class Channel(virtual.Channel): def entity_name(self, name, table=CHARS_REPLACE_TABLE): """Format AMQP queue name into a legal SQS queue name.""" - return text_t(safe_str(name)).translate(table) + if name.endswith('.fifo'): + partial = name.rstrip('.fifo') + partial = text_t(safe_str(partial)).translate(table) + return partial + '.fifo' + else: + return text_t(safe_str(name)).translate(table) + + def canonical_queue_name(self, queue_name): + return self.entity_name(self.queue_name_prefix + queue_name) def _new_queue(self, queue, **kwargs): """Ensure a queue with given name exists in SQS.""" @@ -168,7 +167,7 @@ class Channel(virtual.Channel): return queue # Translate to SQS name for consistency with initial # _queue_cache population. - queue = self.entity_name(self.queue_name_prefix + queue) + queue = self.canonical_queue_name(queue) # The SQS ListQueues method only returns 1000 queues. When you have # so many queues, it's possible that the queue you are looking for is @@ -179,10 +178,14 @@ class Channel(virtual.Channel): try: return self._queue_cache[queue] except KeyError: - q = self._queue_cache[queue] = self.sqs.create_queue( - queue, self.visibility_timeout, - ) - return q + attributes = {'VisibilityTimeout': str(self.visibility_timeout)} + if queue.endswith('.fifo'): + attributes['FifoQueue'] = 'true' + + resp = self._queue_cache[queue] = self.sqs.create_queue( + QueueName=queue, Attributes=attributes) + self._queue_cache[queue] = resp['QueueUrl'] + return resp['QueueUrl'] def _delete(self, queue, *args, **kwargs): """Delete queue by name.""" @@ -191,15 +194,27 @@ class Channel(virtual.Channel): def _put(self, queue, message, **kwargs): """Put message onto queue.""" - q = self._new_queue(queue) - m = Message() - m.set_body(dumps(message)) - q.write(m) + q_url = self._new_queue(queue) + kwargs = {'QueueUrl': q_url, + 'MessageBody': AsyncMessage().encode(dumps(message))} + if queue.endswith('.fifo'): + if 'MessageGroupId' in message['properties']: + kwargs['MessageGroupId'] = \ + message['properties']['MessageGroupId'] + else: + kwargs['MessageGroupId'] = 'default' + if 'MessageDeduplicationId' in message['properties']: + kwargs['MessageDeduplicationId'] = \ + message['properties']['MessageDeduplicationId'] + else: + kwargs['MessageDeduplicationId'] = str(uuid.uuid4()) + self.sqs.send_message(**kwargs) def _message_to_python(self, message, queue_name, queue): - payload = loads(bytes_to_str(message.get_body())) + payload = loads(bytes_to_str(message['Body'])) if queue_name in self._noack_queues: - queue.delete_message(message) + queue = self._new_queue(queue_name) + self.asynsqs.delete_message(queue, message['ReceiptHandle']) else: try: properties = payload['properties'] @@ -209,14 +224,14 @@ class Channel(virtual.Channel): delivery_info = {} properties = {'delivery_info': delivery_info} payload.update({ - 'body': bytes_to_str(message.get_body()), + 'body': bytes_to_str(message['Body']), 'properties': properties, }) # set delivery tag to SQS receipt handle delivery_info.update({ 'sqs_message': message, 'sqs_queue': queue, }) - properties['delivery_tag'] = message.receipt_handle + properties['delivery_tag'] = message['ReceiptHandle'] return payload def _messages_to_python(self, messages, queue): @@ -261,23 +276,31 @@ class Channel(virtual.Channel): # drain_events calls `can_consume` first, consuming # a token, so we know that we are allowed to consume at least # one message. - maxcount = self._get_message_estimate() - if maxcount: - q = self._new_queue(queue) - messages = q.get_messages(num_messages=maxcount) - if messages: - for msg in self._messages_to_python(messages, queue): + # Note: ignoring max_messages for SQS with boto3 + max_count = self._get_message_estimate() + if max_count: + q_url = self._new_queue(queue) + resp = self.sqs.receive_message( + QueueUrl=q_url, MaxNumberOfMessages=max_count) + + if resp['Messages']: + for m in resp['Messages']: + m['Body'] = AsyncMessage().decode(m['Body']) + for msg in self._messages_to_python(resp['Messages'], queue): self.connection._deliver(msg, queue) return raise Empty() def _get(self, queue): """Try to retrieve a single message off ``queue``.""" - q = self._new_queue(queue) - messages = q.get_messages(num_messages=1) - if messages: - return self._messages_to_python(messages, queue)[0] + q_url = self._new_queue(queue) + resp = self.sqs.receive_message(q_url) + + if resp['Messages']: + body = AsyncMessage().decode(resp['Messages'][0]['Body']) + resp['Messages'][0]['Body'] = body + return self._messages_to_python(resp['Messages'], queue)[0] raise Empty() def _loop1(self, queue, _=None): @@ -311,17 +334,18 @@ class Channel(virtual.Channel): def _get_async(self, queue, count=1, callback=None): q = self._new_queue(queue) + qname = self.canonical_queue_name(queue) return self._get_from_sqs( - q, count=count, connection=self.asynsqs, + qname, count=count, connection=self.asynsqs, callback=transform(self._on_messages_ready, callback, q, queue), ) def _on_messages_ready(self, queue, qname, messages): - if messages: + if 'Messages' in messages and messages['Messages']: callbacks = self.connection._callbacks - for raw_message in messages: - message = self._message_to_python(raw_message, qname, queue) - callbacks[qname](message) + for msg in messages['Messages']: + msg_parsed = self._message_to_python(msg, qname, queue) + callbacks[qname](msg_parsed) def _get_from_sqs(self, queue, count=1, connection=None, callback=None): @@ -330,6 +354,7 @@ class Channel(virtual.Channel): Uses long polling and returns :class:`~vine.promises.promise`. """ connection = connection if connection is not None else queue.connection + # url = self.get_queue return connection.receive_message( queue, number_messages=count, wait_time_seconds=self.wait_time_seconds, @@ -344,72 +369,68 @@ class Channel(virtual.Channel): return super(Channel, self)._restore(message) def basic_ack(self, delivery_tag, multiple=False): - delivery_info = self.qos.get(delivery_tag).delivery_info try: - queue = delivery_info['sqs_queue'] + message = self.qos.get(delivery_tag).delivery_info + sqs_message = message['sqs_message'] except KeyError: pass else: - queue.delete_message(delivery_info['sqs_message']) + self.asynsqs.delete_message(message['sqs_queue'], + sqs_message['ReceiptHandle']) super(Channel, self).basic_ack(delivery_tag) def _size(self, queue): """Return the number of messages in a queue.""" - return self._new_queue(queue).count() + url = self._new_queue(queue) + resp = self.sqs.get_queue_attributes( + QueueUrl=url, + AttributeNames=['ApproximateNumberOfMessages']) + return int(resp['Attributes']['ApproximateNumberOfMessages']) def _purge(self, queue): """Delete all current messages in a queue.""" q = self._new_queue(queue) # SQS is slow at registering messages, so run for a few - # iterations to ensure messages are deleted. + # iterations to ensure messages are detected and deleted. size = 0 for i in range(10): - size += q.count() + size += int(self._size(queue)) if not size: break - q.clear() + self.sqs.purge_queue(q) return size def close(self): super(Channel, self).close() - for conn in (self._sqs, self._asynsqs): - if conn: - try: - conn.close() - except AttributeError as exc: # FIXME ??? - if "can't set attribute" not in str(exc): - raise - - def _get_regioninfo(self, regions): - if self.regioninfo: - return self.regioninfo - if self.region: - for _r in regions: - if _r.name == self.region: - return _r - - def _aws_connect_to(self, fun, regions): - conninfo = self.conninfo - region = self._get_regioninfo(regions) - is_secure = self.is_secure if self.is_secure is not None else True - port = self.port if self.port is not None else conninfo.port - return fun(region=region, - aws_access_key_id=conninfo.userid, - aws_secret_access_key=conninfo.password, - is_secure=is_secure, - port=port) + # if self._asynsqs: + # try: + # self.asynsqs.close() + # except AttributeError as exc: # FIXME ??? + # if "can't set attribute" not in str(exc): + # raise @property def sqs(self): if self._sqs is None: - self._sqs = self._aws_connect_to(SQSConnection, regions()) + session = boto3.session.Session( + region_name=self.region, + aws_access_key_id=self.conninfo.userid, + aws_secret_access_key=self.conninfo.password, + ) + is_secure = self.is_secure if self.is_secure is not None else True + self._sqs = session.client('sqs', use_ssl=is_secure) return self._sqs @property def asynsqs(self): if self._asynsqs is None: - self._asynsqs = self._aws_connect_to( - AsyncSQSConnection, _asynsqs.regions(), + is_secure = self.is_secure if self.is_secure is not None else True + self._asynsqs = AsyncSQSConnection( + sqs_connection=self.sqs, + aws_access_key_id=self.conninfo.userid, + aws_secret_access_key=self.conninfo.password, + region=self.region, + is_secure=is_secure, ) return self._asynsqs @@ -466,10 +487,10 @@ class Transport(virtual.Transport): default_port = None connection_errors = ( virtual.Transport.connection_errors + - (exception.SQSError, socket.error) + (exceptions.BotoCoreError, socket.error) ) channel_errors = ( - virtual.Transport.channel_errors + (exception.SQSDecodeError,) + virtual.Transport.channel_errors + (exceptions.BotoCoreError,) ) driver_type = 'sqs' driver_name = 'sqs' diff --git a/requirements/extras/sqs.txt b/requirements/extras/sqs.txt index 53851e77..3d6d07a9 100644 --- a/requirements/extras/sqs.txt +++ b/requirements/extras/sqs.txt @@ -1,2 +1,2 @@ -boto>=2.8 +boto3>=1.4.4 pycurl diff --git a/requirements/funtest.txt b/requirements/funtest.txt index 55637309..033db4b8 100644 --- a/requirements/funtest.txt +++ b/requirements/funtest.txt @@ -9,6 +9,7 @@ kazoo # SQS transport boto +boto3 # Qpid transport qpid-python>=0.26 diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt index 0684adfc..f2318d5a 100644 --- a/requirements/test-ci.txt +++ b/requirements/test-ci.txt @@ -3,3 +3,4 @@ codecov redis PyYAML msgpack-python>0.2.0 +-r extras/sqs.txt diff --git a/t/integration/tests/test_SQS.py b/t/integration/tests/test_SQS.py index 676ba994..1c52d648 100644 --- a/t/integration/tests/test_SQS.py +++ b/t/integration/tests/test_SQS.py @@ -8,6 +8,7 @@ from kombu.tests.case import skip @skip.unless_environ('AWS_ACCESS_KEY_ID') @skip.unless_environ('AWS_SECRET_ACCESS_KEY') @skip.unless_module('boto') +@skip.unless_module('boto3') class test_SQS(transport.TransportCase): transport = 'SQS' prefix = 'sqs' diff --git a/t/unit/async/aws/case.py b/t/unit/async/aws/case.py index 985dbc04..70d9565f 100644 --- a/t/unit/async/aws/case.py +++ b/t/unit/async/aws/case.py @@ -7,7 +7,7 @@ from case import skip @skip.if_pypy() -@skip.unless_module('boto') +@skip.unless_module('boto3') @skip.unless_module('pycurl') @pytest.mark.usefixtures('hub') class AWSCase(object): diff --git a/t/unit/async/aws/sqs/test_connection.py b/t/unit/async/aws/sqs/test_connection.py index ecc22623..ade01868 100644 --- a/t/unit/async/aws/sqs/test_connection.py +++ b/t/unit/async/aws/sqs/test_connection.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -import pytest - -from case import Mock +from case import Mock, MagicMock from kombu.async.aws.sqs.connection import ( - AsyncSQSConnection, Attributes, BatchResults, + AsyncSQSConnection ) +from kombu.async.aws.ext import boto3 from kombu.async.aws.sqs.message import AsyncMessage from kombu.async.aws.sqs.queue import AsyncQueue from kombu.utils.uuid import uuid @@ -20,29 +19,26 @@ from ..case import AWSCase class test_AsyncSQSConnection(AWSCase): def setup(self): - self.x = AsyncSQSConnection('ak', 'sk', http_client=Mock()) + session = boto3.session.Session( + aws_access_key_id='AAA', + aws_secret_access_key='AAAA', + region_name='us-west-2', + ) + sqs_client = session.client('sqs') + self.x = AsyncSQSConnection(sqs_client, 'ak', 'sk', http_client=Mock()) self.x.get_object = Mock(name='X.get_object') self.x.get_status = Mock(name='X.get_status') - self.x.get_list = Mock(nanme='X.get_list') + self.x.get_list = Mock(name='X.get_list') self.callback = PromiseMock(name='callback') - def test_without_boto(self): - from kombu.async.aws.sqs import connection - prev, connection.boto = connection.boto, None - try: - with pytest.raises(ImportError): - AsyncSQSConnection('ak', 'sk', http_client=Mock()) - finally: - connection.boto = prev - - def test_default_region(self): - assert self.x.region - assert issubclass(self.x.region.connection_cls, AsyncSQSConnection) + sqs_client.get_queue_url = MagicMock(return_value={ + 'QueueUrl': 'http://aws.com' + }) def test_create_queue(self): self.x.create_queue('foo', callback=self.callback) self.x.get_object.assert_called_with( - 'CreateQueue', {'QueueName': 'foo'}, AsyncQueue, + 'CreateQueue', {'QueueName': 'foo'}, callback=self.callback, ) @@ -55,7 +51,7 @@ class test_AsyncSQSConnection(AWSCase): 'QueueName': 'foo', 'DefaultVisibilityTimeout': '33' }, - AsyncQueue, callback=self.callback + callback=self.callback ) def test_delete_queue(self): @@ -72,7 +68,7 @@ class test_AsyncSQSConnection(AWSCase): ) self.x.get_object.assert_called_with( 'GetQueueAttributes', {'AttributeName': 'QueueSize'}, - Attributes, queue.id, callback=self.callback, + queue.id, callback=self.callback, ) def test_set_queue_attribute(self): @@ -93,8 +89,9 @@ class test_AsyncSQSConnection(AWSCase): self.x.receive_message(queue, 4, callback=self.callback) self.x.get_list.assert_called_with( 'ReceiveMessage', {'MaxNumberOfMessages': 4}, - [('Message', queue.message_class)], - queue.id, callback=self.callback, + [('Message', AsyncMessage)], + 'http://aws.com', callback=self.callback, + parent=queue, ) def test_receive_message__with_visibility_timeout(self): @@ -105,8 +102,9 @@ class test_AsyncSQSConnection(AWSCase): 'MaxNumberOfMessages': 4, 'VisibilityTimeout': 3666, }, - [('Message', queue.message_class)], - queue.id, callback=self.callback, + [('Message', AsyncMessage)], + 'http://aws.com', callback=self.callback, + parent=queue, ) def test_receive_message__with_wait_time_seconds(self): @@ -119,8 +117,9 @@ class test_AsyncSQSConnection(AWSCase): 'MaxNumberOfMessages': 4, 'WaitTimeSeconds': 303, }, - [('Message', queue.message_class)], - queue.id, callback=self.callback, + [('Message', AsyncMessage)], + 'http://aws.com', callback=self.callback, + parent=queue, ) def test_receive_message__with_attributes(self): @@ -134,8 +133,9 @@ class test_AsyncSQSConnection(AWSCase): 'AttributeName.2': 'bar', 'MaxNumberOfMessages': 4, }, - [('Message', queue.message_class)], - queue.id, callback=self.callback, + [('Message', AsyncMessage)], + 'http://aws.com', callback=self.callback, + parent=queue, ) def MockMessage(self, id=None, receipt_handle=None, body=None): @@ -157,10 +157,11 @@ class test_AsyncSQSConnection(AWSCase): def test_delete_message(self): queue = Mock(name='queue') message = self.MockMessage() - self.x.delete_message(queue, message, callback=self.callback) + self.x.delete_message(queue, message.receipt_handle, + callback=self.callback) self.x.get_status.assert_called_with( 'DeleteMessage', {'ReceiptHandle': message.receipt_handle}, - queue.id, callback=self.callback, + queue, callback=self.callback, ) def test_delete_message_batch(self): @@ -175,7 +176,7 @@ class test_AsyncSQSConnection(AWSCase): 'DeleteMessageBatchRequestEntry.2.Id': '2', 'DeleteMessageBatchRequestEntry.2.ReceiptHandle': 'r2', }, - BatchResults, queue.id, verb='POST', callback=self.callback, + queue.id, verb='POST', callback=self.callback, ) def test_send_message(self): @@ -183,7 +184,7 @@ class test_AsyncSQSConnection(AWSCase): self.x.send_message(queue, 'hello', callback=self.callback) self.x.get_object.assert_called_with( 'SendMessage', {'MessageBody': 'hello'}, - AsyncMessage, queue.id, verb='POST', callback=self.callback, + queue.id, verb='POST', callback=self.callback, ) def test_send_message__with_delay_seconds(self): @@ -193,7 +194,7 @@ class test_AsyncSQSConnection(AWSCase): ) self.x.get_object.assert_called_with( 'SendMessage', {'MessageBody': 'hello', 'DelaySeconds': 303}, - AsyncMessage, queue.id, verb='POST', callback=self.callback, + queue.id, verb='POST', callback=self.callback, ) def test_send_message_batch(self): @@ -213,7 +214,7 @@ class test_AsyncSQSConnection(AWSCase): 'SendMessageBatchRequestEntry.2.MessageBody': 'B', 'SendMessageBatchRequestEntry.2.DelaySeconds': 303, }, - BatchResults, queue.id, verb='POST', callback=self.callback, + queue.id, verb='POST', callback=self.callback, ) def test_change_message_visibility(self): @@ -251,7 +252,7 @@ class test_AsyncSQSConnection(AWSCase): preamble('2.ReceiptHandle'): 'r2', preamble('2.VisibilityTimeout'): 909, }, - BatchResults, queue.id, verb='POST', callback=self.callback, + queue.id, verb='POST', callback=self.callback, ) def test_get_all_queues(self): diff --git a/t/unit/async/aws/sqs/test_message.py b/t/unit/async/aws/sqs/test_message.py deleted file mode 100644 index 44a0ac32..00000000 --- a/t/unit/async/aws/sqs/test_message.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, unicode_literals - -from case import Mock - -from kombu.async.aws.sqs.message import AsyncMessage -from kombu.utils.uuid import uuid - -from t.mocks import PromiseMock - -from ..case import AWSCase - - -class test_AsyncMessage(AWSCase): - - def setup(self): - self.queue = Mock(name='queue') - self.callback = PromiseMock(name='callback') - self.x = AsyncMessage(self.queue, 'body') - self.x.receipt_handle = uuid() - - def test_delete(self): - assert self.x.delete(callback=self.callback) - self.x.queue.delete_message.assert_called_with( - self.x, self.callback, - ) - - self.x.queue = None - assert self.x.delete(callback=self.callback) is None - - def test_change_visibility(self): - assert self.x.change_visibility(303, callback=self.callback) - self.x.queue.connection.change_message_visibility.assert_called_with( - self.x.queue, self.x.receipt_handle, 303, self.callback, - ) - self.x.queue = None - assert self.x.change_visibility(303, callback=self.callback) is None diff --git a/t/unit/async/aws/sqs/test_sqs.py b/t/unit/async/aws/sqs/test_sqs.py deleted file mode 100644 index ea58596d..00000000 --- a/t/unit/async/aws/sqs/test_sqs.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, unicode_literals - -import pytest - -from case import Mock, patch - -from kombu.async.aws.sqs import regions, connect_to_region -from kombu.async.aws.sqs.connection import AsyncSQSConnection - -from ..case import AWSCase - - -class test_connect_to_region(AWSCase): - - def test_when_no_boto_installed(self, patching): - patching('kombu.async.aws.sqs.boto', None) - with pytest.raises(ImportError): - regions() - - def test_using_async_connection(self): - for region in regions(): - assert region.connection_cls is AsyncSQSConnection - - def test_connect_to_region(self): - with patch('kombu.async.aws.sqs.regions') as regions: - region = Mock(name='region') - region.name = 'us-west-1' - regions.return_value = [region] - conn = connect_to_region('us-west-1', kw=3.33) - assert conn is region.connect.return_value - region.connect.assert_called_with(kw=3.33) - - assert connect_to_region('foo') is None diff --git a/t/unit/async/aws/test_aws.py b/t/unit/async/aws/test_aws.py index f5ed9aef..29c535ab 100644 --- a/t/unit/async/aws/test_aws.py +++ b/t/unit/async/aws/test_aws.py @@ -13,4 +13,4 @@ class test_connect_sqs(AWSCase): def test_connection(self): x = connect_sqs('AAKI', 'ASAK', http_client=Mock()) assert x - assert x.connection + assert x.sqs_connection diff --git a/t/unit/async/aws/test_connection.py b/t/unit/async/aws/test_connection.py index 5de76aa5..cba953b3 100644 --- a/t/unit/async/aws/test_connection.py +++ b/t/unit/async/aws/test_connection.py @@ -1,11 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -import pytest - from contextlib import contextmanager -from case import Mock, patch +from case import Mock from vine.abstract import Thenable from kombu.exceptions import HttpError @@ -13,18 +11,22 @@ from kombu.five import WhateverIO from kombu.async import http from kombu.async.aws.connection import ( - AsyncHTTPConnection, AsyncHTTPSConnection, AsyncHTTPResponse, AsyncConnection, - AsyncAWSAuthConnection, AsyncAWSQueryConnection, ) +from kombu.async.aws.ext import boto3 from .case import AWSCase from t.mocks import PromiseMock +try: + from urllib.parse import urlparse, parse_qs +except ImportError: + from urlparse import urlparse, parse_qs # noqa + # Not currently working VALIDATES_CERT = False @@ -38,37 +40,30 @@ def passthrough(*args, **kwargs): return m -class test_AsyncHTTPConnection(AWSCase): - - def test_AsyncHTTPSConnection(self): - x = AsyncHTTPSConnection('aws.vandelay.com') - assert x.scheme == 'https' +class test_AsyncHTTPSConnection(AWSCase): def test_http_client(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection() assert x.http_client is http.get_client() client = Mock(name='http_client') - y = AsyncHTTPConnection('aws.vandelay.com', http_client=client) + y = AsyncHTTPSConnection(http_client=client) assert y.http_client is client def test_args(self): - x = AsyncHTTPConnection( - 'aws.vandelay.com', 8083, strict=True, timeout=33.3, + x = AsyncHTTPSConnection( + strict=True, timeout=33.3, ) - assert x.host == 'aws.vandelay.com' - assert x.port == 8083 assert x.strict assert x.timeout == 33.3 - assert x.scheme == 'http' def test_request(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection('aws.vandelay.com') x.request('PUT', '/importer-exporter') assert x.path == '/importer-exporter' assert x.method == 'PUT' def test_request_with_body_buffer(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection('aws.vandelay.com') body = Mock(name='body') body.read.return_value = 'Vandelay Industries' x.request('PUT', '/importer-exporter', body) @@ -78,14 +73,14 @@ class test_AsyncHTTPConnection(AWSCase): body.read.assert_called_with() def test_request_with_body_text(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection('aws.vandelay.com') x.request('PUT', '/importer-exporter', 'Vandelay Industries') assert x.method == 'PUT' assert x.path == '/importer-exporter' assert x.body == 'Vandelay Industries' def test_request_with_headers(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection() headers = {'Proxy': 'proxy.vandelay.com'} x.request('PUT', '/importer-exporter', None, headers) assert 'Proxy' in dict(x.headers) @@ -99,27 +94,10 @@ class test_AsyncHTTPConnection(AWSCase): validate_cert=VALIDATES_CERT, ) - def test_getrequest_AsyncHTTPSConnection(self): - x = AsyncHTTPSConnection('aws.vandelay.com') - x.Request = Mock(name='Request') - x.getrequest() - self.assert_request_created_with('https://aws.vandelay.com/', x) - - def test_getrequest_nondefault_port(self): - x = AsyncHTTPConnection('aws.vandelay.com', port=8080) - x.Request = Mock(name='Request') - x.getrequest() - self.assert_request_created_with('http://aws.vandelay.com:8080/', x) - - y = AsyncHTTPSConnection('aws.vandelay.com', port=8443) - y.Request = Mock(name='Request') - y.getrequest() - self.assert_request_created_with('https://aws.vandelay.com:8443/', y) - def test_getresponse(self): client = Mock(name='client') client.add_request = passthrough(name='client.add_request') - x = AsyncHTTPConnection('aws.vandelay.com', http_client=client) + x = AsyncHTTPSConnection(http_client=client) x.Response = Mock(name='x.Response') request = x.getresponse() x.http_client.add_request.assert_called_with(request) @@ -134,7 +112,7 @@ class test_AsyncHTTPConnection(AWSCase): client = Mock(name='client') client.add_request = passthrough(name='client.add_request') callback = PromiseMock(name='callback') - x = AsyncHTTPConnection('aws.vandelay.com', http_client=client) + x = AsyncHTTPSConnection(http_client=client) request = x.getresponse(callback) x.http_client.add_request.assert_called_with(request) @@ -151,22 +129,22 @@ class test_AsyncHTTPConnection(AWSCase): assert wresponse.read() == 'The quick brown fox jumps' assert wresponse.status == 200 assert wresponse.getheader('X-Foo') == 'Hello' - assert dict(wresponse.getheaders()) == headers - assert wresponse.msg + headers_dict = wresponse.getheaders() + assert dict(headers_dict) == headers assert wresponse.msg assert repr(wresponse) def test_repr(self): - assert repr(AsyncHTTPConnection('aws.vandelay.com')) + assert repr(AsyncHTTPSConnection()) def test_putrequest(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection() x.putrequest('UPLOAD', '/new') assert x.method == 'UPLOAD' assert x.path == '/new' def test_putheader(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection() x.putheader('X-Foo', 'bar') assert x.headers == [('X-Foo', 'bar')] x.putheader('X-Bar', 'baz') @@ -176,14 +154,14 @@ class test_AsyncHTTPConnection(AWSCase): ] def test_send(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection() x.send('foo') assert x.body == 'foo' x.send('bar') assert x.body == 'foobar' def test_interface(self): - x = AsyncHTTPConnection('aws.vandelay.com') + x = AsyncHTTPSConnection() assert x.set_debuglevel(3) is None assert x.connect() is None assert x.close() is None @@ -204,102 +182,49 @@ class test_AsyncHTTPResponse(AWSCase): class test_AsyncConnection(AWSCase): - def test_when_boto_missing(self, patching): - patching('kombu.async.aws.connection.boto', None) - with pytest.raises(ImportError): - AsyncConnection(Mock(name='client')) - def test_client(self): - x = AsyncConnection() + sqs = Mock(name='sqs') + x = AsyncConnection(sqs) assert x._httpclient is http.get_client() client = Mock(name='client') - y = AsyncConnection(http_client=client) + y = AsyncConnection(sqs, http_client=client) assert y._httpclient is client def test_get_http_connection(self): - x = AsyncConnection(client=Mock(name='client')) - assert isinstance( - x.get_http_connection('aws.vandelay.com', 80, False), - AsyncHTTPConnection, - ) + sqs = Mock(name='sqs') + x = AsyncConnection(sqs) assert isinstance( - x.get_http_connection('aws.vandelay.com', 443, True), + x.get_http_connection(), AsyncHTTPSConnection, ) - - conn = x.get_http_connection('aws.vandelay.com', 80, False) + conn = x.get_http_connection() assert conn.http_client is x._httpclient - assert conn.host == 'aws.vandelay.com' - assert conn.port == 80 - - -class test_AsyncAWSAuthConnection(AWSCase): - - @patch('boto.log', create=True) - def test_make_request(self, _): - x = AsyncAWSAuthConnection('aws.vandelay.com', - http_client=Mock(name='client')) - Conn = x.get_http_connection = Mock(name='get_http_connection') - callback = PromiseMock(name='callback') - ret = x.make_request('GET', '/foo', callback=callback) - assert ret is callback - Conn.return_value.request.assert_called() - Conn.return_value.getresponse.assert_called_with( - callback=callback, - ) - - @patch('boto.log', create=True) - def test_mexe(self, _): - x = AsyncAWSAuthConnection('aws.vandelay.com', - http_client=Mock(name='client')) - Conn = x.get_http_connection = Mock(name='get_http_connection') - request = x.build_base_http_request('GET', 'foo', '/auth') - callback = PromiseMock(name='callback') - x._mexe(request, callback=callback) - Conn.return_value.request.assert_called_with( - request.method, request.path, request.body, request.headers, - ) - Conn.return_value.getresponse.assert_called_with( - callback=callback, - ) - - no_callback_ret = x._mexe(request) - # _mexe always returns promise - assert isinstance(no_callback_ret, Thenable) - - @patch('boto.log', create=True) - def test_mexe__with_sender(self, _): - x = AsyncAWSAuthConnection('aws.vandelay.com', - http_client=Mock(name='client')) - Conn = x.get_http_connection = Mock(name='get_http_connection') - request = x.build_base_http_request('GET', 'foo', '/auth') - sender = Mock(name='sender') - callback = PromiseMock(name='callback') - x._mexe(request, sender=sender, callback=callback) - sender.assert_called_with( - Conn.return_value, request.method, request.path, - request.body, request.headers, callback, - ) class test_AsyncAWSQueryConnection(AWSCase): def setup(self): - self.x = AsyncAWSQueryConnection('aws.vandelay.com', + session = boto3.session.Session( + aws_access_key_id='AAA', + aws_secret_access_key='AAAA', + region_name='us-west-2', + ) + sqs_client = session.client('sqs') + self.x = AsyncAWSQueryConnection(sqs_client, http_client=Mock(name='client')) - @patch('boto.log', create=True) - def test_make_request(self, _): + def test_make_request(self): _mexe, self.x._mexe = self.x._mexe, Mock(name='_mexe') Conn = self.x.get_http_connection = Mock(name='get_http_connection') callback = PromiseMock(name='callback') self.x.make_request( - 'action', {'foo': 1}, '/', 'GET', callback=callback, + 'action', {'foo': 1}, 'https://foo.com/', 'GET', callback=callback, ) self.x._mexe.assert_called() request = self.x._mexe.call_args[0][0] - assert request.params['Action'] == 'action' - assert request.params['Version'] == self.x.APIVersion + parsed = urlparse(request.url) + params = parse_qs(parsed.query) + assert params['Action'][0] == 'action' ret = _mexe(request, callback=callback) assert ret is callback @@ -308,29 +233,18 @@ class test_AsyncAWSQueryConnection(AWSCase): callback=callback, ) - @patch('boto.log', create=True) - def test_make_request__no_action(self, _): + def test_make_request__no_action(self): self.x._mexe = Mock(name='_mexe') self.x.get_http_connection = Mock(name='get_http_connection') callback = PromiseMock(name='callback') self.x.make_request( - None, {'foo': 1}, '/', 'GET', callback=callback, + None, {'foo': 1}, 'http://foo.com/', 'GET', callback=callback, ) self.x._mexe.assert_called() request = self.x._mexe.call_args[0][0] - assert 'Action' not in request.params - assert request.params['Version'] == self.x.APIVersion - - @contextmanager - def mock_sax_parse(self, parser): - with patch('kombu.async.aws.connection.sax_parse') as sax_parse: - with patch('kombu.async.aws.connection.XmlHandler') as xh: - - def effect(body, h): - return parser(xh.call_args[0][0], body, h) - sax_parse.side_effect = effect - yield (sax_parse, xh) - sax_parse.assert_called() + parsed = urlparse(request.url) + params = parse_qs(parsed.query) + assert 'Action' not in params def Response(self, status, body): r = Mock(name='response') @@ -347,90 +261,3 @@ class test_AsyncAWSQueryConnection(AWSCase): def assert_make_request_called(self): self.x.make_request.assert_called() return self.x.make_request.call_args[1]['callback'] - - def test_get_list(self): - with self.mock_make_request() as callback: - self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback) - on_ready = self.assert_make_request_called() - - def parser(dest, body, h): - dest.append('hi') - dest.append('there') - - with self.mock_sax_parse(parser): - on_ready(self.Response(200, 'hello')) - callback.assert_called_with(['hi', 'there']) - - def test_get_list_error(self): - with self.mock_make_request() as callback: - self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback) - on_ready = self.assert_make_request_called() - - with pytest.raises(self.x.ResponseError): - on_ready(self.Response(404, 'Not found')) - - def test_get_object(self): - with self.mock_make_request() as callback: - - class Result(object): - parent = None - value = None - - def __init__(self, parent): - self.parent = parent - - self.x.get_object('action', {'p': 3.3}, Result, callback=callback) - on_ready = self.assert_make_request_called() - - def parser(dest, body, h): - dest.value = 42 - - with self.mock_sax_parse(parser): - on_ready(self.Response(200, 'hello')) - - callback.assert_called() - result = callback.call_args[0][0] - assert result.value == 42 - assert result.parent - - def test_get_object_error(self): - with self.mock_make_request() as callback: - self.x.get_object('action', {'p': 3.3}, object, callback=callback) - on_ready = self.assert_make_request_called() - - with pytest.raises(self.x.ResponseError): - on_ready(self.Response(404, 'Not found')) - - def test_get_status(self): - with self.mock_make_request() as callback: - self.x.get_status('action', {'p': 3.3}, callback=callback) - on_ready = self.assert_make_request_called() - set_status_to = [True] - - def parser(dest, body, b): - dest.status = set_status_to[0] - - with self.mock_sax_parse(parser): - on_ready(self.Response(200, 'hello')) - callback.assert_called_with(True) - - set_status_to[0] = False - with self.mock_sax_parse(parser): - on_ready(self.Response(200, 'hello')) - callback.assert_called_with(False) - - def test_get_status_error(self): - with self.mock_make_request() as callback: - self.x.get_status('action', {'p': 3.3}, callback=callback) - on_ready = self.assert_make_request_called() - - with pytest.raises(self.x.ResponseError): - on_ready(self.Response(404, 'Not found')) - - def test_get_status_error_empty_body(self): - with self.mock_make_request() as callback: - self.x.get_status('action', {'p': 3.3}, callback=callback) - on_ready = self.assert_make_request_called() - - with pytest.raises(self.x.ResponseError): - on_ready(self.Response(200, '')) diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py index f245fdab..180b33bf 100644 --- a/t/unit/transport/test_SQS.py +++ b/t/unit/transport/test_SQS.py @@ -8,97 +8,107 @@ slightly. from __future__ import absolute_import, unicode_literals import pytest +import random +import string from case import Mock, skip from kombu import messaging from kombu import Connection, Exchange, Queue -from kombu.async.aws.ext import exception from kombu.five import Empty from kombu.transport import SQS -class SQSQueueMock(object): - - def __init__(self, name): - self.name = name - self.messages = [] - self._get_message_calls = 0 - - def clear(self, page_size=10, vtimeout=10): - empty, self.messages[:] = not self.messages, [] - return not empty - - def count(self, page_size=10, vtimeout=10): - return len(self.messages) - count_slow = count +class SQSMessageMock(object): + def __init__(self): + """ + Imitate the SQS Message from boto3. + """ + self.body = "" + self.receipt_handle = "receipt_handle_xyz" - def delete(self): - self.messages[:] = [] - return True - def delete_message(self, message): - try: - self.messages.remove(message) - except ValueError: - return False - return True +class QueueMock(object): + """ Hold information about a queue. """ - def get_messages(self, num_messages=1, visibility_timeout=None, - attributes=None, *args, **kwargs): - self._get_message_calls += 1 - messages, self.messages[:num_messages] = ( - self.messages[:num_messages], []) - return messages + def __init__(self, url): + self.url = url + self.attributes = {'ApproximateNumberOfMessages': '0'} - def read(self, visibility_timeout=None): - return self.messages.pop(0) + self.messages = [] - def write(self, message): - self.messages.append(message) - return True + def __repr__(self): + return 'QueueMock: {} {} messages'.format(self.url, len(self.messages)) -class SQSConnectionMock(object): +class SQSClientMock(object): def __init__(self): - self.queues = { - 'q_%s' % n: SQSQueueMock('q_%s' % n) for n in range(1500) - } - q = SQSQueueMock('unittest_queue') - q.write('hello') - self.queues['unittest_queue'] = q - - def get_queue(self, queue): - return self.queues.get(queue) - - def get_all_queues(self, prefix=""): - if not prefix: - keys = sorted(self.queues.keys())[:1000] - else: - keys = list(filter( - lambda k: k.startswith(prefix), sorted(self.queues.keys()) - ))[:1000] - return [self.queues[key] for key in keys] + """ + Imitate the SQS Client from boto3. + """ + self._receive_messages_calls = 0 + # _queues doesn't exist on the real client, here for testing. + self._queues = {} + for n in range(1): + name = 'q_{}'.format(n) + url = 'sqs://q_{}'.format(n) + self.create_queue(QueueName=name) + + url = self.create_queue(QueueName='unittest_queue')['QueueUrl'] + self.send_message(QueueUrl=url, MessageBody='hello') + + def _get_q(self, url): + """ Helper method to quickly get a queue. """ + for q in self._queues.values(): + if q.url == url: + return q + raise Exception("Queue url {} not found".format(url)) + + def create_queue(self, QueueName=None, Attributes=None): + q = self._queues[QueueName] = QueueMock('sqs://' + QueueName) + return {'QueueUrl': q.url} + + def list_queues(self, QueueNamePrefix=None): + """ Return a list of queue urls """ + urls = (val.url for key, val in self._queues.items() + if key.startswith(QueueNamePrefix)) + return {'QueueUrls': urls} + + def get_queue_url(self, QueueName=None): + return self._queues[QueueName] + + def send_message(self, QueueUrl=None, MessageBody=None): + for q in self._queues.values(): + if q.url == QueueUrl: + handle = ''.join(random.choice(string.ascii_lowercase) for + x in range(10)) + q.messages.append({'Body': MessageBody, + 'ReceiptHandle': handle}) + break - def delete_queue(self, queue, force_deletion=False): - q = self.get_queue(queue) - if q: - if q.count(): - return False - q.clear() - self.queues.pop(queue, None) + def receive_message(self, QueueUrl=None, MaxNumberOfMessages=1): + self._receive_messages_calls += 1 + for q in self._queues.values(): + if q.url == QueueUrl: + msgs = q.messages[:MaxNumberOfMessages] + q.messages = q.messages[MaxNumberOfMessages:] + return {'Messages': msgs} - def delete_message(self, queue, message): - return queue.delete_message(message) + def get_queue_attributes(self, QueueUrl=None, AttributeNames=None): + if 'ApproximateNumberOfMessages' in AttributeNames: + count = len(self._get_q(QueueUrl).messages) + return {'Attributes': {'ApproximateNumberOfMessages': count}} - def create_queue(self, name, *args, **kwargs): - q = self.queues[name] = SQSQueueMock(name) - return q + def purge_queue(self, QueueUrl=None): + for q in self._queues.values(): + if q.url == QueueUrl: + q.messages = [] @skip.unless_module('boto') +@skip.unless_module('boto3') class test_Channel: def handleMessageCallback(self, message): @@ -115,7 +125,7 @@ class test_Channel: # Mock the sqs() method that returns an SQSConnection object and # instead return an SQSConnectionMock() object. - self.sqs_conn_mock = SQSConnectionMock() + self.sqs_conn_mock = SQSClientMock() def mock_sqs(): return self.sqs_conn_mock @@ -125,9 +135,9 @@ class test_Channel: self.exchange = Exchange('test_SQS', type='direct') self.queue = Queue(self.queue_name, self.exchange, self.queue_name) - # Mock up a test SQS Queue with the SQSQueueMock class (and always + # Mock up a test SQS Queue with the QueueMock class (and always # make sure its a clean empty queue) - self.sqs_queue_mock = SQSQueueMock(self.queue_name) + self.sqs_queue_mock = QueueMock('sqs://' + self.queue_name) # Now, create our Connection object with the SQS Transport and store # the connection/channel objects as references for use in these tests. @@ -160,33 +170,10 @@ class test_Channel: """kombu.SQS.Channel instantiates correctly with mocked queues""" assert self.queue_name in self.channel._queue_cache - def test_auth_fail(self): - normal_func = SQS.Channel.sqs.get_all_queues - - def get_all_queues_fail_403(prefix=''): - # mock auth error - raise exception.SQSError(403, None, None) - - def get_all_queues_fail_not_403(prefix=''): - # mock non-auth error - raise exception.SQSError(500, None, None) - - try: - SQS.Channel.sqs.access_key = '1234' - SQS.Channel.sqs.get_all_queues = get_all_queues_fail_403 - with pytest.raises(RuntimeError) as excinfo: - self.channel = self.connection.channel() - assert 'access_key=1234' in str(excinfo.value) - SQS.Channel.sqs.get_all_queues = get_all_queues_fail_not_403 - with pytest.raises(exception.SQSError): - self.channel = self.connection.channel() - finally: - SQS.Channel.sqs.get_all_queues = normal_func - def test_new_queue(self): queue_name = 'new_unittest_queue' self.channel._new_queue(queue_name) - assert queue_name in self.sqs_conn_mock.queues + assert queue_name in self.sqs_conn_mock._queues.keys() # For cleanup purposes, delete the queue and the queue file self.channel._delete(queue_name) @@ -195,11 +182,13 @@ class test_Channel: # which is definitely out of cache when get_all_queues returns the # first 1000 queues sorted by name. queue_name = 'unittest_queue' + # This should not create a new queue. self.channel._new_queue(queue_name) - assert queue_name in self.sqs_conn_mock.queues - q = self.sqs_conn_mock.get_queue(queue_name) - assert 1 == q.count() - assert 'hello' == q.read() + assert queue_name in self.sqs_conn_mock._queues.keys() + queue = self.sqs_conn_mock._queues[queue_name] + # The queue originally had 1 message in it. + assert 1 == len(queue.messages) + assert 'hello' == queue.messages[0]['Body'] def test_delete(self): queue_name = 'new_unittest_queue' @@ -211,17 +200,16 @@ class test_Channel: # Test getting a single message message = 'my test message' self.producer.publish(message) - q = self.channel._new_queue(self.queue_name) - results = q.get_messages() - assert len(results) == 1 + result = self.channel._get(self.queue_name) + assert 'body' in result.keys() # Now test getting many messages for i in range(3): message = 'message: {0}'.format(i) self.producer.publish(message) - results = q.get_messages(num_messages=3) - assert len(results) == 3 + self.channel._get_bulk(self.queue_name, max_if_unlimited=3) + assert len(self.sqs_conn_mock._queues[self.queue_name].messages) == 0 def test_get_with_empty_list(self): with pytest.raises(Empty): @@ -244,10 +232,21 @@ class test_Channel: message = {'foo': 'bar'} self.channel._put(self.producer.routing_key, message) - q = self.channel._new_queue(self.queue_name) + q_url = self.channel._new_queue(self.queue_name) # Get the messages now - kombu_messages = q.get_messages(num_messages=kombu_message_count) - json_messages = q.get_messages(num_messages=json_message_count) + kombu_messages = [] + from kombu.async.aws.sqs.ext import Message + for m in self.sqs_conn_mock.receive_message( + QueueUrl=q_url, + MaxNumberOfMessages=kombu_message_count)['Messages']: + m['Body'] = Message().decode(m['Body']) + kombu_messages.append(m) + json_messages = [] + for m in self.sqs_conn_mock.receive_message( + QueueUrl=q_url, + MaxNumberOfMessages=json_message_count)['Messages']: + m['Body'] = Message().decode(m['Body']) + json_messages.append(m) # Now convert them to payloads kombu_payloads = self.channel._messages_to_python( @@ -350,7 +349,7 @@ class test_Channel: def test_drain_events_with_prefetch_none(self): # Generate 20 messages message_count = 20 - expected_get_message_count = 3 + expected_receive_messages_count = 3 current_delivery_tag = [1] @@ -378,6 +377,7 @@ class test_Channel: assert self.channel.connection._deliver.call_count == message_count - # How many times was the SQSConnectionMock get_message method called? - assert (expected_get_message_count == - self.channel._queue_cache[self.queue_name]._get_message_calls) + # How many times was the SQSConnectionMock receive_message method + # called? + assert (expected_receive_messages_count == + self.sqs_conn_mock._receive_messages_calls) |