summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2012-09-11 17:32:00 +0100
committerAsk Solem <ask@celeryproject.org>2012-09-11 17:33:46 +0100
commit2fa467ac2dac12fed3680bf99d2efa97eb479fa0 (patch)
tree25438fa7638630451237791c66a93988284f68ba
parentae51d2f654a6bad1d48434ea9c964a27d605f166 (diff)
downloadkombu-2fa467ac2dac12fed3680bf99d2efa97eb479fa0.tar.gz
Connection: Support for multiple URLs and failover.
Either:: Connection(['amqp://foo', 'amqp://bar']) or: Connection('amqp://foo;amqp://bar') Fixes celery/celery#616
-rw-r--r--kombu/connection.py68
-rw-r--r--kombu/tests/test_utils.py5
-rw-r--r--kombu/utils/__init__.py26
-rw-r--r--kombu/utils/compat.py12
4 files changed, 91 insertions, 20 deletions
diff --git a/kombu/connection.py b/kombu/connection.py
index ac90fd2d..4a2bb290 100644
--- a/kombu/connection.py
+++ b/kombu/connection.py
@@ -17,16 +17,17 @@ import socket
from contextlib import contextmanager
from functools import partial
-from itertools import count
+from itertools import count, cycle
from urllib import quote
from Queue import Empty
+
# jython breaks on relative import for .exceptions for some reason
# (Issue #112)
from kombu import exceptions
from .log import get_logger
from .transport import get_transport_cls, supports_librabbitmq
-from .utils import cached_property, retry_over_time
-from .utils.compat import OrderedDict, LifoQueue as _LifoQueue
+from .utils import cached_property, retry_over_time, RetryNow, shufflecycle
+from .utils.compat import OrderedDict, LifoQueue as _LifoQueue, next
from .utils.url import parse_url
RESOLVE_ALIASES = {'amqplib': 'amqp',
@@ -39,6 +40,12 @@ __all__ = ['Connection', 'ConnectionPool', 'ChannelPool']
URI_PASSTHROUGH = frozenset(['sqla', 'sqlalchemy', 'zeromq', 'zmq'])
logger = get_logger(__name__)
+roundrobin_failover = cycle
+
+failover_strategies = {
+ 'round-robin': cycle,
+ 'shuffle': shufflecycle,
+}
class Connection(object):
@@ -102,14 +109,24 @@ class Connection(object):
password=None, virtual_host=None, port=None, insist=False,
ssl=False, transport=None, connect_timeout=5,
transport_options=None, login_method=None, uri_prefix=None,
- heartbeat=0, **kwargs):
+ heartbeat=0, failover_strategy='round-robin', **kwargs):
+ alt = []
# have to spell the args out, just to get nice docstrings :(
- params = {'hostname': hostname, 'userid': userid,
- 'password': password, 'virtual_host': virtual_host,
- 'port': port, 'insist': insist, 'ssl': ssl,
- 'transport': transport, 'connect_timeout': connect_timeout,
- 'login_method': login_method, 'heartbeat': heartbeat}
+ params = self._initial_params = {
+ 'hostname': hostname, 'userid': userid,
+ 'password': password, 'virtual_host': virtual_host,
+ 'port': port, 'insist': insist, 'ssl': ssl,
+ 'transport': transport, 'connect_timeout': connect_timeout,
+ 'login_method': login_method, 'heartbeat': heartbeat
+ }
+
+ if hostname and not isinstance(hostname, basestring):
+ alt.extend(hostname)
+ hostname = alt[0]
if hostname and '://' in hostname:
+ if ';' in hostname:
+ alt.extend(hostname.split(';'))
+ hostname = alt[0]
if '+' in hostname[:hostname.index('://')]:
# e.g. sqla+mysql://root:masterkey@localhost/
params['transport'], params['hostname'] = hostname.split('+')
@@ -119,6 +136,14 @@ class Connection(object):
params.update(parse_url(hostname))
self._init_params(**params)
+ # fallback hosts
+ self.alt = alt
+ self.failover_strategy = failover_strategies.get(
+ failover_strategy or 'round-robin') or failover_strategy
+ if self.alt:
+ self.cycle = self.failover_strategy(self.alt)
+ next(self.cycle) # skip first entry
+
# backend_cls argument will be removed shortly.
self.transport_cls = self.transport_cls or kwargs.get('backend_cls')
@@ -134,6 +159,14 @@ class Connection(object):
self.declared_entities = set()
+ def switch(self, url):
+ self.close()
+ self._closed = False
+ self._init_params(**dict(self._initial_params, **parse_url(url)))
+
+ def switch_next(self):
+ self.switch(next(self.cycle))
+
def _init_params(self, hostname, userid, password, virtual_host, port,
insist, ssl, transport, connect_timeout, login_method, heartbeat):
transport = transport or 'amqp'
@@ -263,15 +296,26 @@ class Connection(object):
each retry.
:keyword callback: Optional callback that is called for every
internal iteration (1 s)
- :keyword callback: Optional callback that is called for every
- internal iteration (1 s).
"""
+ def on_error(exc, intervals, retries, interval=0):
+ round = self.completes_cycle(retries)
+ if round:
+ interval = next(intervals)
+ if errback:
+ errback(exc, interval)
+ self.switch_next() # select next host
+
+ return interval if round else RetryNow
+
retry_over_time(self.connect, self.connection_errors, (), {},
- errback, max_retries,
+ on_error, max_retries,
interval_start, interval_step, interval_max, callback)
return self
+ def completes_cycle(self, retries):
+ return not (retries + 1) % len(self.alt) if self.alt else True
+
def revive(self, new_channel):
if self._default_channel:
self.maybe_close_channel(self._default_channel)
diff --git a/kombu/tests/test_utils.py b/kombu/tests/test_utils.py
index 9bf7f5d8..457658d8 100644
--- a/kombu/tests/test_utils.py
+++ b/kombu/tests/test_utils.py
@@ -11,6 +11,7 @@ else:
from StringIO import StringIO, StringIO as BytesIO # noqa
from kombu import utils
+from kombu.utils.compat import next
from .utils import redirect_stdouts, mask_modules, skip_if_module
from .utils import TestCase
@@ -167,10 +168,12 @@ class test_retry_over_time(TestCase):
raise self.Predicate()
return 42
- def errback(self, exc, interval):
+ def errback(self, exc, intervals, retries):
+ interval = next(intervals)
sleepvals = (None, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 16.0)
self.index += 1
self.assertEqual(interval, sleepvals[self.index])
+ return interval
@insomnia
def test_simple(self):
diff --git a/kombu/utils/__init__.py b/kombu/utils/__init__.py
index b402ff6c..6a1a3646 100644
--- a/kombu/utils/__init__.py
+++ b/kombu/utils/__init__.py
@@ -10,9 +10,11 @@ Internal utilities.
"""
from __future__ import absolute_import
+import random
import sys
from contextlib import contextmanager
+from itertools import count, repeat
from time import sleep
from uuid import UUID, uuid4 as _uuid4, _uuid_generate_random
@@ -28,6 +30,8 @@ __all__ = ['EqualityDict', 'say', 'uuid', 'kwdict', 'maybe_list',
'emergency_dump_state', 'cached_property',
'reprkwargs', 'reprcall', 'nested']
+RetryNow = object()
+
def eqhash(o):
try:
@@ -153,7 +157,7 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
interval_range = fxrange(interval_start,
interval_max + interval_start,
interval_step, repeatlast=True)
- for retries, interval in enumerate(interval_range): # for infinity
+ for retries in count():
try:
return fun(*args, **kwargs)
except catch, exc:
@@ -161,12 +165,12 @@ def retry_over_time(fun, catch, args=[], kwargs={}, errback=None,
raise
if callback:
callback()
- if errback:
- errback(exc, interval)
- for i in fxrange(stop=interval or 1.0):
- if i and callback:
- callback()
- sleep(i)
+ tts = errback(exc, interval_range, retries) if errback else None
+ if tts is not RetryNow:
+ for i in fxrange(stop=tts or 1.0):
+ if i and callback:
+ callback()
+ sleep(i)
def emergency_dump_state(state, open_file=open, dump=None):
@@ -301,3 +305,11 @@ def nested(*managers): # pragma: no cover
raise exc[0], exc[1], exc[2]
finally:
del(exc)
+
+
+def shufflecycle(it):
+ it = list(it) # don't modify callers list
+ shuffle = random.shuffle
+ for _ in repeat(None):
+ shuffle(it)
+ yield it[0]
diff --git a/kombu/utils/compat.py b/kombu/utils/compat.py
index c1b3cec4..d695728c 100644
--- a/kombu/utils/compat.py
+++ b/kombu/utils/compat.py
@@ -10,6 +10,18 @@ Helps compatibility with older Python versions.
"""
import sys
+############## __builtins__.next #############################################
+try:
+ next = next
+except NameError:
+ def next(it, *args): # noqa
+ try:
+ return it.__next__()
+ except StopIteration:
+ if not args:
+ raise
+ return args[0]
+
############## collections.OrderedDict #######################################
import weakref