diff options
author | Matus Valo <matusvalo@gmail.com> | 2020-05-17 23:34:02 +0200 |
---|---|---|
committer | Asif Saif Uddin <auvipy@gmail.com> | 2020-06-01 12:56:32 +0600 |
commit | 25a026665174c5330e6e3dfbf4361cf8e2a100f4 (patch) | |
tree | 284979e12f88f8fa0c3b39ec51670cabfaac9b4d | |
parent | 5cc4a7688201f4d2e53f468734cc2e7cf177fc40 (diff) | |
download | kombu-25a026665174c5330e6e3dfbf4361cf8e2a100f4.tar.gz |
Ensure connection when connecting to broker
-rw-r--r-- | kombu/connection.py | 56 | ||||
-rw-r--r-- | t/integration/common.py | 17 | ||||
-rw-r--r-- | t/integration/test_py_amqp.py | 29 | ||||
-rw-r--r-- | t/unit/test_connection.py | 52 | ||||
-rw-r--r-- | t/unit/transport/test_redis.py | 16 |
5 files changed, 108 insertions, 62 deletions
diff --git a/kombu/connection.py b/kombu/connection.py index bbd7281f..64a49842 100644 --- a/kombu/connection.py +++ b/kombu/connection.py @@ -279,8 +279,8 @@ class Connection(object): def connect(self): """Establish connection to server immediately.""" - self._closed = False - return self.connection + conn_opts = self._extract_failover_opts() + return self.ensure_connection(**conn_opts) def channel(self): """Create and return a new channel.""" @@ -424,11 +424,12 @@ class Connection(object): if not reraise_as_library_errors: ctx = self._dummy_context with ctx(): - retry_over_time(self.connect, self.recoverable_connection_errors, - (), {}, on_error, max_retries, - interval_start, interval_step, interval_max, - callback, timeout=timeout) - return self + return retry_over_time( + self._connection_factory, self.recoverable_connection_errors, + (), {}, on_error, max_retries, + interval_start, interval_step, interval_max, + callback, timeout=timeout + ) @contextmanager def _reraise_as_library_errors( @@ -817,6 +818,20 @@ class Connection(object): def qos_semantics_matches_spec(self): return self.transport.qos_semantics_matches_spec(self.connection) + def _extract_failover_opts(self): + conn_opts = {} + transport_opts = self.transport_options + if transport_opts: + if 'max_retries' in transport_opts: + conn_opts['max_retries'] = transport_opts['max_retries'] + if 'interval_start' in transport_opts: + conn_opts['interval_start'] = transport_opts['interval_start'] + if 'interval_step' in transport_opts: + conn_opts['interval_step'] = transport_opts['interval_step'] + if 'interval_max' in transport_opts: + conn_opts['interval_max'] = transport_opts['interval_max'] + return conn_opts + @property def connected(self): """Return true if the connection has been established.""" @@ -834,12 +849,17 @@ class Connection(object): """ if not self._closed: if not self.connected: - self.declared_entities.clear() - self._default_channel = None - self._connection = self._establish_connection() - self._closed = False + conn_opts = self._extract_failover_opts() + self._connection = self.ensure_connection(**conn_opts) return self._connection + def _connection_factory(self): + self.declared_entities.clear() + self._default_channel = None + connection = self._establish_connection() + self._closed = False + return connection + @property def default_channel(self): """Default channel. @@ -852,20 +872,6 @@ class Connection(object): a connection is passed instead of a channel, to functions that require a channel. """ - conn_opts = {} - transport_opts = self.transport_options - if transport_opts: - if 'max_retries' in transport_opts: - conn_opts['max_retries'] = transport_opts['max_retries'] - if 'interval_start' in transport_opts: - conn_opts['interval_start'] = transport_opts['interval_start'] - if 'interval_step' in transport_opts: - conn_opts['interval_step'] = transport_opts['interval_step'] - if 'interval_max' in transport_opts: - conn_opts['interval_max'] = transport_opts['interval_max'] - - # make sure we're still connected, and if not refresh. - self.ensure_connection(**conn_opts) if self._default_channel is None: self._default_channel = self.channel() return self._default_channel diff --git a/t/integration/common.py b/t/integration/common.py index d32fc936..c62837d6 100644 --- a/t/integration/common.py +++ b/t/integration/common.py @@ -337,3 +337,20 @@ class BasePriority(object): msg = buf.get(timeout=1) msg.ack() assert msg.payload == data + + +class BaseFailover(BasicFunctionality): + + def test_connect(self, failover_connection): + super(BaseFailover, self).test_connect(failover_connection) + + def test_publish_consume(self, failover_connection): + super(BaseFailover, self).test_publish_consume(failover_connection) + + def test_consume_empty_queue(self, failover_connection): + super(BaseFailover, self).test_consume_empty_queue(failover_connection) + + def test_simple_buffer_publish_consume(self, failover_connection): + super(BaseFailover, self).test_simple_buffer_publish_consume( + failover_connection + ) diff --git a/t/integration/test_py_amqp.py b/t/integration/test_py_amqp.py index 670e7009..1a32cce4 100644 --- a/t/integration/test_py_amqp.py +++ b/t/integration/test_py_amqp.py @@ -7,18 +7,22 @@ import kombu from .common import ( BasicFunctionality, BaseExchangeTypes, - BaseTimeToLive, BasePriority + BaseTimeToLive, BasePriority, BaseFailover ) -def get_connection( - hostname, port, vhost): +def get_connection(hostname, port, vhost): return kombu.Connection('pyamqp://{}:{}'.format(hostname, port)) +def get_failover_connection(hostname, port, vhost): + return kombu.Connection( + 'pyamqp://localhost:12345;pyamqp://{}:{}'.format(hostname, port) + ) + + @pytest.fixture() def connection(request): - # this fixture yields plain connections to broker and TLS encrypted return get_connection( hostname=os.environ.get('RABBITMQ_HOST', 'localhost'), port=os.environ.get('RABBITMQ_5672_TCP', '5672'), @@ -28,6 +32,17 @@ def connection(request): ) +@pytest.fixture() +def failover_connection(request): + return get_failover_connection( + hostname=os.environ.get('RABBITMQ_HOST', 'localhost'), + port=os.environ.get('RABBITMQ_5672_TCP', '5672'), + vhost=getattr( + request.config, "slaveinput", {} + ).get("slaveid", None), + ) + + @pytest.mark.env('py-amqp') @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPBasicFunctionality(BasicFunctionality): @@ -50,3 +65,9 @@ class test_PyAMQPTimeToLive(BaseTimeToLive): @pytest.mark.flaky(reruns=5, reruns_delay=2) class test_PyAMQPPriority(BasePriority): pass + + +@pytest.mark.env('py-amqp') +@pytest.mark.flaky(reruns=5, reruns_delay=2) +class test_PyAMQPFailover(BaseFailover): + pass diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 9bec0eb2..c4e3764c 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -142,6 +142,32 @@ class test_Connection: assert not _connection.connected assert isinstance(conn.transport, Transport) + def test_connect_no_transport_options(self): + conn = self.conn + conn.ensure_connection = Mock() + + conn.connect() + conn.ensure_connection.assert_called_with() + + def test_connect_transport_options(self): + conn = self.conn + conn.transport_options = options = { + 'max_retries': 1, + 'interval_start': 2, + 'interval_step': 3, + 'interval_max': 4, + 'ignore_this': True + } + conn.ensure_connection = Mock() + + conn.connect() + conn.ensure_connection.assert_called_with(**{ + k: v for k, v in options.items() + if k in ['max_retries', + 'interval_start', + 'interval_step', + 'interval_max']}) + def test_multiple_urls(self): conn1 = Connection('amqp://foo;amqp://bar') assert conn1.hostname == 'foo' @@ -406,32 +432,6 @@ class test_Connection: defchan.close.assert_called_with() assert conn._default_channel is None - def test_default_channel_no_transport_options(self): - conn = self.conn - conn.ensure_connection = Mock() - - assert conn.default_channel - conn.ensure_connection.assert_called_with() - - def test_default_channel_transport_options(self): - conn = self.conn - conn.transport_options = options = { - 'max_retries': 1, - 'interval_start': 2, - 'interval_step': 3, - 'interval_max': 4, - 'ignore_this': True - } - conn.ensure_connection = Mock() - - assert conn.default_channel - conn.ensure_connection.assert_called_with(**{ - k: v for k, v in options.items() - if k in ['max_retries', - 'interval_start', - 'interval_step', - 'interval_max']}) - def test_ensure_connection(self): assert self.conn.ensure_connection() diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index a36dfc2b..5e911929 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -10,7 +10,9 @@ from itertools import count from case import ANY, ContextMock, Mock, call, mock, skip, patch from kombu import Connection, Exchange, Queue, Consumer, Producer -from kombu.exceptions import InconsistencyError, VersionMismatch +from kombu.exceptions import ( + InconsistencyError, VersionMismatch, OperationalError +) from kombu.five import Empty, Queue as _Queue, bytes_if_py2 from kombu.transport import virtual from kombu.utils import eventio # patch poll @@ -1043,10 +1045,11 @@ class test_Redis: assert conn.transport.channel_errors def test_check_at_least_we_try_to_connect_and_fail(self): - import redis - connection = Connection('redis://localhost:65534/') + connection = Connection( + 'redis://localhost:65534/', transport_options={'max_retries': 1} + ) - with pytest.raises(redis.exceptions.ConnectionError): + with pytest.raises(OperationalError): chan = connection.channel() chan._size('some_queue') @@ -1465,13 +1468,12 @@ class test_RedisSentinel: master_for().connection_pool.get_connection.assert_called() def test_can_create_connection(self): - from redis.exceptions import ConnectionError - connection = Connection( 'sentinel://localhost:65534/', transport_options={ 'master_name': 'not_important', + 'max_retries': 1 }, ) - with pytest.raises(ConnectionError): + with pytest.raises(OperationalError): connection.channel() |