summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTerry Cain <terrycain@users.noreply.github.com>2021-01-04 13:14:39 +0000
committerGitHub <noreply@github.com>2021-01-04 14:14:39 +0100
commit3d41ab138960214fcbd1110037e2c2ed2473f22d (patch)
tree2eaf3a6216b6b344f53e906bf91aaaa8b0b6cebb
parenta37a05616f5c85fe690e1a8f0f4395e371d25943 (diff)
downloadkombu-3d41ab138960214fcbd1110037e2c2ed2473f22d.tar.gz
Support for Azure Service Bus 7.0.0 (#1284)
* Started servicebus refactor * Cleaned up, handle service bus SAS token parsing
-rw-r--r--kombu/transport/azureservicebus.py295
-rw-r--r--requirements/extras/azureservicebus.txt2
-rw-r--r--t/unit/transport/test_azureservicebus.py475
3 files changed, 476 insertions, 296 deletions
diff --git a/kombu/transport/azureservicebus.py b/kombu/transport/azureservicebus.py
index ff251027..b117ef8b 100644
--- a/kombu/transport/azureservicebus.py
+++ b/kombu/transport/azureservicebus.py
@@ -4,9 +4,6 @@ Note that the Shared Access Policy used to connect to Azure Service Bus
requires Manage, Send and Listen claims since the broker will create new
queues and delete old queues as required.
-Note that if the SAS key for the Service Bus account contains a slash, it will
-have to be regenerated before it can be used in the connection URL.
-
More information about Azure Service Bus:
https://azure.microsoft.com/en-us/services/service-bus/
@@ -31,31 +28,28 @@ Connection string has the following format:
Transport Options
=================
-* ``visibility_timeout``
-* ``queue_name_prefix``
-* ``wait_time_seconds``
-* ``peek_lock``
+* ``queue_name_prefix`` - String prefix to prepend to queue names in a service bus namespace
+* ``wait_time_seconds`` - Number of seconds to wait to receive messages. Default ``5``
+* ``peek_lock_seconds`` - Number of seconds the message is visible for before it is requeued
+and sent to another consumer. Default ``60``
"""
import string
from queue import Empty
+from typing import Dict, Any, Optional, Union, Set
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import loads, dumps
from kombu.utils.objects import cached_property
-from . import virtual
+import azure.core.exceptions
+import azure.servicebus.exceptions
+import isodate
+from azure.servicebus import ServiceBusClient, ServiceBusMessage, ServiceBusReceiver, ServiceBusSender, \
+ ServiceBusReceiveMode
+from azure.servicebus.management import ServiceBusAdministrationClient
-try:
- # azure-servicebus version <= 0.21.1
- from azure.servicebus import ServiceBusService, Message, Queue
-except ImportError:
- try:
- # azure-servicebus version >= 0.50.0
- from azure.servicebus.control_client import \
- ServiceBusService, Message, Queue
- except ImportError:
- ServiceBusService = Message = Queue = None
+from . import virtual
# dots are replaced by dash, all other punctuation replaced by underscore.
CHARS_REPLACE_TABLE = {
@@ -63,95 +57,251 @@ CHARS_REPLACE_TABLE = {
}
+class SendReceive:
+ def __init__(self, receiver: Optional[ServiceBusReceiver] = None, sender: Optional[ServiceBusSender] = None):
+ self.receiver = receiver # type: ServiceBusReceiver
+ self.sender = sender # type: ServiceBusSender
+
+ def close(self) -> None:
+ if self.receiver:
+ self.receiver.close()
+ self.receiver = None
+ if self.sender:
+ self.sender.close()
+ self.sender = None
+
+
class Channel(virtual.Channel):
"""Azure Service Bus channel."""
- default_visibility_timeout = 1800 # 30 minutes.
default_wait_time_seconds = 5 # in seconds
- default_peek_lock = False
+ default_peek_lock_seconds = 60 # in seconds (default 60, max 300)
domain_format = 'kombu%(vhost)s'
- _queue_service = None
- _queue_cache = {}
+ _queue_service = None # type: ServiceBusClient
+ _queue_mgmt_service = None # type: ServiceBusAdministrationClient
+ _queue_cache = {} # type: Dict[str, SendReceive]
+ _noack_queues = set() # type: Set[str]
def __init__(self, *args, **kwargs):
- if ServiceBusService is None:
- raise ImportError('Azure Service Bus transport requires the '
- 'azure-servicebus library')
-
super().__init__(*args, **kwargs)
- for queue in self.queue_service.list_queues():
- self._queue_cache[queue] = queue
+ self._namespace = None
+ self._policy = None
+ self._sas_key = None
+ self._connection_string = None
+
+ self._try_parse_connection_string()
+
+ self.qos.restore_at_shutdown = False
+
+ def _try_parse_connection_string(self) -> None:
+ # URL like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}
+ # urllib parse does not work as the sas key could contain a slash
+ # e.g. azureservicebus://rootpolicy:some/key@somenamespace
+ uri = self.conninfo.hostname.replace('azureservicebus://', '') # > 'rootpolicy:some/key@somenamespace'
+ policykeypair, self._namespace = uri.rsplit('@', 1) # > 'rootpolicy:some/key', 'somenamespace'
+ self._policy, self._sas_key = policykeypair.split(':', 1) # > 'rootpolicy', 'some/key'
+
+ # Validate ASB connection string
+ if not all([self._namespace, self._policy, self._sas_key]):
+ raise ValueError('Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}')
+
+ # Convert
+ endpoint = 'sb://' + self._namespace
+ if not endpoint.endswith('.net'):
+ endpoint += '.servicebus.windows.net'
+
+ conn_dict = {
+ 'Endpoint': endpoint,
+ 'SharedAccessKeyName': self._policy,
+ 'SharedAccessKey': self._sas_key,
+ }
+ self._connection_string = ';'.join([key + '=' + value for key, value in conn_dict.items()])
+
+ def basic_consume(self, queue, no_ack, *args, **kwargs):
+ if no_ack:
+ self._noack_queues.add(queue)
+ return super().basic_consume(
+ queue, no_ack, *args, **kwargs
+ )
- def entity_name(self, name, table=CHARS_REPLACE_TABLE):
+ def basic_cancel(self, consumer_tag):
+ if consumer_tag in self._consumers:
+ queue = self._tag_to_queue[consumer_tag]
+ self._noack_queues.discard(queue)
+ return super().basic_cancel(consumer_tag)
+
+ def _add_queue_to_cache(self,
+ name: str,
+ receiver: Optional[ServiceBusReceiver] = None,
+ sender: Optional[ServiceBusSender] = None) -> SendReceive:
+ if name in self._queue_cache:
+ obj = self._queue_cache[name]
+ obj.sender = obj.sender or sender
+ obj.receiver = obj.receiver or receiver
+ else:
+ obj = SendReceive(receiver, sender)
+ self._queue_cache[name] = obj
+ return obj
+
+ def _get_asb_sender(self, queue: str) -> SendReceive:
+ queue_obj = self._queue_cache.get(queue, None)
+ if queue_obj is None or queue_obj.sender is None:
+ sender = self.queue_service.get_queue_sender(queue)
+ queue_obj = self._add_queue_to_cache(queue, sender=sender)
+ return queue_obj
+
+ def _get_asb_receiver(self, queue: str,
+ recv_mode: ServiceBusReceiveMode = ServiceBusReceiveMode.PEEK_LOCK,
+ queue_cache_key: Optional[str] = None) -> SendReceive:
+ cache_key = queue_cache_key or queue
+ queue_obj = self._queue_cache.get(cache_key, None)
+ if queue_obj is None or queue_obj.receiver is None:
+ receiver = self.queue_service.get_queue_receiver(queue_name=queue, receive_mode=recv_mode)
+ queue_obj = self._add_queue_to_cache(cache_key, receiver=receiver)
+ return queue_obj
+
+ def entity_name(self, name: str, table: Optional[Dict[int, int]] = None) -> str:
"""Format AMQP queue name into a valid ServiceBus queue name."""
- return str(safe_str(name)).translate(table)
+ return str(safe_str(name)).translate(table or CHARS_REPLACE_TABLE)
+
+ def _restore(self, message: virtual.base.Message) -> None:
+ # Not be needed as ASB handles unacked messages
+ # Remove 'azure_message' as its not JSON serializable
+ # message.delivery_info.pop('azure_message', None)
+ # super()._restore(message)
+ pass
- def _new_queue(self, queue, **kwargs):
+ def _new_queue(self, queue: str, **kwargs) -> SendReceive:
"""Ensure a queue exists in ServiceBus."""
queue = self.entity_name(self.queue_name_prefix + queue)
+
try:
return self._queue_cache[queue]
except KeyError:
- self.queue_service.create_queue(queue, fail_on_exist=False)
- q = self._queue_cache[queue] = self.queue_service.get_queue(queue)
- return q
-
- def _delete(self, queue, *args, **kwargs):
+ # Converts seconds into ISO8601 duration format ie 66seconds = P1M6S
+ lock_duration = isodate.duration_isoformat(isodate.Duration(seconds=self.peek_lock_seconds))
+ try:
+ self.queue_mgmt_service.create_queue(queue_name=queue, lock_duration=lock_duration)
+ except azure.core.exceptions.ResourceExistsError:
+ pass
+ return self._add_queue_to_cache(queue)
+
+ def _delete(self, queue: str, *args, **kwargs) -> None:
"""Delete queue by name."""
- queue_name = self.entity_name(queue)
- self._queue_cache.pop(queue_name, None)
- self.queue_service.delete_queue(queue_name)
- super()._delete(queue_name)
+ queue = self.entity_name(self.queue_name_prefix + queue)
- def _put(self, queue, message, **kwargs):
+ self._queue_mgmt_service.delete_queue(queue)
+ send_receive_obj = self._queue_cache.pop(queue, None)
+ if send_receive_obj:
+ send_receive_obj.close()
+
+ def _put(self, queue: str, message, **kwargs) -> None:
"""Put message onto queue."""
- msg = Message(dumps(message))
- self.queue_service.send_queue_message(self.entity_name(queue), msg)
+ queue = self.entity_name(self.queue_name_prefix + queue)
+ msg = ServiceBusMessage(dumps(message))
- def _get(self, queue, timeout=None):
+ queue_obj = self._get_asb_sender(queue)
+ queue_obj.sender.send_messages(msg)
+
+ def _get(self, queue: str, timeout: Optional[Union[float, int]] = None) -> Dict[str, Any]:
"""Try to retrieve a single message off ``queue``."""
- message = self.queue_service.receive_queue_message(
- self.entity_name(queue),
- timeout=timeout or self.wait_time_seconds,
- peek_lock=self.peek_lock
- )
+ # If we're not ack'ing for this queue, just change receive_mode
+ recv_mode = ServiceBusReceiveMode.RECEIVE_AND_DELETE if queue in self._noack_queues else \
+ ServiceBusReceiveMode.PEEK_LOCK
+
+ queue = self.entity_name(self.queue_name_prefix + queue)
- if message.body is None:
+ queue_obj = self._get_asb_receiver(queue, recv_mode)
+ messages = queue_obj.receiver.receive_messages(max_message_count=1,
+ max_wait_time=timeout or self.wait_time_seconds)
+
+ if not messages:
raise Empty()
- return loads(bytes_to_str(message.body))
+ # message.body is either byte or generator[bytes]
+ message = messages[0]
+ if not isinstance(message.body, bytes):
+ body = b''.join(message.body)
+ else:
+ body = message.body
+
+ msg = loads(bytes_to_str(body))
+ msg['properties']['delivery_info']['azure_message'] = message
+
+ return msg
+
+ def basic_ack(self, delivery_tag: str, multiple: bool = False) -> None:
+ delivery_info = self.qos.get(delivery_tag).delivery_info
+
+ if delivery_info['exchange'] in self._noack_queues:
+ return super().basic_ack(delivery_tag)
- def _size(self, queue):
+ queue = self.entity_name(self.queue_name_prefix + delivery_info['exchange'])
+ queue_obj = self._get_asb_receiver(queue) # recv_mode is PEEK_LOCK when ack'ing messages
+
+ try:
+ queue_obj.receiver.complete_message(delivery_info['azure_message'])
+ except azure.servicebus.exceptions.MessageAlreadySettled:
+ super().basic_ack(delivery_tag)
+ except Exception:
+ super().basic_reject(delivery_tag)
+ else:
+ super().basic_ack(delivery_tag)
+
+ def _size(self, queue: str) -> int:
"""Return the number of messages in a queue."""
- return self._new_queue(queue).message_count
+ queue = self.entity_name(self.queue_name_prefix + queue)
+ props = self.queue_mgmt_service.get_queue_runtime_properties(queue)
+
+ return props.total_message_count
def _purge(self, queue):
"""Delete all current messages in a queue."""
+ # Azure doesn't provide a purge api yet
n = 0
+ max_purge_count = 10
+ queue = self.entity_name(self.queue_name_prefix + queue)
+
+ # By default all the receivers will be in PEEK_LOCK receive mode
+ queue_obj = self._queue_cache.get(queue, None)
+ if queue not in self._noack_queues or queue_obj is None or queue_obj.receiver is None:
+ queue_obj = self._get_asb_receiver(queue, ServiceBusReceiveMode.RECEIVE_AND_DELETE, 'purge_' + queue)
while True:
- message = self.queue_service.read_delete_queue_message(
- self.entity_name(queue), timeout=0.1)
+ messages = queue_obj.receiver.receive_messages(max_message_count=max_purge_count,
+ max_wait_time=0.2)
+ n += len(messages)
- if not message.body:
+ if len(messages) < max_purge_count:
break
- else:
- n += 1
return n
+ def close(self) -> None:
+ # receivers and senders spawn threads so clean them up
+ if not self.closed:
+ self.closed = True
+ for queue_obj in self._queue_cache.values():
+ queue_obj.close()
+ self._queue_cache.clear()
+
+ if self.connection is not None:
+ self.connection.close_channel(self)
+
@property
- def queue_service(self):
+ def queue_service(self) -> ServiceBusClient:
if self._queue_service is None:
- self._queue_service = ServiceBusService(
- service_namespace=self.conninfo.hostname,
- shared_access_key_name=self.conninfo.userid,
- shared_access_key_value=self.conninfo.password)
-
+ self._queue_service = ServiceBusClient.from_connection_string(self._connection_string)
return self._queue_service
@property
+ def queue_mgmt_service(self) -> ServiceBusAdministrationClient:
+ if self._queue_mgmt_service is None:
+ self._queue_mgmt_service = ServiceBusAdministrationClient.from_connection_string(self._connection_string)
+ return self._queue_mgmt_service
+
+ @property
def conninfo(self):
return self.connection.client
@@ -160,23 +310,19 @@ class Channel(virtual.Channel):
return self.connection.client.transport_options
@cached_property
- def visibility_timeout(self):
- return (self.transport_options.get('visibility_timeout') or
- self.default_visibility_timeout)
-
- @cached_property
- def queue_name_prefix(self):
+ def queue_name_prefix(self) -> str:
return self.transport_options.get('queue_name_prefix', '')
@cached_property
- def wait_time_seconds(self):
+ def wait_time_seconds(self) -> int:
return self.transport_options.get('wait_time_seconds',
self.default_wait_time_seconds)
@cached_property
- def peek_lock(self):
- return self.transport_options.get('peek_lock',
- self.default_peek_lock)
+ def peek_lock_seconds(self) -> int:
+ return min(self.transport_options.get('peek_lock_seconds',
+ self.default_peek_lock_seconds),
+ 300) # Limit upper bounds to 300
class Transport(virtual.Transport):
@@ -186,3 +332,4 @@ class Transport(virtual.Transport):
polling_interval = 1
default_port = None
+ can_parse_url = True
diff --git a/requirements/extras/azureservicebus.txt b/requirements/extras/azureservicebus.txt
index 8f6f15ce..35b96b35 100644
--- a/requirements/extras/azureservicebus.txt
+++ b/requirements/extras/azureservicebus.txt
@@ -1 +1 @@
-azure-servicebus>=0.21.1
+azure-servicebus>=7.0.0
diff --git a/t/unit/transport/test_azureservicebus.py b/t/unit/transport/test_azureservicebus.py
index fdccaadf..08069f8b 100644
--- a/t/unit/transport/test_azureservicebus.py
+++ b/t/unit/transport/test_azureservicebus.py
@@ -1,244 +1,277 @@
+import json
import pytest
+import base64
+import random
from queue import Empty
-from unittest.mock import patch
+from unittest.mock import patch, MagicMock
+from collections import namedtuple
from kombu import messaging
from kombu import Connection, Exchange, Queue
from kombu.transport import azureservicebus
-
+import azure.servicebus.exceptions
+import azure.core.exceptions
pytest.importorskip('azure.servicebus')
-try:
- # azure-servicebus version >= 0.50.0
- from azure.servicebus.control_client import Message, ServiceBusService
-except ImportError:
- try:
- # azure-servicebus version <= 0.21.1
- from azure.servicebus import Message, ServiceBusService
- except ImportError:
- ServiceBusService = Message = None
+from azure.servicebus import ServiceBusMessage, ServiceBusReceiveMode
+
+
+class ASBQueue:
+ def __init__(self, kwargs):
+ self.options = kwargs
+ self.items = []
+ self.waiting_ack = []
+ self.send_calls = []
+ self.recv_calls = []
+
+ def get_receiver(self, kwargs):
+ receive_mode = kwargs.get('receive_mode', ServiceBusReceiveMode.PEEK_LOCK)
+
+ class Receiver:
+ def close(self):
+ pass
+
+ def receive_messages(_self, **kwargs2):
+ max_message_count = kwargs2.get('max_message_count', 1)
+ result = []
+ if self.items:
+ while self.items or len(result) > max_message_count:
+ item = self.items.pop(0)
+ if receive_mode is ServiceBusReceiveMode.PEEK_LOCK:
+ self.waiting_ack.append(item)
+ result.append(item)
+
+ self.recv_calls.append({
+ 'receiver_options': kwargs,
+ 'receive_messages_options': kwargs2,
+ 'messages': result
+ })
+ return result
+ return Receiver()
+
+ def get_sender(self):
+ class Sender:
+ def close(self):
+ pass
+
+ def send_messages(_self, msg):
+ self.send_calls.append(msg)
+ self.items.append(msg)
+ return Sender()
+
+
+class ASBMock:
+ def __init__(self):
+ self.queues = {}
+
+ def get_queue_receiver(self, queue_name, **kwargs):
+ return self.queues[queue_name].get_receiver(kwargs)
+ def get_queue_sender(self, queue_name):
+ return self.queues[queue_name].get_sender()
-class QueueMock:
- """ Hold information about a queue. """
- def __init__(self, name):
- self.name = name
- self.messages = []
- self.message_count = len(self.messages)
+class ASBMgmtMock:
+ def __init__(self, queues):
+ self.queues = queues
- def __repr__(self):
- return 'QueueMock: {} messages'.format(len(self.messages))
+ def create_queue(self, queue_name, **kwargs):
+ if queue_name in self.queues:
+ raise azure.core.exceptions.ResourceExistsError()
+ self.queues[queue_name] = ASBQueue(kwargs)
+ def delete_queue(self, queue_name):
+ self.queues.pop(queue_name, None)
-def _create_mock_connection(url='', **kwargs):
+ def get_queue_runtime_properties(self, queue_name):
+ count = len(self.queues[queue_name].items)
+ mock = MagicMock()
+ mock.total_message_count = count
+ return mock
- class _Channel(azureservicebus.Channel):
- # reset _fanout_queues for each instance
- queues = []
- _queue_service = None
- def list_queues(self):
- return self.queues
+URL_NOCREDS = 'azureservicebus://'
+URL_CREDS = 'azureservicebus://policyname:ke/y@hostname'
- @property
- def queue_service(self):
- if self._queue_service is None:
- self._queue_service = AzureServiceBusClientMock()
- return self._queue_service
- class Transport(azureservicebus.Transport):
- Channel = _Channel
+def test_queue_service_nocredentials():
+ conn = Connection(URL_NOCREDS, transport=azureservicebus.Transport)
+ with pytest.raises(ValueError) as exc:
+ conn.channel()
+ assert exc == 'Need an URI like azureservicebus://{SAS policy name}:{SAS key}@{ServiceBus Namespace}'
- return Connection(url, transport=Transport, **kwargs)
+def test_queue_service():
+ # Test gettings queue service without credentials
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
+ with patch('kombu.transport.azureservicebus.ServiceBusClient') as m:
+ channel = conn.channel()
-class AzureServiceBusClientMock:
+ # Check the SAS token "ke/y" has been parsed from the url correctly
+ assert channel._sas_key == 'ke/y'
- def __init__(self):
- """
- Imitate the ServiceBus Client.
- """
- # queues doesn't exist on the real client, here for testing.
- self.queues = []
- self._queue_cache = {}
- self.queues.append(self.create_queue(queue_name='unittest_queue'))
-
- def create_queue(self, queue_name, queue=None, fail_on_exist=False):
- queue = QueueMock(name=queue_name)
- self.queues.append(queue)
- self._queue_cache[queue_name] = queue
- return queue
-
- def get_queue(self, queue_name=None):
- for queue in self.queues:
- if queue.name == queue_name:
- return queue
-
- def list_queues(self):
- return self.queues
-
- def send_queue_message(self, queue_name=None, message=None):
- queue = self.get_queue(queue_name)
- queue.messages.append(message)
-
- def receive_queue_message(self, queue_name, peek_lock=True, timeout=60):
- queue = self.get_queue(queue_name)
- if queue and len(queue.messages):
- return queue.messages.pop(0)
- return Message()
-
- def read_delete_queue_message(self, queue_name, timeout='60'):
- return self.receive_queue_message(queue_name, timeout=timeout)
-
- def delete_queue(self, queue_name=None):
- queue = self.get_queue(queue_name)
- if queue:
- del queue
-
-
-class test_Channel:
-
- def handleMessageCallback(self, message):
- self.callback_message = message
-
- def setup(self):
- self.url = 'azureservicebus://'
- self.queue_name = 'unittest_queue'
-
- self.exchange = Exchange('test_servicebus', type='direct')
- self.queue = Queue(self.queue_name, self.exchange, self.queue_name)
- self.connection = _create_mock_connection(self.url)
- self.channel = self.connection.default_channel
- self.queue(self.channel).declare()
-
- self.producer = messaging.Producer(self.channel,
- self.exchange,
- routing_key=self.queue_name)
-
- self.channel.basic_consume(self.queue_name,
- no_ack=False,
- callback=self.handleMessageCallback,
- consumer_tag='unittest')
-
- def teardown(self):
- # Removes QoS reserved messages so we don't restore msgs on shutdown.
- try:
- qos = self.channel._qos
- except AttributeError:
- pass
- else:
- if qos:
- qos._dirty.clear()
- qos._delivered.clear()
-
- def test_queue_service(self):
- # Test gettings queue service without credentials
- conn = Connection(self.url, transport=azureservicebus.Transport)
- with pytest.raises(ValueError) as exc:
- conn.channel()
- assert exc == 'You need to provide servicebus namespace'
-
- # Test getting queue service when queue_service is not setted
- with patch('kombu.transport.azureservicebus.ServiceBusService') as m:
- channel = conn.channel()
-
- # Remove queue service to get from service bus again
- channel._queue_service = None
- channel.queue_service
-
- assert m.call_count == 2
-
- # Calling queue_service again needs to reuse ServiceBus instance
- channel.queue_service
- assert m.call_count == 2
-
- def test_conninfo(self):
- conninfo = self.channel.conninfo
- assert conninfo is self.connection
-
- def test_transport_type(self):
- transport_options = self.channel.transport_options
- assert transport_options == {}
-
- def test_visibility_timeout(self):
- # Test getting default visibility timeout
- assert (
- self.channel.visibility_timeout ==
- azureservicebus.Channel.default_visibility_timeout
- )
-
- # Test getting value setted in transport options
- del self.channel.visibility_timeout
- self.channel.transport_options['visibility_timeout'] = 10
- assert self.channel.visibility_timeout == 10
-
- def test_wait_timeout_seconds(self):
- # Test getting default wait timeout seconds
- assert (
- self.channel.wait_time_seconds ==
- azureservicebus.Channel.default_wait_time_seconds
- )
-
- # Test getting value setted in transport options
- del self.channel.wait_time_seconds
- self.channel.transport_options['wait_time_seconds'] = 10
- assert self.channel.wait_time_seconds == 10
-
- def test_peek_lock(self):
- # Test getting default peek lock
- assert (
- self.channel.peek_lock ==
- azureservicebus.Channel.default_peek_lock
- )
-
- # Test getting value setted in transport options
- del self.channel.peek_lock
- self.channel.transport_options['peek_lock'] = True
- assert self.channel.peek_lock is True
-
- def test_get_from_azure(self):
- # Test getting a single message
- message = 'my test message'
- self.producer.publish(message)
- result = self.channel._get(self.queue_name)
- assert 'body' in result.keys()
-
- # Test getting multiple messages
- for i in range(3):
- message = f'message: {i}'
- self.producer.publish(message)
-
- queue_service = self.channel.queue_service
- assert len(queue_service.get_queue(self.queue_name).messages) == 3
-
- for i in range(3):
- result = self.channel._get(self.queue_name)
-
- assert len(queue_service.get_queue(self.queue_name).messages) == 0
-
- def test_get_with_empty_list(self):
- with pytest.raises(Empty):
- self.channel._get(self.queue_name)
-
- def test_put_and_get(self):
- message = 'my test message'
- self.producer.publish(message)
- results = self.queue(self.channel).get().payload
- assert message == results
-
- def test_delete_queue(self):
- # Test deleting queue without message
- queue_name = 'new_unittest_queue'
- self.channel._new_queue(queue_name)
-
- assert queue_name in self.channel._queue_cache
- self.channel._delete(queue_name)
- assert queue_name not in self.channel._queue_cache
-
- # Test deleting queue with message
- message = 'my test message'
- self.producer.publish(message)
- self.channel._delete(self.queue_name)
- assert queue_name not in self.channel._queue_cache
+ m.from_connection_string.return_value = 'test'
+
+ # Remove queue service to get from service bus again
+ channel._queue_service = None
+ assert channel.queue_service == 'test'
+ assert m.from_connection_string.call_count == 1
+
+ # Ensure that queue_service is cached
+ assert channel.queue_service == 'test'
+ assert m.from_connection_string.call_count == 1
+
+
+def test_conninfo():
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
+ channel = conn.channel()
+ assert channel.conninfo is conn
+
+
+def test_transport_type():
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
+ channel = conn.channel()
+ assert not channel.transport_options
+
+
+def test_default_wait_timeout_seconds():
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
+ channel = conn.channel()
+
+ assert channel.wait_time_seconds == azureservicebus.Channel.default_wait_time_seconds
+
+
+def test_custom_wait_timeout_seconds():
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport, transport_options={'wait_time_seconds': 10})
+ channel = conn.channel()
+
+ assert channel.wait_time_seconds == 10
+
+
+def test_default_peek_lock_seconds():
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
+ channel = conn.channel()
+
+ assert channel.peek_lock_seconds == azureservicebus.Channel.default_peek_lock_seconds
+
+
+def test_custom_peek_lock_seconds():
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
+ transport_options={'peek_lock_seconds': 65})
+ channel = conn.channel()
+
+ assert channel.peek_lock_seconds == 65
+
+
+def test_invalid_peek_lock_seconds():
+ # Max is 300
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport,
+ transport_options={'peek_lock_seconds': 900})
+ channel = conn.channel()
+
+ assert channel.peek_lock_seconds == 300
+
+
+@pytest.fixture
+def random_queue():
+ return 'azureservicebus_queue_{0}'.format(random.randint(1000,9999))
+
+
+@pytest.fixture
+def mock_asb():
+ return ASBMock()
+
+
+@pytest.fixture
+def mock_asb_management(mock_asb):
+ return ASBMgmtMock(queues=mock_asb.queues)
+
+
+MockQueue = namedtuple('MockQueue', ['queue_name', 'asb', 'asb_mgmt', 'conn', 'channel', 'producer', 'queue'])
+
+
+@pytest.fixture
+def mock_queue(mock_asb, mock_asb_management, random_queue) -> MockQueue:
+ exchange = Exchange('test_servicebus', type='direct')
+ queue = Queue(random_queue, exchange, random_queue)
+ conn = Connection(URL_CREDS, transport=azureservicebus.Transport)
+ channel = conn.channel()
+ channel._queue_service = mock_asb
+ channel._queue_mgmt_service = mock_asb_management
+
+ queue(channel).declare()
+ producer = messaging.Producer(channel, exchange, routing_key=random_queue)
+
+ return MockQueue(
+ random_queue,
+ mock_asb,
+ mock_asb_management,
+ conn,
+ channel,
+ producer,
+ queue
+ )
+
+
+def test_basic_put_get(mock_queue: MockQueue):
+ text_message = "test message"
+
+ # This ends up hitting channel._put
+ mock_queue.producer.publish(text_message)
+
+ assert len(mock_queue.asb.queues[mock_queue.queue_name].items) == 1
+ azure_msg = mock_queue.asb.queues[mock_queue.queue_name].items[0]
+ assert isinstance(azure_msg, ServiceBusMessage)
+
+ message = mock_queue.channel._get(mock_queue.queue_name)
+ azure_msg_decoded = json.loads(str(azure_msg))
+
+ assert message['body'] == azure_msg_decoded['body']
+
+ # Check the message has been annotated with the azure message object
+ # which is used to ack later
+ assert message['properties']['delivery_info']['azure_message'] is azure_msg
+
+ assert base64.b64decode(message['body']).decode() == text_message
+
+ # Ack is on by default, check an ack is waiting
+ assert len(mock_queue.asb.queues[mock_queue.queue_name].waiting_ack) == 1
+
+
+def test_empty_queue_get(mock_queue: MockQueue):
+ with pytest.raises(Empty):
+ mock_queue.channel._get(mock_queue.queue_name)
+
+
+def test_delete_empty_queue(mock_queue: MockQueue):
+ chan = mock_queue.channel
+ queue_name = 'random_queue_{0}'.format(random.randint(1000, 9999))
+
+ chan._new_queue(queue_name)
+ assert queue_name in chan._queue_cache
+ chan._delete(queue_name)
+ assert queue_name not in chan._queue_cache
+
+
+def test_delete_populated_queue(mock_queue: MockQueue):
+ mock_queue.producer.publish('test1234')
+
+ mock_queue.channel._delete(mock_queue.queue_name)
+ assert mock_queue.queue_name not in mock_queue.channel._queue_cache
+
+
+def test_purge(mock_queue: MockQueue):
+ mock_queue.producer.publish('test1234')
+ mock_queue.producer.publish('test1234')
+ mock_queue.producer.publish('test1234')
+ mock_queue.producer.publish('test1234')
+
+ size = mock_queue.channel._size(mock_queue.queue_name)
+ assert size == 4
+
+ assert mock_queue.channel._purge(mock_queue.queue_name) == 4
+
+ size = mock_queue.channel._size(mock_queue.queue_name)
+ assert size == 0
+ assert len(mock_queue.asb.queues[mock_queue.queue_name].waiting_ack) == 0