summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBen Bangert <ben@groovie.org>2012-09-07 17:00:29 -0700
committerBen Bangert <ben@groovie.org>2012-09-07 17:00:29 -0700
commit2f73dab9d0cfec68d94de18b2910e8d97aefc724 (patch)
treeddff63bc4ba52a2235c713d5c77c1fb736206d8e
parent598eb80cafe5fa0c6a474f9f4ac157d79afcaa39 (diff)
downloadkazoo-2f73dab9d0cfec68d94de18b2910e8d97aefc724.tar.gz
Add transactions and some tests for them.
-rw-r--r--docs/api/client.rst4
-rw-r--r--kazoo/client.py148
-rw-r--r--kazoo/exceptions.py2
-rw-r--r--kazoo/protocol/connection.py6
-rw-r--r--kazoo/protocol/serialization.py87
-rw-r--r--kazoo/tests/test_client.py53
6 files changed, 291 insertions, 9 deletions
diff --git a/docs/api/client.rst b/docs/api/client.rst
index a6a0ab0..8efe955 100644
--- a/docs/api/client.rst
+++ b/docs/api/client.rst
@@ -29,3 +29,7 @@ Public API
A :class:`~kazoo.protocol.states.KazooState` attribute indicating
the current higher-level connection state.
+
+ .. autoclass:: TransactionRequest
+ :members:
+ :member-order: bysource
diff --git a/kazoo/client.py b/kazoo/client.py
index 39444af..49d9d77 100644
--- a/kazoo/client.py
+++ b/kazoo/client.py
@@ -22,6 +22,7 @@ from kazoo.protocol.paths import normpath
from kazoo.protocol.paths import _prefix_root
from kazoo.protocol.serialization import (
Auth,
+ CheckVersion,
Close,
Create,
Delete,
@@ -32,7 +33,8 @@ from kazoo.protocol.serialization import (
SetACL,
GetData,
SetData,
- Sync
+ Sync,
+ Transaction
)
from kazoo.protocol.states import KazooState
from kazoo.protocol.states import KeeperState
@@ -936,6 +938,20 @@ class KazooClient(object):
async_result)
return async_result
+ def transaction(self):
+ """Create and return a :class:`TransactionRequest` object
+
+ Creates a :class:`TransactionRequest` object. A Transaction can
+ consist of multiple operations which can be committed as a
+ single atomic unit. Either all of the operations will succeed
+ or none of them.
+
+ :returns: A TransactionRequest.
+ :rtype: :class:`TransactionRequest`
+
+ """
+ return TransactionRequest(self)
+
def delete(self, path, version=-1, recursive=False):
"""Delete a node.
@@ -1009,3 +1025,133 @@ class KazooClient(object):
self.delete(path)
except NoNodeError: # pragma: nocover
pass
+
+
+class TransactionRequest(object):
+ """A Zookeeper Transaction Request
+
+ A Transaction provides a builder object that can be used to
+ construct and commit an atomic set of operations. The transaction
+ must be committed before its sent.
+
+ Transactions are not thread-safe and should not be accessed from
+ multiple threads at once.
+
+ """
+ def __init__(self, client):
+ self.client = client
+ self.operations = []
+ self.committed = False
+
+ def create(self, path, value="", acl=None, ephemeral=False,
+ sequence=False):
+ """Add a create ZNode to the transaction. Takes the same
+ arguments as :meth:`KazooClient.create`, with the exception
+ of `makepath`.
+
+ :returns: None
+
+ """
+ if acl is None and self.client.default_acl:
+ acl = self.client.default_acl
+
+ if not isinstance(path, basestring):
+ raise TypeError("path must be a string")
+ if acl and not isinstance(acl, (tuple, list)):
+ raise TypeError("acl must be a tuple/list of ACL's")
+ if not isinstance(value, str):
+ raise TypeError("value must be a byte string")
+ if not isinstance(ephemeral, bool):
+ raise TypeError("ephemeral must be a bool")
+ if not isinstance(sequence, bool):
+ raise TypeError("sequence must be a bool")
+
+ flags = 0
+ if ephemeral:
+ flags |= 1
+ if sequence:
+ flags |= 2
+ if acl is None:
+ acl = OPEN_ACL_UNSAFE
+
+ self._add(Create(_prefix_root(self.client.chroot, path), value, acl,
+ flags), None)
+
+ def delete(self, path, version=-1):
+ """Add a delete ZNode to the transaction. Takes the same
+ arguments as :meth:`KazooClient.delete`, with the exception of
+ `recursive`.
+
+ """
+ if not isinstance(path, basestring):
+ raise TypeError("path must be a string")
+ if not isinstance(version, int):
+ raise TypeError("version must be an int")
+ self._add(Delete(_prefix_root(self.client.chroot, path), version))
+
+ def set_data(self, path, data, version=-1):
+ """Add a set ZNode value to the transaction. Takes the same
+ arguments as :meth:`KazooClient.set`.
+
+ """
+ if not isinstance(path, basestring):
+ raise TypeError("path must be a string")
+ if not isinstance(data, basestring):
+ raise TypeError("data must be a string")
+ if not isinstance(version, int):
+ raise TypeError("version must be an int")
+ self._add(SetData(_prefix_root(self.client.chroot, path), data,
+ version))
+
+ def check(self, path, version):
+ """Add a Check Version to the transaction.
+
+ This command will fail and abort a transaction if the path
+ does not match the specified version.
+
+ """
+ if not isinstance(path, basestring):
+ raise TypeError("path must be a string")
+ if not isinstance(version, int):
+ raise TypeError("version must be an int")
+ self._add(CheckVersion(_prefix_root(self.client.chroot, path),
+ version))
+
+ def commit_async(self):
+ """Commit the transaction asynchronously
+
+ :rtype: :class:`~kazoo.interfaces.IAsyncResult`
+
+ """
+ self._check_tx_state()
+ self.committed = True
+ async_object = self.client.handler.async_result()
+ self.client._call(Transaction(self.operations), async_object)
+ return async_object
+
+ def commit(self):
+ """Commit the transaction
+
+ :returns: A list of the results for each operation in the
+ transaction.
+
+ """
+ return self.commit_async().get()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ """Commit and cleanup accumulated transaction data"""
+ if not exc_type:
+ self.commit()
+
+ def _check_tx_state(self):
+ if self.committed:
+ raise ValueError('Transaction already committed')
+
+ def _add(self, request, post_processor=None):
+ self._check_tx_state()
+ if self.client.log_debug:
+ log.debug('Added %r to %r', request, self)
+ self.operations.append(request)
diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py
index 6ab092f..66366aa 100644
--- a/kazoo/exceptions.py
+++ b/kazoo/exceptions.py
@@ -13,7 +13,7 @@ class ZookeeperError(KazooException):
class CancelledError(KazooException):
- """Raised when a process is cancelled by another thread"""
+ """Raised when a process is canceled by another thread"""
class ConfigurationError(KazooException):
diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py
index db776a8..2225de9 100644
--- a/kazoo/protocol/connection.py
+++ b/kazoo/protocol/connection.py
@@ -22,6 +22,7 @@ from kazoo.protocol.serialization import (
GetChildren,
Ping,
ReplyHeader,
+ Transaction,
Watch,
int_struct
)
@@ -297,6 +298,11 @@ class ConnectionHandler(object):
async_object.set_exception(exc)
return
log.debug('Received response: %r', response)
+
+ # We special case a Transaction as we have to unchroot things
+ if request.type == Transaction.type:
+ response = Transaction.unchroot(client, response)
+
async_object.set(response)
# Determine if watchers should be registered
diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py
index 474ba15..8324f3c 100644
--- a/kazoo/protocol/serialization.py
+++ b/kazoo/protocol/serialization.py
@@ -2,6 +2,7 @@
from collections import namedtuple
import struct
+from kazoo.exceptions import EXCEPTIONS
from kazoo.protocol.states import ZnodeStat
from kazoo.security import ACL
from kazoo.security import Id
@@ -13,6 +14,7 @@ int_int_struct = struct.Struct('!ii')
int_int_long_struct = struct.Struct('!iiq')
int_long_int_long_struct = struct.Struct('!iqiq')
+multiheader_struct = struct.Struct('!iBi')
reply_header_struct = struct.Struct('!iqi')
stat_struct = struct.Struct('!qqqqiiiqiiq')
@@ -287,6 +289,62 @@ class GetChildren2(namedtuple('GetChildren2', 'path watcher')):
return children, stat
+class CheckVersion(namedtuple('CheckVersion', 'path version')):
+ type = 13
+
+ def serialize(self):
+ b = bytearray()
+ b.extend(write_string(self.path))
+ b.extend(int_struct.pack(self.version))
+ return b
+
+
+class Transaction(namedtuple('Transaction', 'operations')):
+ type = 14
+
+ def serialize(self):
+ b = bytearray()
+ for op in self.operations:
+ b.extend(MultiHeader(op.type, False, -1).serialize() +
+ op.serialize())
+ return b + multiheader_struct.pack(-1, True, -1)
+
+ @classmethod
+ def deserialize(cls, bytes, offset):
+ header = MultiHeader(None, False, None)
+ results = []
+ response = None
+ while not header.done:
+ if header.type == Create.type:
+ response, offset = read_string(bytes, offset)
+ elif header.type == Delete.type:
+ response = True
+ elif header.type == SetData.type:
+ response = ZnodeStat._make(
+ stat_struct.unpack_from(bytes, offset))
+ offset += stat_struct.size
+ elif header.type == CheckVersion.type:
+ response = True
+ elif header.type == -1:
+ err = int_struct.unpack_from(bytes, offset)[0]
+ offset += int_struct.size
+ response = EXCEPTIONS[err]()
+ if response:
+ results.append(response)
+ header, offset = MultiHeader.deserialize(bytes, offset)
+ return results
+
+ @staticmethod
+ def unchroot(client, response):
+ resp = []
+ for result in response:
+ if isinstance(result, unicode):
+ resp.append(client.unchroot(result))
+ else:
+ resp.append(result)
+ return resp
+
+
class Auth(namedtuple('Auth', 'auth_type scheme auth')):
type = 100
@@ -297,20 +355,35 @@ class Auth(namedtuple('Auth', 'auth_type scheme auth')):
class Watch(namedtuple('Watch', 'type state path')):
@classmethod
- def deserialize(cls, buffer, offset):
- """Given a buffer and the current buffer offset, return the
+ def deserialize(cls, bytes, offset):
+ """Given bytes and the current bytes offset, return the
type, state, path, and new offset"""
- type, state = int_int_struct.unpack_from(buffer, offset)
+ type, state = int_int_struct.unpack_from(bytes, offset)
offset += int_int_struct.size
- path, offset = read_string(buffer, offset)
+ path, offset = read_string(bytes, offset)
return cls(type, state, path), offset
class ReplyHeader(namedtuple('ReplyHeader', 'xid, zxid, err')):
@classmethod
- def deserialize(cls, buffer, offset):
- """Given a buffer and the current buffer offset, return a
+ def deserialize(cls, bytes, offset):
+ """Given bytes and the current bytes offset, return a
:class:`ReplyHeader` instance and the new offset"""
new_offset = offset + reply_header_struct.size
return cls._make(
- reply_header_struct.unpack_from(buffer, offset)), new_offset
+ reply_header_struct.unpack_from(bytes, offset)), new_offset
+
+
+class MultiHeader(namedtuple('MultiHeader', 'type done err')):
+ def serialize(self):
+ b = bytearray()
+ b.extend(int_struct.pack(self.type))
+ b.extend([1 if self.done else 0])
+ b.extend(int_struct.pack(self.err))
+ return b
+
+ @classmethod
+ def deserialize(cls, bytes, offset):
+ t, done, err = multiheader_struct.unpack_from(bytes, offset)
+ offset += multiheader_struct.size
+ return cls(t, done is 1, err), offset
diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py
index 48fa295..e165daf 100644
--- a/kazoo/tests/test_client.py
+++ b/kazoo/tests/test_client.py
@@ -617,6 +617,11 @@ class TestClient(KazooTestCase):
client._safe_close()
testit()
+ def test_client_state(self):
+ from kazoo.protocol.states import KeeperState
+ eq_(self.client.client_state, KeeperState.CONNECTED)
+
+
dummy_dict = {
'aversion': 1, 'ctime': 0, 'cversion': 1,
'czxid': 110, 'dataLength': 1, 'ephemeralOwner': 'ben',
@@ -624,6 +629,54 @@ dummy_dict = {
}
+class TestClientTransactions(KazooTestCase):
+ def test_basic_create(self):
+ t = self.client.transaction()
+ t.create('/freddy')
+ t.create('/fred', ephemeral=True)
+ t.create('/smith', sequence=True)
+ results = t.commit()
+ eq_(results[0], '/freddy')
+ eq_(len(results), 3)
+ self.assertTrue(results[2].startswith('/smith0'))
+
+ def test_bad_creates(self):
+ args_list = [(True,), ('/smith', 0), ('/smith', '', 'bleh'),
+ ('/smith', '', None, 'fred'),
+ ('/smith', '', None, True, 'fred')]
+
+ @raises(TypeError)
+ def testit(args):
+ t = self.client.transaction()
+ t.create(*args)
+
+ for args in args_list:
+ testit(args)
+
+ def test_default_acl(self):
+ from kazoo.security import make_digest_acl
+ username = uuid.uuid4().hex
+ password = uuid.uuid4().hex
+
+ digest_auth = "%s:%s" % (username, password)
+ acl = make_digest_acl(username, password, all=True)
+
+ self.client.add_auth("digest", digest_auth)
+ self.client.default_acl = (acl,)
+
+ t = self.client.transaction()
+ t.create('/freddy')
+ results = t.commit()
+ eq_(results[0], '/freddy')
+
+ def test_basic_delete(self):
+ self.client.create('/fred')
+ t = self.client.transaction()
+ t.delete('/fred')
+ results = t.commit()
+ eq_(results[0], True)
+
+
class TestCallbacks(unittest.TestCase):
def test_session_callback_states(self):
from kazoo.protocol.states import KazooState, KeeperState