diff options
| author | Ben Bangert <ben@groovie.org> | 2012-09-07 17:00:29 -0700 |
|---|---|---|
| committer | Ben Bangert <ben@groovie.org> | 2012-09-07 17:00:29 -0700 |
| commit | 2f73dab9d0cfec68d94de18b2910e8d97aefc724 (patch) | |
| tree | ddff63bc4ba52a2235c713d5c77c1fb736206d8e | |
| parent | 598eb80cafe5fa0c6a474f9f4ac157d79afcaa39 (diff) | |
| download | kazoo-2f73dab9d0cfec68d94de18b2910e8d97aefc724.tar.gz | |
Add transactions and some tests for them.
| -rw-r--r-- | docs/api/client.rst | 4 | ||||
| -rw-r--r-- | kazoo/client.py | 148 | ||||
| -rw-r--r-- | kazoo/exceptions.py | 2 | ||||
| -rw-r--r-- | kazoo/protocol/connection.py | 6 | ||||
| -rw-r--r-- | kazoo/protocol/serialization.py | 87 | ||||
| -rw-r--r-- | kazoo/tests/test_client.py | 53 |
6 files changed, 291 insertions, 9 deletions
diff --git a/docs/api/client.rst b/docs/api/client.rst index a6a0ab0..8efe955 100644 --- a/docs/api/client.rst +++ b/docs/api/client.rst @@ -29,3 +29,7 @@ Public API A :class:`~kazoo.protocol.states.KazooState` attribute indicating the current higher-level connection state. + + .. autoclass:: TransactionRequest + :members: + :member-order: bysource diff --git a/kazoo/client.py b/kazoo/client.py index 39444af..49d9d77 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -22,6 +22,7 @@ from kazoo.protocol.paths import normpath from kazoo.protocol.paths import _prefix_root from kazoo.protocol.serialization import ( Auth, + CheckVersion, Close, Create, Delete, @@ -32,7 +33,8 @@ from kazoo.protocol.serialization import ( SetACL, GetData, SetData, - Sync + Sync, + Transaction ) from kazoo.protocol.states import KazooState from kazoo.protocol.states import KeeperState @@ -936,6 +938,20 @@ class KazooClient(object): async_result) return async_result + def transaction(self): + """Create and return a :class:`TransactionRequest` object + + Creates a :class:`TransactionRequest` object. A Transaction can + consist of multiple operations which can be committed as a + single atomic unit. Either all of the operations will succeed + or none of them. + + :returns: A TransactionRequest. + :rtype: :class:`TransactionRequest` + + """ + return TransactionRequest(self) + def delete(self, path, version=-1, recursive=False): """Delete a node. @@ -1009,3 +1025,133 @@ class KazooClient(object): self.delete(path) except NoNodeError: # pragma: nocover pass + + +class TransactionRequest(object): + """A Zookeeper Transaction Request + + A Transaction provides a builder object that can be used to + construct and commit an atomic set of operations. The transaction + must be committed before its sent. + + Transactions are not thread-safe and should not be accessed from + multiple threads at once. + + """ + def __init__(self, client): + self.client = client + self.operations = [] + self.committed = False + + def create(self, path, value="", acl=None, ephemeral=False, + sequence=False): + """Add a create ZNode to the transaction. Takes the same + arguments as :meth:`KazooClient.create`, with the exception + of `makepath`. + + :returns: None + + """ + if acl is None and self.client.default_acl: + acl = self.client.default_acl + + if not isinstance(path, basestring): + raise TypeError("path must be a string") + if acl and not isinstance(acl, (tuple, list)): + raise TypeError("acl must be a tuple/list of ACL's") + if not isinstance(value, str): + raise TypeError("value must be a byte string") + if not isinstance(ephemeral, bool): + raise TypeError("ephemeral must be a bool") + if not isinstance(sequence, bool): + raise TypeError("sequence must be a bool") + + flags = 0 + if ephemeral: + flags |= 1 + if sequence: + flags |= 2 + if acl is None: + acl = OPEN_ACL_UNSAFE + + self._add(Create(_prefix_root(self.client.chroot, path), value, acl, + flags), None) + + def delete(self, path, version=-1): + """Add a delete ZNode to the transaction. Takes the same + arguments as :meth:`KazooClient.delete`, with the exception of + `recursive`. + + """ + if not isinstance(path, basestring): + raise TypeError("path must be a string") + if not isinstance(version, int): + raise TypeError("version must be an int") + self._add(Delete(_prefix_root(self.client.chroot, path), version)) + + def set_data(self, path, data, version=-1): + """Add a set ZNode value to the transaction. Takes the same + arguments as :meth:`KazooClient.set`. + + """ + if not isinstance(path, basestring): + raise TypeError("path must be a string") + if not isinstance(data, basestring): + raise TypeError("data must be a string") + if not isinstance(version, int): + raise TypeError("version must be an int") + self._add(SetData(_prefix_root(self.client.chroot, path), data, + version)) + + def check(self, path, version): + """Add a Check Version to the transaction. + + This command will fail and abort a transaction if the path + does not match the specified version. + + """ + if not isinstance(path, basestring): + raise TypeError("path must be a string") + if not isinstance(version, int): + raise TypeError("version must be an int") + self._add(CheckVersion(_prefix_root(self.client.chroot, path), + version)) + + def commit_async(self): + """Commit the transaction asynchronously + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + self._check_tx_state() + self.committed = True + async_object = self.client.handler.async_result() + self.client._call(Transaction(self.operations), async_object) + return async_object + + def commit(self): + """Commit the transaction + + :returns: A list of the results for each operation in the + transaction. + + """ + return self.commit_async().get() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + """Commit and cleanup accumulated transaction data""" + if not exc_type: + self.commit() + + def _check_tx_state(self): + if self.committed: + raise ValueError('Transaction already committed') + + def _add(self, request, post_processor=None): + self._check_tx_state() + if self.client.log_debug: + log.debug('Added %r to %r', request, self) + self.operations.append(request) diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index 6ab092f..66366aa 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -13,7 +13,7 @@ class ZookeeperError(KazooException): class CancelledError(KazooException): - """Raised when a process is cancelled by another thread""" + """Raised when a process is canceled by another thread""" class ConfigurationError(KazooException): diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index db776a8..2225de9 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -22,6 +22,7 @@ from kazoo.protocol.serialization import ( GetChildren, Ping, ReplyHeader, + Transaction, Watch, int_struct ) @@ -297,6 +298,11 @@ class ConnectionHandler(object): async_object.set_exception(exc) return log.debug('Received response: %r', response) + + # We special case a Transaction as we have to unchroot things + if request.type == Transaction.type: + response = Transaction.unchroot(client, response) + async_object.set(response) # Determine if watchers should be registered diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 474ba15..8324f3c 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -2,6 +2,7 @@ from collections import namedtuple import struct +from kazoo.exceptions import EXCEPTIONS from kazoo.protocol.states import ZnodeStat from kazoo.security import ACL from kazoo.security import Id @@ -13,6 +14,7 @@ int_int_struct = struct.Struct('!ii') int_int_long_struct = struct.Struct('!iiq') int_long_int_long_struct = struct.Struct('!iqiq') +multiheader_struct = struct.Struct('!iBi') reply_header_struct = struct.Struct('!iqi') stat_struct = struct.Struct('!qqqqiiiqiiq') @@ -287,6 +289,62 @@ class GetChildren2(namedtuple('GetChildren2', 'path watcher')): return children, stat +class CheckVersion(namedtuple('CheckVersion', 'path version')): + type = 13 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.version)) + return b + + +class Transaction(namedtuple('Transaction', 'operations')): + type = 14 + + def serialize(self): + b = bytearray() + for op in self.operations: + b.extend(MultiHeader(op.type, False, -1).serialize() + + op.serialize()) + return b + multiheader_struct.pack(-1, True, -1) + + @classmethod + def deserialize(cls, bytes, offset): + header = MultiHeader(None, False, None) + results = [] + response = None + while not header.done: + if header.type == Create.type: + response, offset = read_string(bytes, offset) + elif header.type == Delete.type: + response = True + elif header.type == SetData.type: + response = ZnodeStat._make( + stat_struct.unpack_from(bytes, offset)) + offset += stat_struct.size + elif header.type == CheckVersion.type: + response = True + elif header.type == -1: + err = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + response = EXCEPTIONS[err]() + if response: + results.append(response) + header, offset = MultiHeader.deserialize(bytes, offset) + return results + + @staticmethod + def unchroot(client, response): + resp = [] + for result in response: + if isinstance(result, unicode): + resp.append(client.unchroot(result)) + else: + resp.append(result) + return resp + + class Auth(namedtuple('Auth', 'auth_type scheme auth')): type = 100 @@ -297,20 +355,35 @@ class Auth(namedtuple('Auth', 'auth_type scheme auth')): class Watch(namedtuple('Watch', 'type state path')): @classmethod - def deserialize(cls, buffer, offset): - """Given a buffer and the current buffer offset, return the + def deserialize(cls, bytes, offset): + """Given bytes and the current bytes offset, return the type, state, path, and new offset""" - type, state = int_int_struct.unpack_from(buffer, offset) + type, state = int_int_struct.unpack_from(bytes, offset) offset += int_int_struct.size - path, offset = read_string(buffer, offset) + path, offset = read_string(bytes, offset) return cls(type, state, path), offset class ReplyHeader(namedtuple('ReplyHeader', 'xid, zxid, err')): @classmethod - def deserialize(cls, buffer, offset): - """Given a buffer and the current buffer offset, return a + def deserialize(cls, bytes, offset): + """Given bytes and the current bytes offset, return a :class:`ReplyHeader` instance and the new offset""" new_offset = offset + reply_header_struct.size return cls._make( - reply_header_struct.unpack_from(buffer, offset)), new_offset + reply_header_struct.unpack_from(bytes, offset)), new_offset + + +class MultiHeader(namedtuple('MultiHeader', 'type done err')): + def serialize(self): + b = bytearray() + b.extend(int_struct.pack(self.type)) + b.extend([1 if self.done else 0]) + b.extend(int_struct.pack(self.err)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + t, done, err = multiheader_struct.unpack_from(bytes, offset) + offset += multiheader_struct.size + return cls(t, done is 1, err), offset diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index 48fa295..e165daf 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -617,6 +617,11 @@ class TestClient(KazooTestCase): client._safe_close() testit() + def test_client_state(self): + from kazoo.protocol.states import KeeperState + eq_(self.client.client_state, KeeperState.CONNECTED) + + dummy_dict = { 'aversion': 1, 'ctime': 0, 'cversion': 1, 'czxid': 110, 'dataLength': 1, 'ephemeralOwner': 'ben', @@ -624,6 +629,54 @@ dummy_dict = { } +class TestClientTransactions(KazooTestCase): + def test_basic_create(self): + t = self.client.transaction() + t.create('/freddy') + t.create('/fred', ephemeral=True) + t.create('/smith', sequence=True) + results = t.commit() + eq_(results[0], '/freddy') + eq_(len(results), 3) + self.assertTrue(results[2].startswith('/smith0')) + + def test_bad_creates(self): + args_list = [(True,), ('/smith', 0), ('/smith', '', 'bleh'), + ('/smith', '', None, 'fred'), + ('/smith', '', None, True, 'fred')] + + @raises(TypeError) + def testit(args): + t = self.client.transaction() + t.create(*args) + + for args in args_list: + testit(args) + + def test_default_acl(self): + from kazoo.security import make_digest_acl + username = uuid.uuid4().hex + password = uuid.uuid4().hex + + digest_auth = "%s:%s" % (username, password) + acl = make_digest_acl(username, password, all=True) + + self.client.add_auth("digest", digest_auth) + self.client.default_acl = (acl,) + + t = self.client.transaction() + t.create('/freddy') + results = t.commit() + eq_(results[0], '/freddy') + + def test_basic_delete(self): + self.client.create('/fred') + t = self.client.transaction() + t.delete('/fred') + results = t.commit() + eq_(results[0], True) + + class TestCallbacks(unittest.TestCase): def test_session_callback_states(self): from kazoo.protocol.states import KazooState, KeeperState |
