summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatus Valo <matusvalo@gmail.com>2020-05-17 23:34:02 +0200
committerAsif Saif Uddin <auvipy@gmail.com>2020-06-01 12:56:32 +0600
commit25a026665174c5330e6e3dfbf4361cf8e2a100f4 (patch)
tree284979e12f88f8fa0c3b39ec51670cabfaac9b4d
parent5cc4a7688201f4d2e53f468734cc2e7cf177fc40 (diff)
downloadkombu-25a026665174c5330e6e3dfbf4361cf8e2a100f4.tar.gz
Ensure connection when connecting to broker
-rw-r--r--kombu/connection.py56
-rw-r--r--t/integration/common.py17
-rw-r--r--t/integration/test_py_amqp.py29
-rw-r--r--t/unit/test_connection.py52
-rw-r--r--t/unit/transport/test_redis.py16
5 files changed, 108 insertions, 62 deletions
diff --git a/kombu/connection.py b/kombu/connection.py
index bbd7281f..64a49842 100644
--- a/kombu/connection.py
+++ b/kombu/connection.py
@@ -279,8 +279,8 @@ class Connection(object):
def connect(self):
"""Establish connection to server immediately."""
- self._closed = False
- return self.connection
+ conn_opts = self._extract_failover_opts()
+ return self.ensure_connection(**conn_opts)
def channel(self):
"""Create and return a new channel."""
@@ -424,11 +424,12 @@ class Connection(object):
if not reraise_as_library_errors:
ctx = self._dummy_context
with ctx():
- retry_over_time(self.connect, self.recoverable_connection_errors,
- (), {}, on_error, max_retries,
- interval_start, interval_step, interval_max,
- callback, timeout=timeout)
- return self
+ return retry_over_time(
+ self._connection_factory, self.recoverable_connection_errors,
+ (), {}, on_error, max_retries,
+ interval_start, interval_step, interval_max,
+ callback, timeout=timeout
+ )
@contextmanager
def _reraise_as_library_errors(
@@ -817,6 +818,20 @@ class Connection(object):
def qos_semantics_matches_spec(self):
return self.transport.qos_semantics_matches_spec(self.connection)
+ def _extract_failover_opts(self):
+ conn_opts = {}
+ transport_opts = self.transport_options
+ if transport_opts:
+ if 'max_retries' in transport_opts:
+ conn_opts['max_retries'] = transport_opts['max_retries']
+ if 'interval_start' in transport_opts:
+ conn_opts['interval_start'] = transport_opts['interval_start']
+ if 'interval_step' in transport_opts:
+ conn_opts['interval_step'] = transport_opts['interval_step']
+ if 'interval_max' in transport_opts:
+ conn_opts['interval_max'] = transport_opts['interval_max']
+ return conn_opts
+
@property
def connected(self):
"""Return true if the connection has been established."""
@@ -834,12 +849,17 @@ class Connection(object):
"""
if not self._closed:
if not self.connected:
- self.declared_entities.clear()
- self._default_channel = None
- self._connection = self._establish_connection()
- self._closed = False
+ conn_opts = self._extract_failover_opts()
+ self._connection = self.ensure_connection(**conn_opts)
return self._connection
+ def _connection_factory(self):
+ self.declared_entities.clear()
+ self._default_channel = None
+ connection = self._establish_connection()
+ self._closed = False
+ return connection
+
@property
def default_channel(self):
"""Default channel.
@@ -852,20 +872,6 @@ class Connection(object):
a connection is passed instead of a channel, to functions that
require a channel.
"""
- conn_opts = {}
- transport_opts = self.transport_options
- if transport_opts:
- if 'max_retries' in transport_opts:
- conn_opts['max_retries'] = transport_opts['max_retries']
- if 'interval_start' in transport_opts:
- conn_opts['interval_start'] = transport_opts['interval_start']
- if 'interval_step' in transport_opts:
- conn_opts['interval_step'] = transport_opts['interval_step']
- if 'interval_max' in transport_opts:
- conn_opts['interval_max'] = transport_opts['interval_max']
-
- # make sure we're still connected, and if not refresh.
- self.ensure_connection(**conn_opts)
if self._default_channel is None:
self._default_channel = self.channel()
return self._default_channel
diff --git a/t/integration/common.py b/t/integration/common.py
index d32fc936..c62837d6 100644
--- a/t/integration/common.py
+++ b/t/integration/common.py
@@ -337,3 +337,20 @@ class BasePriority(object):
msg = buf.get(timeout=1)
msg.ack()
assert msg.payload == data
+
+
+class BaseFailover(BasicFunctionality):
+
+ def test_connect(self, failover_connection):
+ super(BaseFailover, self).test_connect(failover_connection)
+
+ def test_publish_consume(self, failover_connection):
+ super(BaseFailover, self).test_publish_consume(failover_connection)
+
+ def test_consume_empty_queue(self, failover_connection):
+ super(BaseFailover, self).test_consume_empty_queue(failover_connection)
+
+ def test_simple_buffer_publish_consume(self, failover_connection):
+ super(BaseFailover, self).test_simple_buffer_publish_consume(
+ failover_connection
+ )
diff --git a/t/integration/test_py_amqp.py b/t/integration/test_py_amqp.py
index 670e7009..1a32cce4 100644
--- a/t/integration/test_py_amqp.py
+++ b/t/integration/test_py_amqp.py
@@ -7,18 +7,22 @@ import kombu
from .common import (
BasicFunctionality, BaseExchangeTypes,
- BaseTimeToLive, BasePriority
+ BaseTimeToLive, BasePriority, BaseFailover
)
-def get_connection(
- hostname, port, vhost):
+def get_connection(hostname, port, vhost):
return kombu.Connection('pyamqp://{}:{}'.format(hostname, port))
+def get_failover_connection(hostname, port, vhost):
+ return kombu.Connection(
+ 'pyamqp://localhost:12345;pyamqp://{}:{}'.format(hostname, port)
+ )
+
+
@pytest.fixture()
def connection(request):
- # this fixture yields plain connections to broker and TLS encrypted
return get_connection(
hostname=os.environ.get('RABBITMQ_HOST', 'localhost'),
port=os.environ.get('RABBITMQ_5672_TCP', '5672'),
@@ -28,6 +32,17 @@ def connection(request):
)
+@pytest.fixture()
+def failover_connection(request):
+ return get_failover_connection(
+ hostname=os.environ.get('RABBITMQ_HOST', 'localhost'),
+ port=os.environ.get('RABBITMQ_5672_TCP', '5672'),
+ vhost=getattr(
+ request.config, "slaveinput", {}
+ ).get("slaveid", None),
+ )
+
+
@pytest.mark.env('py-amqp')
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPBasicFunctionality(BasicFunctionality):
@@ -50,3 +65,9 @@ class test_PyAMQPTimeToLive(BaseTimeToLive):
@pytest.mark.flaky(reruns=5, reruns_delay=2)
class test_PyAMQPPriority(BasePriority):
pass
+
+
+@pytest.mark.env('py-amqp')
+@pytest.mark.flaky(reruns=5, reruns_delay=2)
+class test_PyAMQPFailover(BaseFailover):
+ pass
diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py
index 9bec0eb2..c4e3764c 100644
--- a/t/unit/test_connection.py
+++ b/t/unit/test_connection.py
@@ -142,6 +142,32 @@ class test_Connection:
assert not _connection.connected
assert isinstance(conn.transport, Transport)
+ def test_connect_no_transport_options(self):
+ conn = self.conn
+ conn.ensure_connection = Mock()
+
+ conn.connect()
+ conn.ensure_connection.assert_called_with()
+
+ def test_connect_transport_options(self):
+ conn = self.conn
+ conn.transport_options = options = {
+ 'max_retries': 1,
+ 'interval_start': 2,
+ 'interval_step': 3,
+ 'interval_max': 4,
+ 'ignore_this': True
+ }
+ conn.ensure_connection = Mock()
+
+ conn.connect()
+ conn.ensure_connection.assert_called_with(**{
+ k: v for k, v in options.items()
+ if k in ['max_retries',
+ 'interval_start',
+ 'interval_step',
+ 'interval_max']})
+
def test_multiple_urls(self):
conn1 = Connection('amqp://foo;amqp://bar')
assert conn1.hostname == 'foo'
@@ -406,32 +432,6 @@ class test_Connection:
defchan.close.assert_called_with()
assert conn._default_channel is None
- def test_default_channel_no_transport_options(self):
- conn = self.conn
- conn.ensure_connection = Mock()
-
- assert conn.default_channel
- conn.ensure_connection.assert_called_with()
-
- def test_default_channel_transport_options(self):
- conn = self.conn
- conn.transport_options = options = {
- 'max_retries': 1,
- 'interval_start': 2,
- 'interval_step': 3,
- 'interval_max': 4,
- 'ignore_this': True
- }
- conn.ensure_connection = Mock()
-
- assert conn.default_channel
- conn.ensure_connection.assert_called_with(**{
- k: v for k, v in options.items()
- if k in ['max_retries',
- 'interval_start',
- 'interval_step',
- 'interval_max']})
-
def test_ensure_connection(self):
assert self.conn.ensure_connection()
diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py
index a36dfc2b..5e911929 100644
--- a/t/unit/transport/test_redis.py
+++ b/t/unit/transport/test_redis.py
@@ -10,7 +10,9 @@ from itertools import count
from case import ANY, ContextMock, Mock, call, mock, skip, patch
from kombu import Connection, Exchange, Queue, Consumer, Producer
-from kombu.exceptions import InconsistencyError, VersionMismatch
+from kombu.exceptions import (
+ InconsistencyError, VersionMismatch, OperationalError
+)
from kombu.five import Empty, Queue as _Queue, bytes_if_py2
from kombu.transport import virtual
from kombu.utils import eventio # patch poll
@@ -1043,10 +1045,11 @@ class test_Redis:
assert conn.transport.channel_errors
def test_check_at_least_we_try_to_connect_and_fail(self):
- import redis
- connection = Connection('redis://localhost:65534/')
+ connection = Connection(
+ 'redis://localhost:65534/', transport_options={'max_retries': 1}
+ )
- with pytest.raises(redis.exceptions.ConnectionError):
+ with pytest.raises(OperationalError):
chan = connection.channel()
chan._size('some_queue')
@@ -1465,13 +1468,12 @@ class test_RedisSentinel:
master_for().connection_pool.get_connection.assert_called()
def test_can_create_connection(self):
- from redis.exceptions import ConnectionError
-
connection = Connection(
'sentinel://localhost:65534/',
transport_options={
'master_name': 'not_important',
+ 'max_retries': 1
},
)
- with pytest.raises(ConnectionError):
+ with pytest.raises(OperationalError):
connection.channel()