diff options
author | Adam Wentz <Adam Wentz> | 2012-12-11 00:47:32 -0600 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2013-01-11 14:20:45 +0000 |
commit | a792d13efd61183eef21206022acb6250d880528 (patch) | |
tree | 25199142a38641148dd19398f342f6293ada0908 | |
parent | 3e9fdc5b386011c70e47d934ebaaf9d533ff2805 (diff) | |
download | py-amqp-a792d13efd61183eef21206022acb6250d880528.tar.gz |
Add array type to protocol
This adds the array type. It was adapted from a patch on the original
py-amqp google code repo
http://code.google.com/p/py-amqplib/issues/detail?id=38
-rw-r--r-- | amqp/serialization.py | 131 | ||||
-rwxr-xr-x | funtests/test_serialization.py | 42 |
2 files changed, 124 insertions, 49 deletions
diff --git a/amqp/serialization.py b/amqp/serialization.py index c9da3f3..a9298d0 100644 --- a/amqp/serialization.py +++ b/amqp/serialization.py @@ -139,28 +139,45 @@ class AMQPReader(object): result = {} while table_data.input.tell() < tlen: name = table_data.read_shortstr() - ftype = ord(table_data.input.read(1)) - if ftype == 83: # 'S' - val = table_data.read_longstr() - elif ftype == 73: # 'I' - val = unpack('>i', table_data.input.read(4))[0] - elif ftype == 68: # 'D' - d = table_data.read_octet() - n = unpack('>i', table_data.input.read(4))[0] - val = Decimal(n) / Decimal(10 ** d) - elif ftype == 84: # 'T' - val = table_data.read_timestamp() - elif ftype == 70: # 'F' - val = table_data.read_table() # recurse - elif ftype == 116: - val = table_data.read_bit() - elif ftype == 100: - val = table_data.read_float() - else: - raise ValueError('Unknown table item type: %s' % repr(ftype)) + val = table_data.read_item() result[name] = val return result + def read_item(self): + ftype = ord(self.input.read(1)) + if ftype == 83: # 'S' + val = self.read_longstr() + elif ftype == 73: # 'I' + val = unpack('>i', self.input.read(4))[0] + elif ftype == 68: # 'D' + d = self.read_octet() + n = unpack('>i', self.input.read(4))[0] + val = Decimal(n) / Decimal(10 ** d) + elif ftype == 84: # 'T' + val = self.read_timestamp() + elif ftype == 70: # 'F' + val = self.read_table() # recurse + elif ftype == 65: # 'A' + val = self.read_array() + elif ftype == 116: + val = self.read_bit() + elif ftype == 100: + val = self.read_float() + else: + raise ValueError( + 'Unknown value in table: %r (%r)' % ( + ftype, type(ftype))) + return val + + def read_array(self): + array_length = unpack('>I', self.input.read(4))[0] + array_data = AMQPReader(self.input.read(array_length)) + result = [] + while array_data.input.tell() < array_length: + val = array_data.read_item() + result.append(val) + return result + def read_timestamp(self): """Read and AMQP timestamp, which is a 64-bit integer representing seconds since the Unix epoch in 1-second resolution. @@ -287,40 +304,56 @@ class AMQPWriter(object): table_data = AMQPWriter() for k, v in d.iteritems(): table_data.write_shortstr(k) - if isinstance(v, basestring): - if isinstance(v, unicode): - v = v.encode('utf-8') - table_data.write(byte(83)) # 'S' - table_data.write_longstr(v) - elif isinstance(v, bool): - table_data.write(pack('>cB', 't', int(v))) - elif isinstance(v, float): - table_data.write(pack('>cd', 'd', v)) - elif isinstance(v, (int, long)): - table_data.write(pack('>ci', 'I', v)) - elif isinstance(v, Decimal): - table_data.write(byte(68)) # 'D' - sign, digits, exponent = v.as_tuple() - v = 0 - for d in digits: - v = (v * 10) + d - if sign: - v = -v - table_data.write_octet(-exponent) - table_data.write(pack('>i', v)) - elif isinstance(v, datetime): - table_data.write(byte(84)) # 'T' - table_data.write_timestamp(v) - ## FIXME: timezone ? - elif isinstance(v, dict): - table_data.write(byte(70)) # 'F' - table_data.write_table(v) - else: - raise ValueError('%r not serializable in AMQP' % (v, )) + table_data.write_item(v) table_data = table_data.getvalue() self.write_long(len(table_data)) self.out.write(table_data) + def write_item(self, v): + if isinstance(v, basestring): + if isinstance(v, unicode): + v = v.encode('utf-8') + self.write(byte(83)) # 'S' + self.write_longstr(v) + elif isinstance(v, bool): + self.write(pack('>cB', b't', int(v))) + elif isinstance(v, float): + self.write(pack('>cd', b'd', v)) + elif isinstance(v, (int, long)): + self.write(pack('>ci', b'I', v)) + elif isinstance(v, Decimal): + self.write(byte(68)) # 'D' + sign, digits, exponent = v.as_tuple() + v = 0 + for d in digits: + v = (v * 10) + d + if sign: + v = -v + self.write_octet(-exponent) + self.write(pack('>i', v)) + elif isinstance(v, datetime): + self.write(byte(84)) # 'T' + self.write_timestamp(v) + ## FIXME: timezone ? + elif isinstance(v, dict): + self.write(byte(70)) # 'F' + self.write_table(v) + elif isinstance(v, (list, tuple)): + self.write(byte(65)) # 'A' + self.write_array(v) + else: + raise ValueError( + 'Table type %r not handled by amqp: %r' % ( + type(v), v)) + + def write_array(self, a): + array_data = AMQPWriter() + for v in a: + array_data.write_item(v) + array_data = array_data.getvalue() + self.write_long(len(array_data)) + self.out.write(array_data) + def write_timestamp(self, v): """Write out a Python datetime.datetime object as a 64-bit integer representing seconds since the Unix epoch.""" diff --git a/funtests/test_serialization.py b/funtests/test_serialization.py index 3656a31..573ab95 100755 --- a/funtests/test_serialization.py +++ b/funtests/test_serialization.py @@ -347,6 +347,48 @@ class TestSerialization(unittest.TestCase): self.assertEqual(r.read_table(), val) # + # Array + # + def test_array_from_list(self): + val = [1, 'foo'] + w = AMQPWriter() + w.write_array(val) + s = w.getvalue() + + self.assertEqualBinary(s, '\x00\x00\x00\x0DI\x00\x00\x00\x01S\x00\x00\x00\x03foo') + + r = AMQPReader(s) + self.assertEqual(r.read_array(), val) + + def test_array_from_tuple(self): + val = (1, 'foo') + w = AMQPWriter() + w.write_array(val) + s = w.getvalue() + + self.assertEqualBinary(s, '\x00\x00\x00\x0DI\x00\x00\x00\x01S\x00\x00\x00\x03foo') + + r = AMQPReader(s) + self.assertEqual(r.read_array(), list(val)) + + def test_table_with_array(self): + val = { + 'foo': 7, + 'bar': Decimal('123345.1234'), + 'baz': 'this is some random string I typed', + 'blist': [1,2,3], + 'nlist': [1, [2,3,4]], + 'ndictl': {'nfoo': 8, 'nblist': [5,6,7] } + } + + w = AMQPWriter() + w.write_table(val) + s = w.getvalue() + + r = AMQPReader(s) + self.assertEqual(r.read_table(), val) + + # # GenericContent # def test_generic_content_eq(self): |