summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Wentz <Adam Wentz>2012-12-11 00:47:32 -0600
committerAsk Solem <ask@celeryproject.org>2013-01-11 14:20:45 +0000
commita792d13efd61183eef21206022acb6250d880528 (patch)
tree25199142a38641148dd19398f342f6293ada0908
parent3e9fdc5b386011c70e47d934ebaaf9d533ff2805 (diff)
downloadpy-amqp-a792d13efd61183eef21206022acb6250d880528.tar.gz
Add array type to protocol
This adds the array type. It was adapted from a patch on the original py-amqp google code repo http://code.google.com/p/py-amqplib/issues/detail?id=38
-rw-r--r--amqp/serialization.py131
-rwxr-xr-xfuntests/test_serialization.py42
2 files changed, 124 insertions, 49 deletions
diff --git a/amqp/serialization.py b/amqp/serialization.py
index c9da3f3..a9298d0 100644
--- a/amqp/serialization.py
+++ b/amqp/serialization.py
@@ -139,28 +139,45 @@ class AMQPReader(object):
result = {}
while table_data.input.tell() < tlen:
name = table_data.read_shortstr()
- ftype = ord(table_data.input.read(1))
- if ftype == 83: # 'S'
- val = table_data.read_longstr()
- elif ftype == 73: # 'I'
- val = unpack('>i', table_data.input.read(4))[0]
- elif ftype == 68: # 'D'
- d = table_data.read_octet()
- n = unpack('>i', table_data.input.read(4))[0]
- val = Decimal(n) / Decimal(10 ** d)
- elif ftype == 84: # 'T'
- val = table_data.read_timestamp()
- elif ftype == 70: # 'F'
- val = table_data.read_table() # recurse
- elif ftype == 116:
- val = table_data.read_bit()
- elif ftype == 100:
- val = table_data.read_float()
- else:
- raise ValueError('Unknown table item type: %s' % repr(ftype))
+ val = table_data.read_item()
result[name] = val
return result
+ def read_item(self):
+ ftype = ord(self.input.read(1))
+ if ftype == 83: # 'S'
+ val = self.read_longstr()
+ elif ftype == 73: # 'I'
+ val = unpack('>i', self.input.read(4))[0]
+ elif ftype == 68: # 'D'
+ d = self.read_octet()
+ n = unpack('>i', self.input.read(4))[0]
+ val = Decimal(n) / Decimal(10 ** d)
+ elif ftype == 84: # 'T'
+ val = self.read_timestamp()
+ elif ftype == 70: # 'F'
+ val = self.read_table() # recurse
+ elif ftype == 65: # 'A'
+ val = self.read_array()
+ elif ftype == 116:
+ val = self.read_bit()
+ elif ftype == 100:
+ val = self.read_float()
+ else:
+ raise ValueError(
+ 'Unknown value in table: %r (%r)' % (
+ ftype, type(ftype)))
+ return val
+
+ def read_array(self):
+ array_length = unpack('>I', self.input.read(4))[0]
+ array_data = AMQPReader(self.input.read(array_length))
+ result = []
+ while array_data.input.tell() < array_length:
+ val = array_data.read_item()
+ result.append(val)
+ return result
+
def read_timestamp(self):
"""Read and AMQP timestamp, which is a 64-bit integer representing
seconds since the Unix epoch in 1-second resolution.
@@ -287,40 +304,56 @@ class AMQPWriter(object):
table_data = AMQPWriter()
for k, v in d.iteritems():
table_data.write_shortstr(k)
- if isinstance(v, basestring):
- if isinstance(v, unicode):
- v = v.encode('utf-8')
- table_data.write(byte(83)) # 'S'
- table_data.write_longstr(v)
- elif isinstance(v, bool):
- table_data.write(pack('>cB', 't', int(v)))
- elif isinstance(v, float):
- table_data.write(pack('>cd', 'd', v))
- elif isinstance(v, (int, long)):
- table_data.write(pack('>ci', 'I', v))
- elif isinstance(v, Decimal):
- table_data.write(byte(68)) # 'D'
- sign, digits, exponent = v.as_tuple()
- v = 0
- for d in digits:
- v = (v * 10) + d
- if sign:
- v = -v
- table_data.write_octet(-exponent)
- table_data.write(pack('>i', v))
- elif isinstance(v, datetime):
- table_data.write(byte(84)) # 'T'
- table_data.write_timestamp(v)
- ## FIXME: timezone ?
- elif isinstance(v, dict):
- table_data.write(byte(70)) # 'F'
- table_data.write_table(v)
- else:
- raise ValueError('%r not serializable in AMQP' % (v, ))
+ table_data.write_item(v)
table_data = table_data.getvalue()
self.write_long(len(table_data))
self.out.write(table_data)
+ def write_item(self, v):
+ if isinstance(v, basestring):
+ if isinstance(v, unicode):
+ v = v.encode('utf-8')
+ self.write(byte(83)) # 'S'
+ self.write_longstr(v)
+ elif isinstance(v, bool):
+ self.write(pack('>cB', b't', int(v)))
+ elif isinstance(v, float):
+ self.write(pack('>cd', b'd', v))
+ elif isinstance(v, (int, long)):
+ self.write(pack('>ci', b'I', v))
+ elif isinstance(v, Decimal):
+ self.write(byte(68)) # 'D'
+ sign, digits, exponent = v.as_tuple()
+ v = 0
+ for d in digits:
+ v = (v * 10) + d
+ if sign:
+ v = -v
+ self.write_octet(-exponent)
+ self.write(pack('>i', v))
+ elif isinstance(v, datetime):
+ self.write(byte(84)) # 'T'
+ self.write_timestamp(v)
+ ## FIXME: timezone ?
+ elif isinstance(v, dict):
+ self.write(byte(70)) # 'F'
+ self.write_table(v)
+ elif isinstance(v, (list, tuple)):
+ self.write(byte(65)) # 'A'
+ self.write_array(v)
+ else:
+ raise ValueError(
+ 'Table type %r not handled by amqp: %r' % (
+ type(v), v))
+
+ def write_array(self, a):
+ array_data = AMQPWriter()
+ for v in a:
+ array_data.write_item(v)
+ array_data = array_data.getvalue()
+ self.write_long(len(array_data))
+ self.out.write(array_data)
+
def write_timestamp(self, v):
"""Write out a Python datetime.datetime object as a 64-bit integer
representing seconds since the Unix epoch."""
diff --git a/funtests/test_serialization.py b/funtests/test_serialization.py
index 3656a31..573ab95 100755
--- a/funtests/test_serialization.py
+++ b/funtests/test_serialization.py
@@ -347,6 +347,48 @@ class TestSerialization(unittest.TestCase):
self.assertEqual(r.read_table(), val)
#
+ # Array
+ #
+ def test_array_from_list(self):
+ val = [1, 'foo']
+ w = AMQPWriter()
+ w.write_array(val)
+ s = w.getvalue()
+
+ self.assertEqualBinary(s, '\x00\x00\x00\x0DI\x00\x00\x00\x01S\x00\x00\x00\x03foo')
+
+ r = AMQPReader(s)
+ self.assertEqual(r.read_array(), val)
+
+ def test_array_from_tuple(self):
+ val = (1, 'foo')
+ w = AMQPWriter()
+ w.write_array(val)
+ s = w.getvalue()
+
+ self.assertEqualBinary(s, '\x00\x00\x00\x0DI\x00\x00\x00\x01S\x00\x00\x00\x03foo')
+
+ r = AMQPReader(s)
+ self.assertEqual(r.read_array(), list(val))
+
+ def test_table_with_array(self):
+ val = {
+ 'foo': 7,
+ 'bar': Decimal('123345.1234'),
+ 'baz': 'this is some random string I typed',
+ 'blist': [1,2,3],
+ 'nlist': [1, [2,3,4]],
+ 'ndictl': {'nfoo': 8, 'nblist': [5,6,7] }
+ }
+
+ w = AMQPWriter()
+ w.write_table(val)
+ s = w.getvalue()
+
+ r = AMQPReader(s)
+ self.assertEqual(r.read_table(), val)
+
+ #
# GenericContent
#
def test_generic_content_eq(self):