summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMischa Spiegelmock <revmischa@cpan.org>2017-04-13 22:22:18 -0700
committerAsif Saifuddin Auvi <auvipy@users.noreply.github.com>2017-04-14 11:22:18 +0600
commit129a9e4ed05bf9a99d12fff9e17c9ffb37b14c4d (patch)
treec4eea532ea4ca84dee0049ab4304d86f0cc73ab7
parentbf820b20b022556c72402565f0ae50124017d6fe (diff)
downloadkombu-129a9e4ed05bf9a99d12fff9e17c9ffb37b14c4d.tar.gz
Switching to boto3 only (#693)
* Switch Boto2 to Boto3 for SQS messaging * Fixed region support * Add SQS FIFO queue support * Add sensible defaults for message attributes * Asynchronous support, plus boto3 for region endpoint lookups * Clean up imports * Fix Python 2 support * Fix receive_message tests * Reformat docstring * boto3 import changes for CI * skip tests if boto3 not installed * skip tests if boto3 not installed * flake8 * noboto * ditching boto2. got queue URL fetching, async HTTP request generation and signing working. * request signing working kinda * async parsing of SQS message response more or less working * botocore sqs dep * ripping out more old boto2 stuff * removing tests that are no longer valid with boto3/SQS * fix boto3 dep, min version and no botocore * no boto2 for test * cleaning up some SQS tests. fixing header parsing of response to msg * fixing some sqs tests * removing response-parsing tests that are no longer necessary as we're using the botocore response parsing machinery instead of implementing SAX parsing in kombu. * fixing more SQS tests * wants a region * trying to fix py2 parsing of sqs message * lint * py2/py2 message header parsing stupidness * forgot * python 2 sux * flake8 * Import boto3 from the right place * Changes * Update encode fuction * Fix lint * remove some unused things * removing unused stuff * ugh * ugh * ugh * landscape ignoring * shut up, landscape
-rw-r--r--examples/hello_consumer.py2
-rw-r--r--kombu/async/aws/connection.py229
-rw-r--r--kombu/async/aws/ext.py31
-rw-r--r--kombu/async/aws/sqs/__init__.py22
-rw-r--r--kombu/async/aws/sqs/connection.py60
-rw-r--r--kombu/async/aws/sqs/ext.py31
-rw-r--r--kombu/async/aws/sqs/message.py51
-rw-r--r--kombu/async/aws/sqs/queue.py3
-rw-r--r--kombu/async/http/base.py9
-rw-r--r--kombu/async/http/curl.py17
-rw-r--r--kombu/transport/SQS.py197
-rw-r--r--requirements/extras/sqs.txt2
-rw-r--r--requirements/funtest.txt1
-rw-r--r--requirements/test-ci.txt1
-rw-r--r--t/integration/tests/test_SQS.py1
-rw-r--r--t/unit/async/aws/case.py2
-rw-r--r--t/unit/async/aws/sqs/test_connection.py73
-rw-r--r--t/unit/async/aws/sqs/test_message.py37
-rw-r--r--t/unit/async/aws/sqs/test_sqs.py34
-rw-r--r--t/unit/async/aws/test_aws.py2
-rw-r--r--t/unit/async/aws/test_connection.py271
-rw-r--r--t/unit/transport/test_SQS.py218
22 files changed, 491 insertions, 803 deletions
diff --git a/examples/hello_consumer.py b/examples/hello_consumer.py
index 5722450e..71f40f13 100644
--- a/examples/hello_consumer.py
+++ b/examples/hello_consumer.py
@@ -1,6 +1,6 @@
from __future__ import absolute_import, unicode_literals, print_function
-from kombu import Connection
+from kombu import Connection # noqa
with Connection('amqp://guest:guest@localhost:5672//') as conn:
diff --git a/kombu/async/aws/connection.py b/kombu/async/aws/connection.py
index cc907d7b..303f3577 100644
--- a/kombu/async/aws/connection.py
+++ b/kombu/async/aws/connection.py
@@ -1,37 +1,36 @@
-# -*- coding: utf-8 -*-
+# * coding: utf8 *
"""Amazon AWS Connection."""
from __future__ import absolute_import, unicode_literals
-from io import BytesIO
-
from vine import promise, transform
+from kombu.async.aws.ext import AWSRequest, get_response
+
from kombu.async.http import Headers, Request, get_client
from kombu.five import items, python_2_unicode_compatible
-from .ext import (
- boto, AWSAuthConnection, AWSQueryConnection, XmlHandler, ResultSet,
-)
-
-try:
- from urllib.parse import urlunsplit
-except ImportError:
- from urlparse import urlunsplit # noqa
-from xml.sax import parseString as sax_parse # noqa
+import io
try: # pragma: no cover
- from email import message_from_file
+ from email import message_from_bytes
from email.mime.message import MIMEMessage
+
+ # py3
+ def message_from_headers(hdr): # noqa
+ bs = "\r\n".join("{}: {}".format(*h) for h in hdr)
+ return message_from_bytes(bs.encode())
+
except ImportError: # pragma: no cover
from mimetools import Message as MIMEMessage # noqa
- def message_from_file(m): # noqa
- return m
+ # py2
+ def message_from_headers(hdr): # noqa
+ return io.BytesIO(b'\r\n'.join(
+ b'{0}: {1}'.format(*h) for h in hdr
+ ))
__all__ = [
- 'AsyncHTTPConnection', 'AsyncHTTPSConnection',
- 'AsyncHTTPResponse', 'AsyncConnection',
- 'AsyncAWSAuthConnection', 'AsyncAWSQueryConnection',
+ 'AsyncHTTPSConnection', 'AsyncConnection',
]
@@ -56,11 +55,7 @@ class AsyncHTTPResponse(object):
@property
def msg(self):
if self._msg is None:
- self._msg = MIMEMessage(message_from_file(
- BytesIO(b'\r\n'.join(
- b'{0}: {1}'.format(*h) for h in self.getheaders())
- )
- ))
+ self._msg = MIMEMessage(message_from_headers(self.getheaders()))
return self._msg
@property
@@ -78,7 +73,7 @@ class AsyncHTTPResponse(object):
@python_2_unicode_compatible
-class AsyncHTTPConnection(object):
+class AsyncHTTPSConnection(object):
"""Async HTTP Connection."""
Request = Request
@@ -87,13 +82,9 @@ class AsyncHTTPConnection(object):
method = 'GET'
path = '/'
body = None
- scheme = 'http'
default_ports = {'http': 80, 'https': 443}
- def __init__(self, host, port=None,
- strict=None, timeout=20.0, http_client=None, **kwargs):
- self.host = host
- self.port = port
+ def __init__(self, strict=None, timeout=20.0, http_client=None):
self.headers = []
self.timeout = timeout
self.strict = strict
@@ -112,14 +103,9 @@ class AsyncHTTPConnection(object):
if headers is not None:
self.headers.extend(list(items(headers)))
- def getrequest(self, scheme=None):
- scheme = scheme if scheme else self.scheme
- host = self.host
- if self.port and self.port != self.default_ports[scheme]:
- host = '{0}:{1}'.format(host, self.port)
- url = urlunsplit((scheme, host, self.path, '', ''))
+ def getrequest(self):
headers = Headers(self.headers)
- return self.Request(url, method=self.method, headers=headers,
+ return self.Request(self.path, method=self.method, headers=headers,
body=self.body, connect_timeout=self.timeout,
request_timeout=self.timeout, validate_cert=False)
@@ -137,7 +123,7 @@ class AsyncHTTPConnection(object):
def close(self):
pass
- def putrequest(self, method, path, **kwargs):
+ def putrequest(self, method, path):
self.method = method
self.path = path
@@ -157,139 +143,120 @@ class AsyncHTTPConnection(object):
return '<AsyncHTTPConnection: {0!r}>'.format(self.getrequest())
-class AsyncHTTPSConnection(AsyncHTTPConnection):
- """Async HTTPS Connection."""
-
- scheme = 'https'
-
-
class AsyncConnection(object):
"""Async AWS Connection."""
- def __init__(self, http_client=None, **kwargs):
- if boto is None:
- raise ImportError('boto is not installed')
+ def __init__(self, sqs_connection, http_client=None, **kwargs): # noqa
+ self.sqs_connection = sqs_connection
self._httpclient = http_client or get_client()
- def get_http_connection(self, host, port, is_secure):
- return (AsyncHTTPSConnection if is_secure else AsyncHTTPConnection)(
- host, port, http_client=self._httpclient,
- )
+ def get_http_connection(self):
+ return AsyncHTTPSConnection(http_client=self._httpclient)
def _mexe(self, request, sender=None, callback=None):
callback = callback or promise()
- boto.log.debug(
- 'HTTP %s/%s headers=%s body=%s',
- request.host, request.path,
- request.headers, request.body,
- )
-
- conn = self.get_http_connection(
- request.host, request.port, self.is_secure,
- )
- request.authorize(connection=self)
+ conn = self.get_http_connection()
if callable(sender):
sender(conn, request.method, request.path, request.body,
request.headers, callback)
else:
- conn.request(request.method, request.path,
+ conn.request(request.method, request.url,
request.body, request.headers)
conn.getresponse(callback=callback)
return callback
-class AsyncAWSAuthConnection(AsyncConnection, AWSAuthConnection):
- """Async AWS Authn Connection."""
-
- def __init__(self, host,
- http_client=None, http_client_params={}, **kwargs):
- AsyncConnection.__init__(self, http_client, **http_client_params)
- AWSAuthConnection.__init__(self, host, **kwargs)
-
- def make_request(self, method, path, headers=None, data='', host=None,
- auth_path=None, sender=None, callback=None, **kwargs):
- req = self.build_base_http_request(
- method, path, auth_path, {}, headers, data, host,
- )
- return self._mexe(req, sender=sender, callback=callback)
-
-
-class AsyncAWSQueryConnection(AsyncConnection, AWSQueryConnection):
+class AsyncAWSQueryConnection(AsyncConnection):
"""Async AWS Query Connection."""
- def __init__(self, host,
- http_client=None, http_client_params={}, **kwargs):
- AsyncConnection.__init__(self, http_client, **http_client_params)
- AWSAuthConnection.__init__(self, host, **kwargs)
-
- def make_request(self, action, params, path, verb, callback=None):
- request = self.build_base_http_request(
- verb, path, None, params, {}, '', self.server_name())
- if action:
- request.params['Action'] = action
- request.params['Version'] = self.APIVersion
- return self._mexe(request, callback=callback)
-
- def get_list(self, action, params, markers,
- path='/', parent=None, verb='GET', callback=None):
+ def __init__(self, sqs_connection, http_client=None,
+ http_client_params=None, **kwargs):
+ if not http_client_params:
+ http_client_params = {}
+ AsyncConnection.__init__(self, sqs_connection, http_client,
+ **http_client_params)
+
+ def make_request(self, operation, params_, path, verb, callback=None): # noqa
+ params = params_.copy()
+ if operation:
+ params['Action'] = operation
+ signer = self.sqs_connection._request_signer # noqa
+
+ # defaults for non-get
+ signing_type = 'standard'
+ param_payload = {'data': params}
+ if verb.lower() == 'get':
+ # query-based opts
+ signing_type = 'presignurl'
+ param_payload = {'params': params}
+
+ request = AWSRequest(method=verb, url=path, **param_payload)
+ signer.sign(operation, request, signing_type=signing_type)
+ prepared_request = request.prepare()
+
+ # print(prepared_request.url)
+ # print(prepared_request.headers)
+ # print(prepared_request.body)
+ return self._mexe(prepared_request, callback=callback)
+
+ def get_list(self, operation, params, markers, path='/', parent=None, verb='POST', callback=None): # noqa
return self.make_request(
- action, params, path, verb,
+ operation, params, path, verb,
callback=transform(
self._on_list_ready, callback, parent or self, markers,
+ operation
),
)
- def get_object(self, action, params, cls,
- path='/', parent=None, verb='GET', callback=None):
+ def get_object(self, operation, params, path='/', parent=None, verb='GET', callback=None): # noqa
return self.make_request(
- action, params, path, verb,
+ operation, params, path, verb,
callback=transform(
- self._on_obj_ready, callback, parent or self, cls,
+ self._on_obj_ready, callback, parent or self, operation
),
)
- def get_status(self, action, params,
- path='/', parent=None, verb='GET', callback=None):
+ def get_status(self, operation, params, path='/', parent=None, verb='GET', callback=None): # noqa
return self.make_request(
- action, params, path, verb,
+ operation, params, path, verb,
callback=transform(
- self._on_status_ready, callback, parent or self,
+ self._on_status_ready, callback, parent or self, operation
),
)
- def _on_list_ready(self, parent, markers, response):
- body = response.read()
- if response.status == 200 and body:
- rs = ResultSet(markers)
- h = XmlHandler(rs, parent)
- sax_parse(body, h)
- return rs
+ def _on_list_ready(self, parent, markers, operation, response): # noqa
+ service_model = self.sqs_connection.meta.service_model
+ if response.status == 200:
+ _, parsed = get_response(
+ service_model.operation_model(operation), response.response
+ )
+ return parsed
else:
- raise self._for_status(response, body)
-
- def _on_obj_ready(self, parent, cls, response):
- body = response.read()
- if response.status == 200 and body:
- obj = cls(parent)
- h = XmlHandler(obj, parent)
- sax_parse(body, h)
- return obj
+ raise self._for_status(response, response.read())
+
+ def _on_obj_ready(self, parent, operation, response): # noqa
+ service_model = self.sqs_connection.meta.service_model
+ if response.status == 200:
+ _, parsed = get_response(
+ service_model.operation_model(operation), response.response
+ )
+ return parsed
else:
- raise self._for_status(response, body)
-
- def _on_status_ready(self, parent, response):
- body = response.read()
- if response.status == 200 and body:
- rs = ResultSet()
- h = XmlHandler(rs, parent)
- sax_parse(body, h)
- return rs.status
+ raise self._for_status(response, response.read())
+
+ def _on_status_ready(self, parent, operation, response): # noqa
+ service_model = self.sqs_connection.meta.service_model
+ if response.status == 200:
+ httpres, _ = get_response(
+ service_model.operation_model(operation), response.response
+ )
+ return httpres.code
else:
- raise self._for_status(response, body)
+ raise self._for_status(response, response.read())
def _for_status(self, response, body):
context = 'Empty body' if not body else 'HTTP Error'
- exc = self.ResponseError(response.status, response.reason, body)
- boto.log.error('{0}: %r'.format(context), exc)
- return exc
+ return Exception("Request {} HTTP {} {} ({})".format(
+ context, response.status, response.reason, body
+ ))
diff --git a/kombu/async/aws/ext.py b/kombu/async/aws/ext.py
index b0e497cb..3519ab65 100644
--- a/kombu/async/aws/ext.py
+++ b/kombu/async/aws/ext.py
@@ -1,29 +1,26 @@
# -*- coding: utf-8 -*-
-"""Amazon boto interface."""
+"""Amazon boto3 interface."""
from __future__ import absolute_import, unicode_literals
try:
- import boto
-except ImportError: # pragma: no cover
- boto = get_regions = ResultSet = RegionInfo = XmlHandler = None
+ import boto3
+ from botocore import exceptions
+ from botocore.awsrequest import AWSRequest
+ from botocore.response import get_response
+except ImportError:
+ boto3 = None
class _void(object):
pass
- AWSAuthConnection = AWSQueryConnection = _void # noqa
- class BotoError(Exception):
+ class BotoCoreError(Exception):
pass
- exception = _void()
- exception.SQSError = BotoError
- exception.SQSDecodeError = BotoError
-else:
- from boto import exception
- from boto.connection import AWSAuthConnection, AWSQueryConnection
- from boto.handler import XmlHandler
- from boto.resultset import ResultSet
- from boto.regioninfo import RegionInfo, get_regions
+ exceptions = _void()
+ exceptions.BotoCoreError = BotoCoreError
+ AWSRequest = _void()
+ get_response = _void()
+
__all__ = [
- 'exception', 'AWSAuthConnection', 'AWSQueryConnection',
- 'XmlHandler', 'ResultSet', 'RegionInfo', 'get_regions',
+ 'exceptions', 'AWSRequest', 'get_response'
]
diff --git a/kombu/async/aws/sqs/__init__.py b/kombu/async/aws/sqs/__init__.py
index fe529584..e69de29b 100644
--- a/kombu/async/aws/sqs/__init__.py
+++ b/kombu/async/aws/sqs/__init__.py
@@ -1,22 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
-from kombu.async.aws.ext import boto, get_regions
-
-from .connection import AsyncSQSConnection
-
-__all__ = ['regions', 'connect_to_region']
-
-
-def regions():
- """Return list of known AWS regions."""
- if boto is None:
- raise ImportError('boto is not installed')
- return get_regions('sqs', connection_cls=AsyncSQSConnection)
-
-
-def connect_to_region(region_name, **kwargs):
- """Connect to specific AWS region."""
- for region in regions():
- if region.name == region_name:
- return region.connect(**kwargs)
diff --git a/kombu/async/aws/sqs/connection.py b/kombu/async/aws/sqs/connection.py
index 32b9926b..7b0b738e 100644
--- a/kombu/async/aws/sqs/connection.py
+++ b/kombu/async/aws/sqs/connection.py
@@ -5,9 +5,8 @@ from __future__ import absolute_import, unicode_literals
from vine import transform
from kombu.async.aws.connection import AsyncAWSQueryConnection
-from kombu.async.aws.ext import RegionInfo
-from .ext import boto, Attributes, BatchResults, SQSConnection
+from .ext import boto3
from .message import AsyncMessage
from .queue import AsyncQueue
@@ -15,28 +14,17 @@ from .queue import AsyncQueue
__all__ = ['AsyncSQSConnection']
-class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
+class AsyncSQSConnection(AsyncAWSQueryConnection):
"""Async SQS Connection."""
- def __init__(self, aws_access_key_id=None, aws_secret_access_key=None,
- is_secure=True, port=None, proxy=None, proxy_port=None,
- proxy_user=None, proxy_pass=None, debug=0,
- https_connection_factory=None, region=None, *args, **kwargs):
- if boto is None:
- raise ImportError('boto is not installed')
- self.region = region or RegionInfo(
- self, self.DefaultRegionName, self.DefaultRegionEndpoint,
- connection_cls=type(self),
- )
+ def __init__(self, sqs_connection, debug=0, region=None, **kwargs):
+ if boto3 is None:
+ raise ImportError('boto3 is not installed')
AsyncAWSQueryConnection.__init__(
self,
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- is_secure=is_secure, port=port,
- proxy=proxy, proxy_port=proxy_port,
- proxy_user=proxy_user, proxy_pass=proxy_pass,
- host=self.region.endpoint, debug=debug,
- https_connection_factory=https_connection_factory, **kwargs
+ sqs_connection,
+ region_name=region, debug=debug,
+ **kwargs
)
def create_queue(self, queue_name,
@@ -46,17 +34,21 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
params['DefaultVisibilityTimeout'] = format(
visibility_timeout, 'd',
)
- return self.get_object('CreateQueue', params, AsyncQueue,
+ return self.get_object('CreateQueue', params,
callback=callback)
def delete_queue(self, queue, force_deletion=False, callback=None):
return self.get_status('DeleteQueue', None, queue.id,
callback=callback)
+ def get_queue_url(self, queue):
+ res = self.sqs_connection.get_queue_url(QueueName=queue)
+ return res['QueueUrl']
+
def get_queue_attributes(self, queue, attribute='All', callback=None):
return self.get_object(
'GetQueueAttributes', {'AttributeName': attribute},
- Attributes, queue.id, callback=callback,
+ queue.id, callback=callback,
)
def set_queue_attribute(self, queue, attribute, value, callback=None):
@@ -74,17 +66,21 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
if visibility_timeout:
params['VisibilityTimeout'] = visibility_timeout
if attributes:
- self.build_list_params(params, attributes, 'AttributeName')
+ attrs = {}
+ for idx, attr in enumerate(attributes):
+ attrs['AttributeName.' + str(idx + 1)] = attr
+ params.update(attrs)
if wait_time_seconds is not None:
params['WaitTimeSeconds'] = wait_time_seconds
+ queue_url = self.get_queue_url(queue)
return self.get_list(
- 'ReceiveMessage', params, [('Message', queue.message_class)],
- queue.id, callback=callback,
+ 'ReceiveMessage', params, [('Message', AsyncMessage)],
+ queue_url, callback=callback, parent=queue,
)
- def delete_message(self, queue, message, callback=None):
+ def delete_message(self, queue, receipt_handle, callback=None):
return self.delete_message_from_handle(
- queue, message.receipt_handle, callback,
+ queue, receipt_handle, callback,
)
def delete_message_batch(self, queue, messages, callback=None):
@@ -96,7 +92,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
'{0}.ReceiptHandle'.format(prefix): m.receipt_handle,
})
return self.get_object(
- 'DeleteMessageBatch', params, BatchResults, queue.id,
+ 'DeleteMessageBatch', params, queue.id,
verb='POST', callback=callback,
)
@@ -104,7 +100,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
callback=None):
return self.get_status(
'DeleteMessage', {'ReceiptHandle': receipt_handle},
- queue.id, callback=callback,
+ queue, callback=callback,
)
def send_message(self, queue, message_content,
@@ -113,7 +109,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
if delay_seconds:
params['DelaySeconds'] = int(delay_seconds)
return self.get_object(
- 'SendMessage', params, AsyncMessage, queue.id,
+ 'SendMessage', params, queue.id,
verb='POST', callback=callback,
)
@@ -127,7 +123,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
'{0}.DelaySeconds'.format(prefix): msg[2],
})
return self.get_object(
- 'SendMessageBatch', params, BatchResults, queue.id,
+ 'SendMessageBatch', params, queue.id,
verb='POST', callback=callback,
)
@@ -150,7 +146,7 @@ class AsyncSQSConnection(AsyncAWSQueryConnection, SQSConnection):
'{0}.VisibilityTimeout'.format(pre): t[1],
})
return self.get_object(
- 'ChangeMessageVisibilityBatch', params, BatchResults, queue.id,
+ 'ChangeMessageVisibilityBatch', params, queue.id,
verb='POST', callback=callback,
)
diff --git a/kombu/async/aws/sqs/ext.py b/kombu/async/aws/sqs/ext.py
index eb48f3e9..09fdf1a1 100644
--- a/kombu/async/aws/sqs/ext.py
+++ b/kombu/async/aws/sqs/ext.py
@@ -1,32 +1,9 @@
# -*- coding: utf-8 -*-
-"""Amazon SQS boto interface."""
+"""Amazon SQS boto3 interface."""
from __future__ import absolute_import, unicode_literals
try:
- import boto
-except ImportError: # pragma: no cover
- boto = Attributes = BatchResults = None # noqa
-
- class _void(object):
- pass
- regions = SQSConnection = Queue = _void
-
- RawMessage = Message = MHMessage = \
- EncodedMHMessage = JSONMessage = _void
-else:
- from boto.sqs.attributes import Attributes
- from boto.sqs.batchresults import BatchResults
- from boto.sqs.message import (
- EncodedMHMessage, Message, MHMessage, RawMessage,
- )
- from boto.sqs import regions
- from boto.sqs.jsonmessage import JSONMessage
- from boto.sqs.connection import SQSConnection
- from boto.sqs.queue import Queue
-
-__all__ = [
- 'Attributes', 'BatchResults', 'EncodedMHMessage', 'MHMessage',
- 'Message', 'RawMessage', 'JSONMessage', 'SQSConnection',
- 'Queue', 'regions',
-]
+ import boto3
+except ImportError:
+ boto3 = None
diff --git a/kombu/async/aws/sqs/message.py b/kombu/async/aws/sqs/message.py
index 36359841..b4b28331 100644
--- a/kombu/async/aws/sqs/message.py
+++ b/kombu/async/aws/sqs/message.py
@@ -2,45 +2,32 @@
"""Amazon SQS message implementation."""
from __future__ import absolute_import, unicode_literals
-from .ext import (
- RawMessage, Message, MHMessage, EncodedMHMessage, JSONMessage,
-)
+from kombu.message import Message
+import base64
-__all__ = [
- 'BaseAsyncMessage', 'AsyncRawMessage', 'AsyncMessage',
- 'AsyncMHMessage', 'AsyncEncodedMHMessage', 'AsyncJSONMessage',
-]
-
-class BaseAsyncMessage(object):
+class BaseAsyncMessage(Message):
"""Base class for messages received on async client."""
- def delete(self, callback=None):
- if self.queue:
- return self.queue.delete_message(self, callback)
-
- def change_visibility(self, visibility_timeout, callback=None):
- if self.queue:
- return self.queue.connection.change_message_visibility(
- self.queue, self.receipt_handle, visibility_timeout, callback,
- )
-
-class AsyncRawMessage(BaseAsyncMessage, RawMessage):
+class AsyncRawMessage(BaseAsyncMessage):
"""Raw Message."""
-class AsyncMessage(BaseAsyncMessage, Message):
+class AsyncMessage(BaseAsyncMessage):
"""Serialized message."""
-
-class AsyncMHMessage(BaseAsyncMessage, MHMessage):
- """MHM Message (uhm, look that up later)."""
-
-
-class AsyncEncodedMHMessage(BaseAsyncMessage, EncodedMHMessage):
- """Encoded MH Message."""
-
-
-class AsyncJSONMessage(BaseAsyncMessage, JSONMessage):
- """Json serialized message."""
+ def encode(self, value):
+ """Encode/decode the value using Base64 encoding."""
+ return base64.b64encode(value).decode('utf-8')
+
+ def __getitem__(self, item):
+ """Support Boto3-style access on a message."""
+ if item == 'ReceiptHandle':
+ return self.receipt_handle
+ elif item == 'Body':
+ return self.get_body()
+ elif item == 'queue':
+ return self.queue
+ else:
+ raise KeyError(item)
diff --git a/kombu/async/aws/sqs/queue.py b/kombu/async/aws/sqs/queue.py
index 140ff31a..c0f8c0f3 100644
--- a/kombu/async/aws/sqs/queue.py
+++ b/kombu/async/aws/sqs/queue.py
@@ -4,7 +4,6 @@ from __future__ import absolute_import, unicode_literals
from vine import transform
-from .ext import Queue as _Queue
from .message import AsyncMessage
_all__ = ['AsyncQueue']
@@ -15,7 +14,7 @@ def list_first(rs):
return rs[0] if len(rs) == 1 else None
-class AsyncQueue(_Queue):
+class AsyncQueue():
"""Async SQS Queue."""
def __init__(self, connection=None, url=None, message_class=AsyncMessage):
diff --git a/kombu/async/http/base.py b/kombu/async/http/base.py
index f8a8bc0a..a8eb6edf 100644
--- a/kombu/async/http/base.py
+++ b/kombu/async/http/base.py
@@ -200,6 +200,15 @@ class Response(object):
self._body = self.buffer.getvalue()
return self._body
+ # these are for compatibility with Requests
+ @property
+ def status_code(self):
+ return self.code
+
+ @property
+ def content(self):
+ return self.body
+
@coro
def header_parser(keyt=normalize_header):
diff --git a/kombu/async/http/curl.py b/kombu/async/http/curl.py
index 1c50eef8..d8520ded 100644
--- a/kombu/async/http/curl.py
+++ b/kombu/async/http/curl.py
@@ -171,15 +171,12 @@ class CurlClient(BaseClient):
code = curl.getinfo(_pycurl.HTTP_CODE)
effective_url = curl.getinfo(_pycurl.EFFECTIVE_URL)
buffer.seek(0)
- try:
- request = info['request']
- request.on_ready(self.Response(
- request=request, code=code, headers=info['headers'],
- buffer=buffer, effective_url=effective_url, error=error,
- ))
- except Exception as exc:
- self.hub.on_callback_error(request.on_ready, exc)
- raise
+ # try:
+ request = info['request']
+ request.on_ready(self.Response(
+ request=request, code=code, headers=info['headers'],
+ buffer=buffer, effective_url=effective_url, error=error,
+ ))
def _setup_request(self, curl, request, buffer, headers, _pycurl=pycurl):
setopt = curl.setopt
@@ -243,7 +240,7 @@ class CurlClient(BaseClient):
setopt(meth, True)
if request.method in ('POST', 'PUT'):
- body = request.body or ''
+ body = request.body.encode('utf-8') if request.body else bytes()
reqbuffer = BytesIO(body)
setopt(_pycurl.READFUNCTION, reqbuffer.read)
if request.method == 'POST':
diff --git a/kombu/transport/SQS.py b/kombu/transport/SQS.py
index a0fb6e8d..59cfdc87 100644
--- a/kombu/transport/SQS.py
+++ b/kombu/transport/SQS.py
@@ -39,15 +39,14 @@ from __future__ import absolute_import, unicode_literals
import socket
import string
+import uuid
from vine import transform, ensure_promise, promise
from kombu.async import get_event_loop
-from kombu.async.aws import sqs as _asynsqs
-from kombu.async.aws.ext import boto, exception
-from kombu.async.aws.sqs.connection import AsyncSQSConnection, SQSConnection
-from kombu.async.aws.sqs.ext import regions
-from kombu.async.aws.sqs.message import Message
+from kombu.async.aws.ext import boto3, exceptions
+from kombu.async.aws.sqs.connection import AsyncSQSConnection
+from kombu.async.aws.sqs.message import AsyncMessage
from kombu.five import Empty, range, string_t, text_t
from kombu.log import get_logger
from kombu.utils import scheduling
@@ -91,8 +90,8 @@ class Channel(virtual.Channel):
_noack_queues = set()
def __init__(self, *args, **kwargs):
- if boto is None:
- raise ImportError('boto is not installed')
+ if boto3 is None:
+ raise ImportError('boto3 is not installed')
super(Channel, self).__init__(*args, **kwargs)
# SQS blows up if you try to create a new queue when one already
@@ -104,18 +103,10 @@ class Channel(virtual.Channel):
self.hub = kwargs.get('hub') or get_event_loop()
def _update_queue_cache(self, queue_name_prefix):
- try:
- queues = self.sqs.get_all_queues(prefix=queue_name_prefix)
- except exception.SQSError as exc:
- if exc.status == 403:
- raise RuntimeError(
- 'SQS authorization error, access_key={0}'.format(
- self.sqs.access_key))
- raise
- else:
- self._queue_cache.update({
- queue.name: queue for queue in queues
- })
+ resp = self.sqs.list_queues(QueueNamePrefix=queue_name_prefix)
+ for url in resp.get('QueueUrls', []):
+ queue_name = url.split('/')[-1]
+ self._queue_cache[queue_name] = url
def basic_consume(self, queue, no_ack, *args, **kwargs):
if no_ack:
@@ -132,7 +123,7 @@ class Channel(virtual.Channel):
self._noack_queues.discard(queue)
return super(Channel, self).basic_cancel(consumer_tag)
- def drain_events(self, timeout=None, **kwargs):
+ def drain_events(self, timeout=None, callback=None, **kwargs):
"""Return a single payload message from one of our queues.
Raises:
@@ -143,7 +134,7 @@ class Channel(virtual.Channel):
raise Empty()
# At this point, go and get more messages from SQS
- self._poll(self.cycle, self.connection._deliver, timeout=timeout)
+ self._poll(self.cycle, callback, timeout=timeout)
def _reset_cycle(self):
"""Reset the consume cycle.
@@ -160,7 +151,15 @@ class Channel(virtual.Channel):
def entity_name(self, name, table=CHARS_REPLACE_TABLE):
"""Format AMQP queue name into a legal SQS queue name."""
- return text_t(safe_str(name)).translate(table)
+ if name.endswith('.fifo'):
+ partial = name.rstrip('.fifo')
+ partial = text_t(safe_str(partial)).translate(table)
+ return partial + '.fifo'
+ else:
+ return text_t(safe_str(name)).translate(table)
+
+ def canonical_queue_name(self, queue_name):
+ return self.entity_name(self.queue_name_prefix + queue_name)
def _new_queue(self, queue, **kwargs):
"""Ensure a queue with given name exists in SQS."""
@@ -168,7 +167,7 @@ class Channel(virtual.Channel):
return queue
# Translate to SQS name for consistency with initial
# _queue_cache population.
- queue = self.entity_name(self.queue_name_prefix + queue)
+ queue = self.canonical_queue_name(queue)
# The SQS ListQueues method only returns 1000 queues. When you have
# so many queues, it's possible that the queue you are looking for is
@@ -179,10 +178,14 @@ class Channel(virtual.Channel):
try:
return self._queue_cache[queue]
except KeyError:
- q = self._queue_cache[queue] = self.sqs.create_queue(
- queue, self.visibility_timeout,
- )
- return q
+ attributes = {'VisibilityTimeout': str(self.visibility_timeout)}
+ if queue.endswith('.fifo'):
+ attributes['FifoQueue'] = 'true'
+
+ resp = self._queue_cache[queue] = self.sqs.create_queue(
+ QueueName=queue, Attributes=attributes)
+ self._queue_cache[queue] = resp['QueueUrl']
+ return resp['QueueUrl']
def _delete(self, queue, *args, **kwargs):
"""Delete queue by name."""
@@ -191,15 +194,27 @@ class Channel(virtual.Channel):
def _put(self, queue, message, **kwargs):
"""Put message onto queue."""
- q = self._new_queue(queue)
- m = Message()
- m.set_body(dumps(message))
- q.write(m)
+ q_url = self._new_queue(queue)
+ kwargs = {'QueueUrl': q_url,
+ 'MessageBody': AsyncMessage().encode(dumps(message))}
+ if queue.endswith('.fifo'):
+ if 'MessageGroupId' in message['properties']:
+ kwargs['MessageGroupId'] = \
+ message['properties']['MessageGroupId']
+ else:
+ kwargs['MessageGroupId'] = 'default'
+ if 'MessageDeduplicationId' in message['properties']:
+ kwargs['MessageDeduplicationId'] = \
+ message['properties']['MessageDeduplicationId']
+ else:
+ kwargs['MessageDeduplicationId'] = str(uuid.uuid4())
+ self.sqs.send_message(**kwargs)
def _message_to_python(self, message, queue_name, queue):
- payload = loads(bytes_to_str(message.get_body()))
+ payload = loads(bytes_to_str(message['Body']))
if queue_name in self._noack_queues:
- queue.delete_message(message)
+ queue = self._new_queue(queue_name)
+ self.asynsqs.delete_message(queue, message['ReceiptHandle'])
else:
try:
properties = payload['properties']
@@ -209,14 +224,14 @@ class Channel(virtual.Channel):
delivery_info = {}
properties = {'delivery_info': delivery_info}
payload.update({
- 'body': bytes_to_str(message.get_body()),
+ 'body': bytes_to_str(message['Body']),
'properties': properties,
})
# set delivery tag to SQS receipt handle
delivery_info.update({
'sqs_message': message, 'sqs_queue': queue,
})
- properties['delivery_tag'] = message.receipt_handle
+ properties['delivery_tag'] = message['ReceiptHandle']
return payload
def _messages_to_python(self, messages, queue):
@@ -261,23 +276,31 @@ class Channel(virtual.Channel):
# drain_events calls `can_consume` first, consuming
# a token, so we know that we are allowed to consume at least
# one message.
- maxcount = self._get_message_estimate()
- if maxcount:
- q = self._new_queue(queue)
- messages = q.get_messages(num_messages=maxcount)
- if messages:
- for msg in self._messages_to_python(messages, queue):
+ # Note: ignoring max_messages for SQS with boto3
+ max_count = self._get_message_estimate()
+ if max_count:
+ q_url = self._new_queue(queue)
+ resp = self.sqs.receive_message(
+ QueueUrl=q_url, MaxNumberOfMessages=max_count)
+
+ if resp['Messages']:
+ for m in resp['Messages']:
+ m['Body'] = AsyncMessage().decode(m['Body'])
+ for msg in self._messages_to_python(resp['Messages'], queue):
self.connection._deliver(msg, queue)
return
raise Empty()
def _get(self, queue):
"""Try to retrieve a single message off ``queue``."""
- q = self._new_queue(queue)
- messages = q.get_messages(num_messages=1)
- if messages:
- return self._messages_to_python(messages, queue)[0]
+ q_url = self._new_queue(queue)
+ resp = self.sqs.receive_message(q_url)
+
+ if resp['Messages']:
+ body = AsyncMessage().decode(resp['Messages'][0]['Body'])
+ resp['Messages'][0]['Body'] = body
+ return self._messages_to_python(resp['Messages'], queue)[0]
raise Empty()
def _loop1(self, queue, _=None):
@@ -311,17 +334,18 @@ class Channel(virtual.Channel):
def _get_async(self, queue, count=1, callback=None):
q = self._new_queue(queue)
+ qname = self.canonical_queue_name(queue)
return self._get_from_sqs(
- q, count=count, connection=self.asynsqs,
+ qname, count=count, connection=self.asynsqs,
callback=transform(self._on_messages_ready, callback, q, queue),
)
def _on_messages_ready(self, queue, qname, messages):
- if messages:
+ if 'Messages' in messages and messages['Messages']:
callbacks = self.connection._callbacks
- for raw_message in messages:
- message = self._message_to_python(raw_message, qname, queue)
- callbacks[qname](message)
+ for msg in messages['Messages']:
+ msg_parsed = self._message_to_python(msg, qname, queue)
+ callbacks[qname](msg_parsed)
def _get_from_sqs(self, queue,
count=1, connection=None, callback=None):
@@ -330,6 +354,7 @@ class Channel(virtual.Channel):
Uses long polling and returns :class:`~vine.promises.promise`.
"""
connection = connection if connection is not None else queue.connection
+ # url = self.get_queue
return connection.receive_message(
queue, number_messages=count,
wait_time_seconds=self.wait_time_seconds,
@@ -344,72 +369,68 @@ class Channel(virtual.Channel):
return super(Channel, self)._restore(message)
def basic_ack(self, delivery_tag, multiple=False):
- delivery_info = self.qos.get(delivery_tag).delivery_info
try:
- queue = delivery_info['sqs_queue']
+ message = self.qos.get(delivery_tag).delivery_info
+ sqs_message = message['sqs_message']
except KeyError:
pass
else:
- queue.delete_message(delivery_info['sqs_message'])
+ self.asynsqs.delete_message(message['sqs_queue'],
+ sqs_message['ReceiptHandle'])
super(Channel, self).basic_ack(delivery_tag)
def _size(self, queue):
"""Return the number of messages in a queue."""
- return self._new_queue(queue).count()
+ url = self._new_queue(queue)
+ resp = self.sqs.get_queue_attributes(
+ QueueUrl=url,
+ AttributeNames=['ApproximateNumberOfMessages'])
+ return int(resp['Attributes']['ApproximateNumberOfMessages'])
def _purge(self, queue):
"""Delete all current messages in a queue."""
q = self._new_queue(queue)
# SQS is slow at registering messages, so run for a few
- # iterations to ensure messages are deleted.
+ # iterations to ensure messages are detected and deleted.
size = 0
for i in range(10):
- size += q.count()
+ size += int(self._size(queue))
if not size:
break
- q.clear()
+ self.sqs.purge_queue(q)
return size
def close(self):
super(Channel, self).close()
- for conn in (self._sqs, self._asynsqs):
- if conn:
- try:
- conn.close()
- except AttributeError as exc: # FIXME ???
- if "can't set attribute" not in str(exc):
- raise
-
- def _get_regioninfo(self, regions):
- if self.regioninfo:
- return self.regioninfo
- if self.region:
- for _r in regions:
- if _r.name == self.region:
- return _r
-
- def _aws_connect_to(self, fun, regions):
- conninfo = self.conninfo
- region = self._get_regioninfo(regions)
- is_secure = self.is_secure if self.is_secure is not None else True
- port = self.port if self.port is not None else conninfo.port
- return fun(region=region,
- aws_access_key_id=conninfo.userid,
- aws_secret_access_key=conninfo.password,
- is_secure=is_secure,
- port=port)
+ # if self._asynsqs:
+ # try:
+ # self.asynsqs.close()
+ # except AttributeError as exc: # FIXME ???
+ # if "can't set attribute" not in str(exc):
+ # raise
@property
def sqs(self):
if self._sqs is None:
- self._sqs = self._aws_connect_to(SQSConnection, regions())
+ session = boto3.session.Session(
+ region_name=self.region,
+ aws_access_key_id=self.conninfo.userid,
+ aws_secret_access_key=self.conninfo.password,
+ )
+ is_secure = self.is_secure if self.is_secure is not None else True
+ self._sqs = session.client('sqs', use_ssl=is_secure)
return self._sqs
@property
def asynsqs(self):
if self._asynsqs is None:
- self._asynsqs = self._aws_connect_to(
- AsyncSQSConnection, _asynsqs.regions(),
+ is_secure = self.is_secure if self.is_secure is not None else True
+ self._asynsqs = AsyncSQSConnection(
+ sqs_connection=self.sqs,
+ aws_access_key_id=self.conninfo.userid,
+ aws_secret_access_key=self.conninfo.password,
+ region=self.region,
+ is_secure=is_secure,
)
return self._asynsqs
@@ -466,10 +487,10 @@ class Transport(virtual.Transport):
default_port = None
connection_errors = (
virtual.Transport.connection_errors +
- (exception.SQSError, socket.error)
+ (exceptions.BotoCoreError, socket.error)
)
channel_errors = (
- virtual.Transport.channel_errors + (exception.SQSDecodeError,)
+ virtual.Transport.channel_errors + (exceptions.BotoCoreError,)
)
driver_type = 'sqs'
driver_name = 'sqs'
diff --git a/requirements/extras/sqs.txt b/requirements/extras/sqs.txt
index 53851e77..3d6d07a9 100644
--- a/requirements/extras/sqs.txt
+++ b/requirements/extras/sqs.txt
@@ -1,2 +1,2 @@
-boto>=2.8
+boto3>=1.4.4
pycurl
diff --git a/requirements/funtest.txt b/requirements/funtest.txt
index 55637309..033db4b8 100644
--- a/requirements/funtest.txt
+++ b/requirements/funtest.txt
@@ -9,6 +9,7 @@ kazoo
# SQS transport
boto
+boto3
# Qpid transport
qpid-python>=0.26
diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt
index 0684adfc..f2318d5a 100644
--- a/requirements/test-ci.txt
+++ b/requirements/test-ci.txt
@@ -3,3 +3,4 @@ codecov
redis
PyYAML
msgpack-python>0.2.0
+-r extras/sqs.txt
diff --git a/t/integration/tests/test_SQS.py b/t/integration/tests/test_SQS.py
index 676ba994..1c52d648 100644
--- a/t/integration/tests/test_SQS.py
+++ b/t/integration/tests/test_SQS.py
@@ -8,6 +8,7 @@ from kombu.tests.case import skip
@skip.unless_environ('AWS_ACCESS_KEY_ID')
@skip.unless_environ('AWS_SECRET_ACCESS_KEY')
@skip.unless_module('boto')
+@skip.unless_module('boto3')
class test_SQS(transport.TransportCase):
transport = 'SQS'
prefix = 'sqs'
diff --git a/t/unit/async/aws/case.py b/t/unit/async/aws/case.py
index 985dbc04..70d9565f 100644
--- a/t/unit/async/aws/case.py
+++ b/t/unit/async/aws/case.py
@@ -7,7 +7,7 @@ from case import skip
@skip.if_pypy()
-@skip.unless_module('boto')
+@skip.unless_module('boto3')
@skip.unless_module('pycurl')
@pytest.mark.usefixtures('hub')
class AWSCase(object):
diff --git a/t/unit/async/aws/sqs/test_connection.py b/t/unit/async/aws/sqs/test_connection.py
index ecc22623..ade01868 100644
--- a/t/unit/async/aws/sqs/test_connection.py
+++ b/t/unit/async/aws/sqs/test_connection.py
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
-import pytest
-
-from case import Mock
+from case import Mock, MagicMock
from kombu.async.aws.sqs.connection import (
- AsyncSQSConnection, Attributes, BatchResults,
+ AsyncSQSConnection
)
+from kombu.async.aws.ext import boto3
from kombu.async.aws.sqs.message import AsyncMessage
from kombu.async.aws.sqs.queue import AsyncQueue
from kombu.utils.uuid import uuid
@@ -20,29 +19,26 @@ from ..case import AWSCase
class test_AsyncSQSConnection(AWSCase):
def setup(self):
- self.x = AsyncSQSConnection('ak', 'sk', http_client=Mock())
+ session = boto3.session.Session(
+ aws_access_key_id='AAA',
+ aws_secret_access_key='AAAA',
+ region_name='us-west-2',
+ )
+ sqs_client = session.client('sqs')
+ self.x = AsyncSQSConnection(sqs_client, 'ak', 'sk', http_client=Mock())
self.x.get_object = Mock(name='X.get_object')
self.x.get_status = Mock(name='X.get_status')
- self.x.get_list = Mock(nanme='X.get_list')
+ self.x.get_list = Mock(name='X.get_list')
self.callback = PromiseMock(name='callback')
- def test_without_boto(self):
- from kombu.async.aws.sqs import connection
- prev, connection.boto = connection.boto, None
- try:
- with pytest.raises(ImportError):
- AsyncSQSConnection('ak', 'sk', http_client=Mock())
- finally:
- connection.boto = prev
-
- def test_default_region(self):
- assert self.x.region
- assert issubclass(self.x.region.connection_cls, AsyncSQSConnection)
+ sqs_client.get_queue_url = MagicMock(return_value={
+ 'QueueUrl': 'http://aws.com'
+ })
def test_create_queue(self):
self.x.create_queue('foo', callback=self.callback)
self.x.get_object.assert_called_with(
- 'CreateQueue', {'QueueName': 'foo'}, AsyncQueue,
+ 'CreateQueue', {'QueueName': 'foo'},
callback=self.callback,
)
@@ -55,7 +51,7 @@ class test_AsyncSQSConnection(AWSCase):
'QueueName': 'foo',
'DefaultVisibilityTimeout': '33'
},
- AsyncQueue, callback=self.callback
+ callback=self.callback
)
def test_delete_queue(self):
@@ -72,7 +68,7 @@ class test_AsyncSQSConnection(AWSCase):
)
self.x.get_object.assert_called_with(
'GetQueueAttributes', {'AttributeName': 'QueueSize'},
- Attributes, queue.id, callback=self.callback,
+ queue.id, callback=self.callback,
)
def test_set_queue_attribute(self):
@@ -93,8 +89,9 @@ class test_AsyncSQSConnection(AWSCase):
self.x.receive_message(queue, 4, callback=self.callback)
self.x.get_list.assert_called_with(
'ReceiveMessage', {'MaxNumberOfMessages': 4},
- [('Message', queue.message_class)],
- queue.id, callback=self.callback,
+ [('Message', AsyncMessage)],
+ 'http://aws.com', callback=self.callback,
+ parent=queue,
)
def test_receive_message__with_visibility_timeout(self):
@@ -105,8 +102,9 @@ class test_AsyncSQSConnection(AWSCase):
'MaxNumberOfMessages': 4,
'VisibilityTimeout': 3666,
},
- [('Message', queue.message_class)],
- queue.id, callback=self.callback,
+ [('Message', AsyncMessage)],
+ 'http://aws.com', callback=self.callback,
+ parent=queue,
)
def test_receive_message__with_wait_time_seconds(self):
@@ -119,8 +117,9 @@ class test_AsyncSQSConnection(AWSCase):
'MaxNumberOfMessages': 4,
'WaitTimeSeconds': 303,
},
- [('Message', queue.message_class)],
- queue.id, callback=self.callback,
+ [('Message', AsyncMessage)],
+ 'http://aws.com', callback=self.callback,
+ parent=queue,
)
def test_receive_message__with_attributes(self):
@@ -134,8 +133,9 @@ class test_AsyncSQSConnection(AWSCase):
'AttributeName.2': 'bar',
'MaxNumberOfMessages': 4,
},
- [('Message', queue.message_class)],
- queue.id, callback=self.callback,
+ [('Message', AsyncMessage)],
+ 'http://aws.com', callback=self.callback,
+ parent=queue,
)
def MockMessage(self, id=None, receipt_handle=None, body=None):
@@ -157,10 +157,11 @@ class test_AsyncSQSConnection(AWSCase):
def test_delete_message(self):
queue = Mock(name='queue')
message = self.MockMessage()
- self.x.delete_message(queue, message, callback=self.callback)
+ self.x.delete_message(queue, message.receipt_handle,
+ callback=self.callback)
self.x.get_status.assert_called_with(
'DeleteMessage', {'ReceiptHandle': message.receipt_handle},
- queue.id, callback=self.callback,
+ queue, callback=self.callback,
)
def test_delete_message_batch(self):
@@ -175,7 +176,7 @@ class test_AsyncSQSConnection(AWSCase):
'DeleteMessageBatchRequestEntry.2.Id': '2',
'DeleteMessageBatchRequestEntry.2.ReceiptHandle': 'r2',
},
- BatchResults, queue.id, verb='POST', callback=self.callback,
+ queue.id, verb='POST', callback=self.callback,
)
def test_send_message(self):
@@ -183,7 +184,7 @@ class test_AsyncSQSConnection(AWSCase):
self.x.send_message(queue, 'hello', callback=self.callback)
self.x.get_object.assert_called_with(
'SendMessage', {'MessageBody': 'hello'},
- AsyncMessage, queue.id, verb='POST', callback=self.callback,
+ queue.id, verb='POST', callback=self.callback,
)
def test_send_message__with_delay_seconds(self):
@@ -193,7 +194,7 @@ class test_AsyncSQSConnection(AWSCase):
)
self.x.get_object.assert_called_with(
'SendMessage', {'MessageBody': 'hello', 'DelaySeconds': 303},
- AsyncMessage, queue.id, verb='POST', callback=self.callback,
+ queue.id, verb='POST', callback=self.callback,
)
def test_send_message_batch(self):
@@ -213,7 +214,7 @@ class test_AsyncSQSConnection(AWSCase):
'SendMessageBatchRequestEntry.2.MessageBody': 'B',
'SendMessageBatchRequestEntry.2.DelaySeconds': 303,
},
- BatchResults, queue.id, verb='POST', callback=self.callback,
+ queue.id, verb='POST', callback=self.callback,
)
def test_change_message_visibility(self):
@@ -251,7 +252,7 @@ class test_AsyncSQSConnection(AWSCase):
preamble('2.ReceiptHandle'): 'r2',
preamble('2.VisibilityTimeout'): 909,
},
- BatchResults, queue.id, verb='POST', callback=self.callback,
+ queue.id, verb='POST', callback=self.callback,
)
def test_get_all_queues(self):
diff --git a/t/unit/async/aws/sqs/test_message.py b/t/unit/async/aws/sqs/test_message.py
deleted file mode 100644
index 44a0ac32..00000000
--- a/t/unit/async/aws/sqs/test_message.py
+++ /dev/null
@@ -1,37 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
-from case import Mock
-
-from kombu.async.aws.sqs.message import AsyncMessage
-from kombu.utils.uuid import uuid
-
-from t.mocks import PromiseMock
-
-from ..case import AWSCase
-
-
-class test_AsyncMessage(AWSCase):
-
- def setup(self):
- self.queue = Mock(name='queue')
- self.callback = PromiseMock(name='callback')
- self.x = AsyncMessage(self.queue, 'body')
- self.x.receipt_handle = uuid()
-
- def test_delete(self):
- assert self.x.delete(callback=self.callback)
- self.x.queue.delete_message.assert_called_with(
- self.x, self.callback,
- )
-
- self.x.queue = None
- assert self.x.delete(callback=self.callback) is None
-
- def test_change_visibility(self):
- assert self.x.change_visibility(303, callback=self.callback)
- self.x.queue.connection.change_message_visibility.assert_called_with(
- self.x.queue, self.x.receipt_handle, 303, self.callback,
- )
- self.x.queue = None
- assert self.x.change_visibility(303, callback=self.callback) is None
diff --git a/t/unit/async/aws/sqs/test_sqs.py b/t/unit/async/aws/sqs/test_sqs.py
deleted file mode 100644
index ea58596d..00000000
--- a/t/unit/async/aws/sqs/test_sqs.py
+++ /dev/null
@@ -1,34 +0,0 @@
-# -*- coding: utf-8 -*-
-from __future__ import absolute_import, unicode_literals
-
-import pytest
-
-from case import Mock, patch
-
-from kombu.async.aws.sqs import regions, connect_to_region
-from kombu.async.aws.sqs.connection import AsyncSQSConnection
-
-from ..case import AWSCase
-
-
-class test_connect_to_region(AWSCase):
-
- def test_when_no_boto_installed(self, patching):
- patching('kombu.async.aws.sqs.boto', None)
- with pytest.raises(ImportError):
- regions()
-
- def test_using_async_connection(self):
- for region in regions():
- assert region.connection_cls is AsyncSQSConnection
-
- def test_connect_to_region(self):
- with patch('kombu.async.aws.sqs.regions') as regions:
- region = Mock(name='region')
- region.name = 'us-west-1'
- regions.return_value = [region]
- conn = connect_to_region('us-west-1', kw=3.33)
- assert conn is region.connect.return_value
- region.connect.assert_called_with(kw=3.33)
-
- assert connect_to_region('foo') is None
diff --git a/t/unit/async/aws/test_aws.py b/t/unit/async/aws/test_aws.py
index f5ed9aef..29c535ab 100644
--- a/t/unit/async/aws/test_aws.py
+++ b/t/unit/async/aws/test_aws.py
@@ -13,4 +13,4 @@ class test_connect_sqs(AWSCase):
def test_connection(self):
x = connect_sqs('AAKI', 'ASAK', http_client=Mock())
assert x
- assert x.connection
+ assert x.sqs_connection
diff --git a/t/unit/async/aws/test_connection.py b/t/unit/async/aws/test_connection.py
index 5de76aa5..cba953b3 100644
--- a/t/unit/async/aws/test_connection.py
+++ b/t/unit/async/aws/test_connection.py
@@ -1,11 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, unicode_literals
-import pytest
-
from contextlib import contextmanager
-from case import Mock, patch
+from case import Mock
from vine.abstract import Thenable
from kombu.exceptions import HttpError
@@ -13,18 +11,22 @@ from kombu.five import WhateverIO
from kombu.async import http
from kombu.async.aws.connection import (
- AsyncHTTPConnection,
AsyncHTTPSConnection,
AsyncHTTPResponse,
AsyncConnection,
- AsyncAWSAuthConnection,
AsyncAWSQueryConnection,
)
+from kombu.async.aws.ext import boto3
from .case import AWSCase
from t.mocks import PromiseMock
+try:
+ from urllib.parse import urlparse, parse_qs
+except ImportError:
+ from urlparse import urlparse, parse_qs # noqa
+
# Not currently working
VALIDATES_CERT = False
@@ -38,37 +40,30 @@ def passthrough(*args, **kwargs):
return m
-class test_AsyncHTTPConnection(AWSCase):
-
- def test_AsyncHTTPSConnection(self):
- x = AsyncHTTPSConnection('aws.vandelay.com')
- assert x.scheme == 'https'
+class test_AsyncHTTPSConnection(AWSCase):
def test_http_client(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection()
assert x.http_client is http.get_client()
client = Mock(name='http_client')
- y = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
+ y = AsyncHTTPSConnection(http_client=client)
assert y.http_client is client
def test_args(self):
- x = AsyncHTTPConnection(
- 'aws.vandelay.com', 8083, strict=True, timeout=33.3,
+ x = AsyncHTTPSConnection(
+ strict=True, timeout=33.3,
)
- assert x.host == 'aws.vandelay.com'
- assert x.port == 8083
assert x.strict
assert x.timeout == 33.3
- assert x.scheme == 'http'
def test_request(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection('aws.vandelay.com')
x.request('PUT', '/importer-exporter')
assert x.path == '/importer-exporter'
assert x.method == 'PUT'
def test_request_with_body_buffer(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection('aws.vandelay.com')
body = Mock(name='body')
body.read.return_value = 'Vandelay Industries'
x.request('PUT', '/importer-exporter', body)
@@ -78,14 +73,14 @@ class test_AsyncHTTPConnection(AWSCase):
body.read.assert_called_with()
def test_request_with_body_text(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection('aws.vandelay.com')
x.request('PUT', '/importer-exporter', 'Vandelay Industries')
assert x.method == 'PUT'
assert x.path == '/importer-exporter'
assert x.body == 'Vandelay Industries'
def test_request_with_headers(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection()
headers = {'Proxy': 'proxy.vandelay.com'}
x.request('PUT', '/importer-exporter', None, headers)
assert 'Proxy' in dict(x.headers)
@@ -99,27 +94,10 @@ class test_AsyncHTTPConnection(AWSCase):
validate_cert=VALIDATES_CERT,
)
- def test_getrequest_AsyncHTTPSConnection(self):
- x = AsyncHTTPSConnection('aws.vandelay.com')
- x.Request = Mock(name='Request')
- x.getrequest()
- self.assert_request_created_with('https://aws.vandelay.com/', x)
-
- def test_getrequest_nondefault_port(self):
- x = AsyncHTTPConnection('aws.vandelay.com', port=8080)
- x.Request = Mock(name='Request')
- x.getrequest()
- self.assert_request_created_with('http://aws.vandelay.com:8080/', x)
-
- y = AsyncHTTPSConnection('aws.vandelay.com', port=8443)
- y.Request = Mock(name='Request')
- y.getrequest()
- self.assert_request_created_with('https://aws.vandelay.com:8443/', y)
-
def test_getresponse(self):
client = Mock(name='client')
client.add_request = passthrough(name='client.add_request')
- x = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
+ x = AsyncHTTPSConnection(http_client=client)
x.Response = Mock(name='x.Response')
request = x.getresponse()
x.http_client.add_request.assert_called_with(request)
@@ -134,7 +112,7 @@ class test_AsyncHTTPConnection(AWSCase):
client = Mock(name='client')
client.add_request = passthrough(name='client.add_request')
callback = PromiseMock(name='callback')
- x = AsyncHTTPConnection('aws.vandelay.com', http_client=client)
+ x = AsyncHTTPSConnection(http_client=client)
request = x.getresponse(callback)
x.http_client.add_request.assert_called_with(request)
@@ -151,22 +129,22 @@ class test_AsyncHTTPConnection(AWSCase):
assert wresponse.read() == 'The quick brown fox jumps'
assert wresponse.status == 200
assert wresponse.getheader('X-Foo') == 'Hello'
- assert dict(wresponse.getheaders()) == headers
- assert wresponse.msg
+ headers_dict = wresponse.getheaders()
+ assert dict(headers_dict) == headers
assert wresponse.msg
assert repr(wresponse)
def test_repr(self):
- assert repr(AsyncHTTPConnection('aws.vandelay.com'))
+ assert repr(AsyncHTTPSConnection())
def test_putrequest(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection()
x.putrequest('UPLOAD', '/new')
assert x.method == 'UPLOAD'
assert x.path == '/new'
def test_putheader(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection()
x.putheader('X-Foo', 'bar')
assert x.headers == [('X-Foo', 'bar')]
x.putheader('X-Bar', 'baz')
@@ -176,14 +154,14 @@ class test_AsyncHTTPConnection(AWSCase):
]
def test_send(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection()
x.send('foo')
assert x.body == 'foo'
x.send('bar')
assert x.body == 'foobar'
def test_interface(self):
- x = AsyncHTTPConnection('aws.vandelay.com')
+ x = AsyncHTTPSConnection()
assert x.set_debuglevel(3) is None
assert x.connect() is None
assert x.close() is None
@@ -204,102 +182,49 @@ class test_AsyncHTTPResponse(AWSCase):
class test_AsyncConnection(AWSCase):
- def test_when_boto_missing(self, patching):
- patching('kombu.async.aws.connection.boto', None)
- with pytest.raises(ImportError):
- AsyncConnection(Mock(name='client'))
-
def test_client(self):
- x = AsyncConnection()
+ sqs = Mock(name='sqs')
+ x = AsyncConnection(sqs)
assert x._httpclient is http.get_client()
client = Mock(name='client')
- y = AsyncConnection(http_client=client)
+ y = AsyncConnection(sqs, http_client=client)
assert y._httpclient is client
def test_get_http_connection(self):
- x = AsyncConnection(client=Mock(name='client'))
- assert isinstance(
- x.get_http_connection('aws.vandelay.com', 80, False),
- AsyncHTTPConnection,
- )
+ sqs = Mock(name='sqs')
+ x = AsyncConnection(sqs)
assert isinstance(
- x.get_http_connection('aws.vandelay.com', 443, True),
+ x.get_http_connection(),
AsyncHTTPSConnection,
)
-
- conn = x.get_http_connection('aws.vandelay.com', 80, False)
+ conn = x.get_http_connection()
assert conn.http_client is x._httpclient
- assert conn.host == 'aws.vandelay.com'
- assert conn.port == 80
-
-
-class test_AsyncAWSAuthConnection(AWSCase):
-
- @patch('boto.log', create=True)
- def test_make_request(self, _):
- x = AsyncAWSAuthConnection('aws.vandelay.com',
- http_client=Mock(name='client'))
- Conn = x.get_http_connection = Mock(name='get_http_connection')
- callback = PromiseMock(name='callback')
- ret = x.make_request('GET', '/foo', callback=callback)
- assert ret is callback
- Conn.return_value.request.assert_called()
- Conn.return_value.getresponse.assert_called_with(
- callback=callback,
- )
-
- @patch('boto.log', create=True)
- def test_mexe(self, _):
- x = AsyncAWSAuthConnection('aws.vandelay.com',
- http_client=Mock(name='client'))
- Conn = x.get_http_connection = Mock(name='get_http_connection')
- request = x.build_base_http_request('GET', 'foo', '/auth')
- callback = PromiseMock(name='callback')
- x._mexe(request, callback=callback)
- Conn.return_value.request.assert_called_with(
- request.method, request.path, request.body, request.headers,
- )
- Conn.return_value.getresponse.assert_called_with(
- callback=callback,
- )
-
- no_callback_ret = x._mexe(request)
- # _mexe always returns promise
- assert isinstance(no_callback_ret, Thenable)
-
- @patch('boto.log', create=True)
- def test_mexe__with_sender(self, _):
- x = AsyncAWSAuthConnection('aws.vandelay.com',
- http_client=Mock(name='client'))
- Conn = x.get_http_connection = Mock(name='get_http_connection')
- request = x.build_base_http_request('GET', 'foo', '/auth')
- sender = Mock(name='sender')
- callback = PromiseMock(name='callback')
- x._mexe(request, sender=sender, callback=callback)
- sender.assert_called_with(
- Conn.return_value, request.method, request.path,
- request.body, request.headers, callback,
- )
class test_AsyncAWSQueryConnection(AWSCase):
def setup(self):
- self.x = AsyncAWSQueryConnection('aws.vandelay.com',
+ session = boto3.session.Session(
+ aws_access_key_id='AAA',
+ aws_secret_access_key='AAAA',
+ region_name='us-west-2',
+ )
+ sqs_client = session.client('sqs')
+ self.x = AsyncAWSQueryConnection(sqs_client,
http_client=Mock(name='client'))
- @patch('boto.log', create=True)
- def test_make_request(self, _):
+ def test_make_request(self):
_mexe, self.x._mexe = self.x._mexe, Mock(name='_mexe')
Conn = self.x.get_http_connection = Mock(name='get_http_connection')
callback = PromiseMock(name='callback')
self.x.make_request(
- 'action', {'foo': 1}, '/', 'GET', callback=callback,
+ 'action', {'foo': 1}, 'https://foo.com/', 'GET', callback=callback,
)
self.x._mexe.assert_called()
request = self.x._mexe.call_args[0][0]
- assert request.params['Action'] == 'action'
- assert request.params['Version'] == self.x.APIVersion
+ parsed = urlparse(request.url)
+ params = parse_qs(parsed.query)
+ assert params['Action'][0] == 'action'
ret = _mexe(request, callback=callback)
assert ret is callback
@@ -308,29 +233,18 @@ class test_AsyncAWSQueryConnection(AWSCase):
callback=callback,
)
- @patch('boto.log', create=True)
- def test_make_request__no_action(self, _):
+ def test_make_request__no_action(self):
self.x._mexe = Mock(name='_mexe')
self.x.get_http_connection = Mock(name='get_http_connection')
callback = PromiseMock(name='callback')
self.x.make_request(
- None, {'foo': 1}, '/', 'GET', callback=callback,
+ None, {'foo': 1}, 'http://foo.com/', 'GET', callback=callback,
)
self.x._mexe.assert_called()
request = self.x._mexe.call_args[0][0]
- assert 'Action' not in request.params
- assert request.params['Version'] == self.x.APIVersion
-
- @contextmanager
- def mock_sax_parse(self, parser):
- with patch('kombu.async.aws.connection.sax_parse') as sax_parse:
- with patch('kombu.async.aws.connection.XmlHandler') as xh:
-
- def effect(body, h):
- return parser(xh.call_args[0][0], body, h)
- sax_parse.side_effect = effect
- yield (sax_parse, xh)
- sax_parse.assert_called()
+ parsed = urlparse(request.url)
+ params = parse_qs(parsed.query)
+ assert 'Action' not in params
def Response(self, status, body):
r = Mock(name='response')
@@ -347,90 +261,3 @@ class test_AsyncAWSQueryConnection(AWSCase):
def assert_make_request_called(self):
self.x.make_request.assert_called()
return self.x.make_request.call_args[1]['callback']
-
- def test_get_list(self):
- with self.mock_make_request() as callback:
- self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback)
- on_ready = self.assert_make_request_called()
-
- def parser(dest, body, h):
- dest.append('hi')
- dest.append('there')
-
- with self.mock_sax_parse(parser):
- on_ready(self.Response(200, 'hello'))
- callback.assert_called_with(['hi', 'there'])
-
- def test_get_list_error(self):
- with self.mock_make_request() as callback:
- self.x.get_list('action', {'p': 3.3}, ['m'], callback=callback)
- on_ready = self.assert_make_request_called()
-
- with pytest.raises(self.x.ResponseError):
- on_ready(self.Response(404, 'Not found'))
-
- def test_get_object(self):
- with self.mock_make_request() as callback:
-
- class Result(object):
- parent = None
- value = None
-
- def __init__(self, parent):
- self.parent = parent
-
- self.x.get_object('action', {'p': 3.3}, Result, callback=callback)
- on_ready = self.assert_make_request_called()
-
- def parser(dest, body, h):
- dest.value = 42
-
- with self.mock_sax_parse(parser):
- on_ready(self.Response(200, 'hello'))
-
- callback.assert_called()
- result = callback.call_args[0][0]
- assert result.value == 42
- assert result.parent
-
- def test_get_object_error(self):
- with self.mock_make_request() as callback:
- self.x.get_object('action', {'p': 3.3}, object, callback=callback)
- on_ready = self.assert_make_request_called()
-
- with pytest.raises(self.x.ResponseError):
- on_ready(self.Response(404, 'Not found'))
-
- def test_get_status(self):
- with self.mock_make_request() as callback:
- self.x.get_status('action', {'p': 3.3}, callback=callback)
- on_ready = self.assert_make_request_called()
- set_status_to = [True]
-
- def parser(dest, body, b):
- dest.status = set_status_to[0]
-
- with self.mock_sax_parse(parser):
- on_ready(self.Response(200, 'hello'))
- callback.assert_called_with(True)
-
- set_status_to[0] = False
- with self.mock_sax_parse(parser):
- on_ready(self.Response(200, 'hello'))
- callback.assert_called_with(False)
-
- def test_get_status_error(self):
- with self.mock_make_request() as callback:
- self.x.get_status('action', {'p': 3.3}, callback=callback)
- on_ready = self.assert_make_request_called()
-
- with pytest.raises(self.x.ResponseError):
- on_ready(self.Response(404, 'Not found'))
-
- def test_get_status_error_empty_body(self):
- with self.mock_make_request() as callback:
- self.x.get_status('action', {'p': 3.3}, callback=callback)
- on_ready = self.assert_make_request_called()
-
- with pytest.raises(self.x.ResponseError):
- on_ready(self.Response(200, ''))
diff --git a/t/unit/transport/test_SQS.py b/t/unit/transport/test_SQS.py
index f245fdab..180b33bf 100644
--- a/t/unit/transport/test_SQS.py
+++ b/t/unit/transport/test_SQS.py
@@ -8,97 +8,107 @@ slightly.
from __future__ import absolute_import, unicode_literals
import pytest
+import random
+import string
from case import Mock, skip
from kombu import messaging
from kombu import Connection, Exchange, Queue
-from kombu.async.aws.ext import exception
from kombu.five import Empty
from kombu.transport import SQS
-class SQSQueueMock(object):
-
- def __init__(self, name):
- self.name = name
- self.messages = []
- self._get_message_calls = 0
-
- def clear(self, page_size=10, vtimeout=10):
- empty, self.messages[:] = not self.messages, []
- return not empty
-
- def count(self, page_size=10, vtimeout=10):
- return len(self.messages)
- count_slow = count
+class SQSMessageMock(object):
+ def __init__(self):
+ """
+ Imitate the SQS Message from boto3.
+ """
+ self.body = ""
+ self.receipt_handle = "receipt_handle_xyz"
- def delete(self):
- self.messages[:] = []
- return True
- def delete_message(self, message):
- try:
- self.messages.remove(message)
- except ValueError:
- return False
- return True
+class QueueMock(object):
+ """ Hold information about a queue. """
- def get_messages(self, num_messages=1, visibility_timeout=None,
- attributes=None, *args, **kwargs):
- self._get_message_calls += 1
- messages, self.messages[:num_messages] = (
- self.messages[:num_messages], [])
- return messages
+ def __init__(self, url):
+ self.url = url
+ self.attributes = {'ApproximateNumberOfMessages': '0'}
- def read(self, visibility_timeout=None):
- return self.messages.pop(0)
+ self.messages = []
- def write(self, message):
- self.messages.append(message)
- return True
+ def __repr__(self):
+ return 'QueueMock: {} {} messages'.format(self.url, len(self.messages))
-class SQSConnectionMock(object):
+class SQSClientMock(object):
def __init__(self):
- self.queues = {
- 'q_%s' % n: SQSQueueMock('q_%s' % n) for n in range(1500)
- }
- q = SQSQueueMock('unittest_queue')
- q.write('hello')
- self.queues['unittest_queue'] = q
-
- def get_queue(self, queue):
- return self.queues.get(queue)
-
- def get_all_queues(self, prefix=""):
- if not prefix:
- keys = sorted(self.queues.keys())[:1000]
- else:
- keys = list(filter(
- lambda k: k.startswith(prefix), sorted(self.queues.keys())
- ))[:1000]
- return [self.queues[key] for key in keys]
+ """
+ Imitate the SQS Client from boto3.
+ """
+ self._receive_messages_calls = 0
+ # _queues doesn't exist on the real client, here for testing.
+ self._queues = {}
+ for n in range(1):
+ name = 'q_{}'.format(n)
+ url = 'sqs://q_{}'.format(n)
+ self.create_queue(QueueName=name)
+
+ url = self.create_queue(QueueName='unittest_queue')['QueueUrl']
+ self.send_message(QueueUrl=url, MessageBody='hello')
+
+ def _get_q(self, url):
+ """ Helper method to quickly get a queue. """
+ for q in self._queues.values():
+ if q.url == url:
+ return q
+ raise Exception("Queue url {} not found".format(url))
+
+ def create_queue(self, QueueName=None, Attributes=None):
+ q = self._queues[QueueName] = QueueMock('sqs://' + QueueName)
+ return {'QueueUrl': q.url}
+
+ def list_queues(self, QueueNamePrefix=None):
+ """ Return a list of queue urls """
+ urls = (val.url for key, val in self._queues.items()
+ if key.startswith(QueueNamePrefix))
+ return {'QueueUrls': urls}
+
+ def get_queue_url(self, QueueName=None):
+ return self._queues[QueueName]
+
+ def send_message(self, QueueUrl=None, MessageBody=None):
+ for q in self._queues.values():
+ if q.url == QueueUrl:
+ handle = ''.join(random.choice(string.ascii_lowercase) for
+ x in range(10))
+ q.messages.append({'Body': MessageBody,
+ 'ReceiptHandle': handle})
+ break
- def delete_queue(self, queue, force_deletion=False):
- q = self.get_queue(queue)
- if q:
- if q.count():
- return False
- q.clear()
- self.queues.pop(queue, None)
+ def receive_message(self, QueueUrl=None, MaxNumberOfMessages=1):
+ self._receive_messages_calls += 1
+ for q in self._queues.values():
+ if q.url == QueueUrl:
+ msgs = q.messages[:MaxNumberOfMessages]
+ q.messages = q.messages[MaxNumberOfMessages:]
+ return {'Messages': msgs}
- def delete_message(self, queue, message):
- return queue.delete_message(message)
+ def get_queue_attributes(self, QueueUrl=None, AttributeNames=None):
+ if 'ApproximateNumberOfMessages' in AttributeNames:
+ count = len(self._get_q(QueueUrl).messages)
+ return {'Attributes': {'ApproximateNumberOfMessages': count}}
- def create_queue(self, name, *args, **kwargs):
- q = self.queues[name] = SQSQueueMock(name)
- return q
+ def purge_queue(self, QueueUrl=None):
+ for q in self._queues.values():
+ if q.url == QueueUrl:
+ q.messages = []
@skip.unless_module('boto')
+@skip.unless_module('boto3')
class test_Channel:
def handleMessageCallback(self, message):
@@ -115,7 +125,7 @@ class test_Channel:
# Mock the sqs() method that returns an SQSConnection object and
# instead return an SQSConnectionMock() object.
- self.sqs_conn_mock = SQSConnectionMock()
+ self.sqs_conn_mock = SQSClientMock()
def mock_sqs():
return self.sqs_conn_mock
@@ -125,9 +135,9 @@ class test_Channel:
self.exchange = Exchange('test_SQS', type='direct')
self.queue = Queue(self.queue_name, self.exchange, self.queue_name)
- # Mock up a test SQS Queue with the SQSQueueMock class (and always
+ # Mock up a test SQS Queue with the QueueMock class (and always
# make sure its a clean empty queue)
- self.sqs_queue_mock = SQSQueueMock(self.queue_name)
+ self.sqs_queue_mock = QueueMock('sqs://' + self.queue_name)
# Now, create our Connection object with the SQS Transport and store
# the connection/channel objects as references for use in these tests.
@@ -160,33 +170,10 @@ class test_Channel:
"""kombu.SQS.Channel instantiates correctly with mocked queues"""
assert self.queue_name in self.channel._queue_cache
- def test_auth_fail(self):
- normal_func = SQS.Channel.sqs.get_all_queues
-
- def get_all_queues_fail_403(prefix=''):
- # mock auth error
- raise exception.SQSError(403, None, None)
-
- def get_all_queues_fail_not_403(prefix=''):
- # mock non-auth error
- raise exception.SQSError(500, None, None)
-
- try:
- SQS.Channel.sqs.access_key = '1234'
- SQS.Channel.sqs.get_all_queues = get_all_queues_fail_403
- with pytest.raises(RuntimeError) as excinfo:
- self.channel = self.connection.channel()
- assert 'access_key=1234' in str(excinfo.value)
- SQS.Channel.sqs.get_all_queues = get_all_queues_fail_not_403
- with pytest.raises(exception.SQSError):
- self.channel = self.connection.channel()
- finally:
- SQS.Channel.sqs.get_all_queues = normal_func
-
def test_new_queue(self):
queue_name = 'new_unittest_queue'
self.channel._new_queue(queue_name)
- assert queue_name in self.sqs_conn_mock.queues
+ assert queue_name in self.sqs_conn_mock._queues.keys()
# For cleanup purposes, delete the queue and the queue file
self.channel._delete(queue_name)
@@ -195,11 +182,13 @@ class test_Channel:
# which is definitely out of cache when get_all_queues returns the
# first 1000 queues sorted by name.
queue_name = 'unittest_queue'
+ # This should not create a new queue.
self.channel._new_queue(queue_name)
- assert queue_name in self.sqs_conn_mock.queues
- q = self.sqs_conn_mock.get_queue(queue_name)
- assert 1 == q.count()
- assert 'hello' == q.read()
+ assert queue_name in self.sqs_conn_mock._queues.keys()
+ queue = self.sqs_conn_mock._queues[queue_name]
+ # The queue originally had 1 message in it.
+ assert 1 == len(queue.messages)
+ assert 'hello' == queue.messages[0]['Body']
def test_delete(self):
queue_name = 'new_unittest_queue'
@@ -211,17 +200,16 @@ class test_Channel:
# Test getting a single message
message = 'my test message'
self.producer.publish(message)
- q = self.channel._new_queue(self.queue_name)
- results = q.get_messages()
- assert len(results) == 1
+ result = self.channel._get(self.queue_name)
+ assert 'body' in result.keys()
# Now test getting many messages
for i in range(3):
message = 'message: {0}'.format(i)
self.producer.publish(message)
- results = q.get_messages(num_messages=3)
- assert len(results) == 3
+ self.channel._get_bulk(self.queue_name, max_if_unlimited=3)
+ assert len(self.sqs_conn_mock._queues[self.queue_name].messages) == 0
def test_get_with_empty_list(self):
with pytest.raises(Empty):
@@ -244,10 +232,21 @@ class test_Channel:
message = {'foo': 'bar'}
self.channel._put(self.producer.routing_key, message)
- q = self.channel._new_queue(self.queue_name)
+ q_url = self.channel._new_queue(self.queue_name)
# Get the messages now
- kombu_messages = q.get_messages(num_messages=kombu_message_count)
- json_messages = q.get_messages(num_messages=json_message_count)
+ kombu_messages = []
+ from kombu.async.aws.sqs.ext import Message
+ for m in self.sqs_conn_mock.receive_message(
+ QueueUrl=q_url,
+ MaxNumberOfMessages=kombu_message_count)['Messages']:
+ m['Body'] = Message().decode(m['Body'])
+ kombu_messages.append(m)
+ json_messages = []
+ for m in self.sqs_conn_mock.receive_message(
+ QueueUrl=q_url,
+ MaxNumberOfMessages=json_message_count)['Messages']:
+ m['Body'] = Message().decode(m['Body'])
+ json_messages.append(m)
# Now convert them to payloads
kombu_payloads = self.channel._messages_to_python(
@@ -350,7 +349,7 @@ class test_Channel:
def test_drain_events_with_prefetch_none(self):
# Generate 20 messages
message_count = 20
- expected_get_message_count = 3
+ expected_receive_messages_count = 3
current_delivery_tag = [1]
@@ -378,6 +377,7 @@ class test_Channel:
assert self.channel.connection._deliver.call_count == message_count
- # How many times was the SQSConnectionMock get_message method called?
- assert (expected_get_message_count ==
- self.channel._queue_cache[self.queue_name]._get_message_calls)
+ # How many times was the SQSConnectionMock receive_message method
+ # called?
+ assert (expected_receive_messages_count ==
+ self.sqs_conn_mock._receive_messages_calls)