summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Bangert <ben@groovie.org>2014-01-08 15:11:15 -0800
committerBen Bangert <ben@groovie.org>2014-01-08 15:11:15 -0800
commit379c29dc9c9531157e561b53efb8261b5af38efc (patch)
tree1d5c2accd4a3c2e9a058b6421dc53bf9b5482319
parent6f6ad316b829e380bb0bbdd25faf6385fb643611 (diff)
parentab1006ca51ca5035c23859806feda1342ce55fb4 (diff)
downloadkazoo-379c29dc9c9531157e561b53efb8261b5af38efc.tar.gz
Merge pull request #139 from nailor/fix-create-closed-client
client: Raise exception when calling create on closed client
-rw-r--r--kazoo/client.py29
-rw-r--r--kazoo/tests/test_client.py24
2 files changed, 46 insertions, 7 deletions
diff --git a/kazoo/client.py b/kazoo/client.py
index da2d624..287de30 100644
--- a/kazoo/client.py
+++ b/kazoo/client.py
@@ -477,19 +477,24 @@ class KazooClient(object):
def _call(self, request, async_object):
"""Ensure there's an active connection and put the request in
- the queue if there is."""
+ the queue if there is.
+
+ Returns False if the call short circuits due to AUTH_FAILED,
+ CLOSED, EXPIRED_SESSION or CONNECTING state.
+
+ """
if self._state == KeeperState.AUTH_FAILED:
async_object.set_exception(AuthFailedError())
- return
+ return False
elif self._state == KeeperState.CLOSED:
async_object.set_exception(ConnectionClosedError(
"Connection has been closed"))
- return
+ return False
elif self._state in (KeeperState.EXPIRED_SESSION,
KeeperState.CONNECTING):
async_object.set_exception(SessionExpiredError())
- return
+ return False
self._queue.append((request, async_object))
@@ -806,8 +811,10 @@ class KazooClient(object):
async_result = self.handler.async_result()
+ @capture_exceptions(async_result)
def do_create():
- self._create_async_inner(path, value, acl, flags, trailing=sequence).rawlink(create_completion)
+ result = self._create_async_inner(path, value, acl, flags, trailing=sequence)
+ result.rawlink(create_completion)
@capture_exceptions(async_result)
def retry_completion(result):
@@ -832,8 +839,16 @@ class KazooClient(object):
def _create_async_inner(self, path, value, acl, flags, trailing=False):
async_result = self.handler.async_result()
- self._call(Create(_prefix_root(self.chroot, path, trailing=trailing), value, acl, flags),
- async_result)
+ call_result = self._call(
+ Create(_prefix_root(self.chroot, path, trailing=trailing),
+ value, acl, flags), async_result)
+ if call_result is False:
+ # We hit a short-circuit exit on the _call. Because we are
+ # not using the original async_result here, we bubble the
+ # exception upwards to the do_create function in
+ # KazooClient.create so that it gets set on the correct
+ # async_result object
+ raise async_result.exception
return async_result
def ensure_path(self, path, acl=None):
diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py
index 6cae5a1..4d09a4b 100644
--- a/kazoo/tests/test_client.py
+++ b/kazoo/tests/test_client.py
@@ -11,6 +11,7 @@ from nose.tools import raises
from kazoo.testing import KazooTestCase
from kazoo.exceptions import (
+ AuthFailedError,
BadArgumentsError,
ConfigurationError,
ConnectionClosedError,
@@ -19,7 +20,9 @@ from kazoo.exceptions import (
NoAuthError,
NoNodeError,
NodeExistsError,
+ SessionExpiredError,
)
+from kazoo.protocol.states import KeeperState
if sys.version_info > (3, ): # pragma: nocover
@@ -366,6 +369,27 @@ class TestClient(KazooTestCase):
eq_(path, "/1")
self.assertTrue(client.exists("/1"))
+ def test_create_on_broken_connection(self):
+ client = self.client
+ client.start()
+
+ client._state = KeeperState.EXPIRED_SESSION
+ self.assertRaises(SessionExpiredError, client.create,
+ '/closedpath', b'bar')
+
+ client._state = KeeperState.AUTH_FAILED
+ self.assertRaises(AuthFailedError, client.create,
+ '/closedpath', b'bar')
+
+ client._state = KeeperState.CONNECTING
+ self.assertRaises(SessionExpiredError, client.create,
+ '/closedpath', b'bar')
+ client.stop()
+ client.close()
+
+ self.assertRaises(ConnectionClosedError, client.create,
+ '/closedpath', b'bar')
+
def test_create_unicode_path(self):
client = self.client
path = client.create(u("/ascii"))