summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaniel Nephin <dnephin@yelp.com>2014-08-22 07:31:21 -0700
committerDaniel Nephin <dnephin@yelp.com>2014-08-22 07:31:21 -0700
commit8b73aff04ee320ce27570c447028c30817f23d1a (patch)
tree50eb19a37aebe4c114a3a26e040feecc08de8577
parent6d8ef5bd03365836910edcd3398331ab22b17c00 (diff)
downloadretrying-8b73aff04ee320ce27570c447028c30817f23d1a.tar.gz
Support custom wait and stop functions.
-rw-r--r--retrying.py14
-rw-r--r--test_retrying.py14
2 files changed, 25 insertions, 3 deletions
diff --git a/retrying.py b/retrying.py
index ccf9668..50fc439 100644
--- a/retrying.py
+++ b/retrying.py
@@ -108,7 +108,9 @@ class Retrying(object):
wait_exponential_multiplier=None, wait_exponential_max=None,
retry_on_exception=None,
retry_on_result=None,
- wrap_exception=False):
+ wrap_exception=False,
+ stop_func=None,
+ wait_func=None):
self._stop_max_attempt_number = 5 if stop_max_attempt_number is None else stop_max_attempt_number
self._stop_max_delay = 100 if stop_max_delay is None else stop_max_delay
@@ -129,7 +131,10 @@ class Retrying(object):
if stop_max_delay is not None:
stop_funcs.append(self.stop_after_delay)
- if stop is None:
+ if stop_func is not None:
+ self.stop = stop_func
+
+ elif stop is None:
self.stop = lambda attempts, delay: any(f(attempts, delay) for f in stop_funcs)
else:
@@ -150,7 +155,10 @@ class Retrying(object):
if wait_exponential_multiplier is not None or wait_exponential_max is not None:
wait_funcs.append(self.exponential_sleep)
- if wait is None:
+ if wait_func is not None:
+ self.wait = wait_func
+
+ elif wait is None:
self.wait = lambda attempts, delay: max(f(attempts, delay) for f in wait_funcs)
else:
diff --git a/test_retrying.py b/test_retrying.py
index c163c41..da440b0 100644
--- a/test_retrying.py
+++ b/test_retrying.py
@@ -40,6 +40,13 @@ class TestStopConditions(unittest.TestCase):
def test_legacy_explicit_stop_type(self):
r = Retrying(stop="stop_after_attempt")
+ def test_stop_func(self):
+ r = Retrying(stop_func=lambda attempt, delay: attempt == delay)
+ self.assertFalse(r.stop(1, 3))
+ self.assertFalse(r.stop(100, 99))
+ self.assertTrue(r.stop(101, 101))
+
+
class TestWaitConditions(unittest.TestCase):
def test_no_sleep(self):
@@ -114,6 +121,13 @@ class TestWaitConditions(unittest.TestCase):
def test_legacy_explicit_wait_type(self):
r = Retrying(wait="exponential_sleep")
+ def test_wait_func(self):
+ r = Retrying(wait_func=lambda attempt, delay: attempt * delay)
+ self.assertEqual(r.wait(1, 5), 5)
+ self.assertEqual(r.wait(2, 11), 22)
+ self.assertEqual(r.wait(10, 100), 1000)
+
+
class NoneReturnUntilAfterCount:
"""
This class holds counter state for invoking a method several times in a row.