diff options
author | Ben Bangert <ben@groovie.org> | 2014-01-08 15:11:15 -0800 |
---|---|---|
committer | Ben Bangert <ben@groovie.org> | 2014-01-08 15:11:15 -0800 |
commit | 379c29dc9c9531157e561b53efb8261b5af38efc (patch) | |
tree | 1d5c2accd4a3c2e9a058b6421dc53bf9b5482319 | |
parent | 6f6ad316b829e380bb0bbdd25faf6385fb643611 (diff) | |
parent | ab1006ca51ca5035c23859806feda1342ce55fb4 (diff) | |
download | kazoo-379c29dc9c9531157e561b53efb8261b5af38efc.tar.gz |
Merge pull request #139 from nailor/fix-create-closed-client
client: Raise exception when calling create on closed client
-rw-r--r-- | kazoo/client.py | 29 | ||||
-rw-r--r-- | kazoo/tests/test_client.py | 24 |
2 files changed, 46 insertions, 7 deletions
diff --git a/kazoo/client.py b/kazoo/client.py index da2d624..287de30 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -477,19 +477,24 @@ class KazooClient(object): def _call(self, request, async_object): """Ensure there's an active connection and put the request in - the queue if there is.""" + the queue if there is. + + Returns False if the call short circuits due to AUTH_FAILED, + CLOSED, EXPIRED_SESSION or CONNECTING state. + + """ if self._state == KeeperState.AUTH_FAILED: async_object.set_exception(AuthFailedError()) - return + return False elif self._state == KeeperState.CLOSED: async_object.set_exception(ConnectionClosedError( "Connection has been closed")) - return + return False elif self._state in (KeeperState.EXPIRED_SESSION, KeeperState.CONNECTING): async_object.set_exception(SessionExpiredError()) - return + return False self._queue.append((request, async_object)) @@ -806,8 +811,10 @@ class KazooClient(object): async_result = self.handler.async_result() + @capture_exceptions(async_result) def do_create(): - self._create_async_inner(path, value, acl, flags, trailing=sequence).rawlink(create_completion) + result = self._create_async_inner(path, value, acl, flags, trailing=sequence) + result.rawlink(create_completion) @capture_exceptions(async_result) def retry_completion(result): @@ -832,8 +839,16 @@ class KazooClient(object): def _create_async_inner(self, path, value, acl, flags, trailing=False): async_result = self.handler.async_result() - self._call(Create(_prefix_root(self.chroot, path, trailing=trailing), value, acl, flags), - async_result) + call_result = self._call( + Create(_prefix_root(self.chroot, path, trailing=trailing), + value, acl, flags), async_result) + if call_result is False: + # We hit a short-circuit exit on the _call. Because we are + # not using the original async_result here, we bubble the + # exception upwards to the do_create function in + # KazooClient.create so that it gets set on the correct + # async_result object + raise async_result.exception return async_result def ensure_path(self, path, acl=None): diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index 6cae5a1..4d09a4b 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -11,6 +11,7 @@ from nose.tools import raises from kazoo.testing import KazooTestCase from kazoo.exceptions import ( + AuthFailedError, BadArgumentsError, ConfigurationError, ConnectionClosedError, @@ -19,7 +20,9 @@ from kazoo.exceptions import ( NoAuthError, NoNodeError, NodeExistsError, + SessionExpiredError, ) +from kazoo.protocol.states import KeeperState if sys.version_info > (3, ): # pragma: nocover @@ -366,6 +369,27 @@ class TestClient(KazooTestCase): eq_(path, "/1") self.assertTrue(client.exists("/1")) + def test_create_on_broken_connection(self): + client = self.client + client.start() + + client._state = KeeperState.EXPIRED_SESSION + self.assertRaises(SessionExpiredError, client.create, + '/closedpath', b'bar') + + client._state = KeeperState.AUTH_FAILED + self.assertRaises(AuthFailedError, client.create, + '/closedpath', b'bar') + + client._state = KeeperState.CONNECTING + self.assertRaises(SessionExpiredError, client.create, + '/closedpath', b'bar') + client.stop() + client.close() + + self.assertRaises(ConnectionClosedError, client.create, + '/closedpath', b'bar') + def test_create_unicode_path(self): client = self.client path = client.create(u("/ascii")) |