diff options
author | Ask Solem <ask@celeryproject.org> | 2014-06-04 18:50:02 +0100 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2014-06-04 18:50:02 +0100 |
commit | f075ae618c19e3c0fb155b3397c13d1c297ebe56 (patch) | |
tree | bc07c56dd1147149e7bf1312aa90cc9b56dc55dd | |
parent | b8b3def2f5f6c6bf5cae30d44a90c7998b3a3654 (diff) | |
download | py-amqp-f075ae618c19e3c0fb155b3397c13d1c297ebe56.tar.gz |
Refactors AMQPReader
-rw-r--r-- | amqp/abstract_channel.py | 36 | ||||
-rw-r--r-- | amqp/channel.py | 144 | ||||
-rw-r--r-- | amqp/connection.py | 120 | ||||
-rw-r--r-- | amqp/method_framing.py | 13 | ||||
-rw-r--r-- | amqp/serialization.py | 214 |
5 files changed, 356 insertions, 171 deletions
diff --git a/amqp/abstract_channel.py b/amqp/abstract_channel.py index dfc1ef8..bce82ff 100644 --- a/amqp/abstract_channel.py +++ b/amqp/abstract_channel.py @@ -18,11 +18,21 @@ from __future__ import absolute_import from .exceptions import AMQPNotImplementedError, RecoverableConnectionError from .promise import promise -from .serialization import dumps +from .serialization import dumps, loads __all__ = ['AbstractChannel'] +def inbound(method, spec=None, content=False): + + def _inner(fun): + fun.__amqp_method__ = method + fun.__amqp_argspec__ = spec + fun.__amqp_content__ = content + return fun + return _inner + + class AbstractChannel(object): """Superclass for both the Connection, which is treated as channel 0, and other user-created Channel objects. @@ -79,12 +89,12 @@ class AbstractChannel(object): def wait(self, allowed_methods=None): """Wait for a method that matches our allowed_methods parameter (the default value of None means match any method), and dispatch to it.""" - method_sig, args, content = self.connection._wait_method( + method_sig, payload, content = self.connection._wait_method( self.channel_id, allowed_methods) - return self.dispatch_method(method_sig, args, content) + return self.dispatch_method(method_sig, payload, content) - def dispatch_method(self, method_sig, args, content): + def dispatch_method(self, method_sig, payload, content): if content and \ self.auto_decode and \ hasattr(content, 'content_encoding'): @@ -99,10 +109,22 @@ class AbstractChannel(object): raise AMQPNotImplementedError( 'Unknown AMQP method {0!r}'.format(method_sig)) - if content is None: - return amqp_method(self, args) + try: + expects_content = amqp_method.__amqp_content__ + except AttributeError: + expects_content = False + + try: + argspec = amqp_method.__amqp_argspec__ + except AttributeError: + args = [] + else: + args = loads(argspec, payload, 4) if argspec else [] + + if expects_content: + return amqp_method(self, *args + [content]) else: - return amqp_method(self, args, content) + return amqp_method(self, *args) #: Placeholder, the concrete implementations will have to #: supply their own versions of _METHOD_MAP diff --git a/amqp/channel.py b/amqp/channel.py index 2ebb3fa..4195d87 100644 --- a/amqp/channel.py +++ b/amqp/channel.py @@ -21,10 +21,10 @@ import logging from collections import defaultdict from warnings import warn -from .abstract_channel import AbstractChannel +from .abstract_channel import AbstractChannel, inbound from .exceptions import ChannelError, ConsumerCancelled, error_for_code from .five import Queue -from .protocol import basic_return_t, queue_declare_ok_t +from .protocol import queue_declare_ok_t __all__ = ['Channel'] @@ -226,7 +226,7 @@ class Channel(AbstractChannel): finally: self.connection = None - def _close(self, args): + def _close(self, reply_code, reply_text, class_id, method_id): """Request a channel close This method indicates that the sender wants to close the @@ -274,11 +274,6 @@ class Channel(AbstractChannel): """ - reply_code = args.read_short() - reply_text = args.read_shortstr() - class_id = args.read_short() - method_id = args.read_short() - self._send_method(Channel_CloseOk) self._do_revive() @@ -286,7 +281,8 @@ class Channel(AbstractChannel): reply_code, reply_text, (class_id, method_id), ChannelError, ) - def _close_ok(self, args): + @inbound(Channel_CloseOk) + def _close_ok(self): """Confirm a channel close This method confirms a Channel.Close method and tells the @@ -352,7 +348,8 @@ class Channel(AbstractChannel): Channel_Flow, 'b', (active, ), wait=[Channel_FlowOk], ) - def _flow(self, args): + @inbound(Channel_Flow, 'b') + def _flow(self, active): """Enable/disable flow from peer This method asks the peer to pause or restart the flow of @@ -398,7 +395,7 @@ class Channel(AbstractChannel): False, the peer stops sending content frames. """ - self.active = args.read_bit() + self.active = active self._x_flow_ok(self.active) def _x_flow_ok(self, active): @@ -419,7 +416,8 @@ class Channel(AbstractChannel): """ return self.send_method(Channel_FlowOk, 'b', (active, )) - def _flow_ok(self, args): + @inbound(Channel_FlowOk, 'b') + def _flow_ok(self, active): """Confirm a flow method Confirms to the peer that a flow command was received and @@ -435,7 +433,7 @@ class Channel(AbstractChannel): to send content frames; False means it will not. """ - return args.read_bit() + return active def _x_open(self): """Open a channel for use @@ -464,7 +462,8 @@ class Channel(AbstractChannel): Channel_Open, 's', ('', ), wait=[Channel_OpenOk], ) - def _open_ok(self, args): + @inbound(Channel_OpenOk) + def _open_ok(self): """Signal that the channel is ready This method signals to the client that the channel is ready @@ -647,7 +646,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Exchange_DeclareOk], ) - def _exchange_declare_ok(self, args): + @inbound(Exchange_DeclareOk) + def _exchange_declare_ok(self): """Confirms an exchange declaration This method confirms a Declare method and confirms the name of @@ -705,7 +705,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Exchange_DeleteOk], ) - def _exchange_delete_ok(self, args): + @inbound(Exchange_DeleteOk) + def _exchange_delete_ok(self): """Confirm deletion of an exchange This method confirms the deletion of an exchange. @@ -850,7 +851,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Exchange_UnbindOk], ) - def _exchange_bind_ok(self, args): + @inbound(Exchange_BindOk) + def _exchange_bind_ok(self): """Confirm bind successful This method confirms that the bind was successful. @@ -858,7 +860,8 @@ class Channel(AbstractChannel): """ pass - def _exchange_unbind_ok(self, args): + @inbound(Exchange_UnbindOk) + def _exchange_unbind_ok(self): """Confirm unbind successful This method confirms that the unbind was successful. @@ -997,7 +1000,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Queue_BindOk], ) - def _queue_bind_ok(self, args): + @inbound(Queue_BindOk) + def _queue_bind_ok(self): """Confirm bind successful This method confirms that the bind was successful. @@ -1063,7 +1067,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Queue_UnbindOk], ) - def _queue_unbind_ok(self, args): + @inbound(Queue_UnbindOk) + def _queue_unbind_ok(self): """Confirm unbind successful This method confirms that the unbind was successful. @@ -1233,7 +1238,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Queue_DeclareOk], ) - def _queue_declare_ok(self, args): + @inbound(Queue_DeclareOk, 'sll') + def _queue_declare_ok(self, queue, message_count, consumer_count): """Confirms a queue definition This method confirms a Declare method and confirms the name of @@ -1262,11 +1268,7 @@ class Channel(AbstractChannel): this count. """ - return queue_declare_ok_t( - args.read_shortstr(), - args.read_long(), - args.read_long(), - ) + return queue_declare_ok_t(queue, message_count, consumer_count) def queue_delete(self, queue='', if_unused=False, if_empty=False, nowait=False, @@ -1342,7 +1344,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Queue_DeleteOk], ) - def _queue_delete_ok(self, args): + @inbound(Queue_DeleteOk, 'l') + def _queue_delete_ok(self, message_count): """Confirm deletion of a queue This method confirms the deletion of a queue. @@ -1355,7 +1358,7 @@ class Channel(AbstractChannel): Reports the number of messages purged. """ - return args.read_long() + return message_count def queue_purge(self, queue='', nowait=False, argsig='Bsb'): """Purge a queue @@ -1418,7 +1421,8 @@ class Channel(AbstractChannel): wait=None if nowait else [Queue_PurgeOk], ) - def _queue_purge_ok(self, args): + @inbound(Queue_PurgeOk, 'l') + def _queue_purge_ok(self, message_count): """Confirms a queue purge This method confirms the purge of a queue. @@ -1431,7 +1435,7 @@ class Channel(AbstractChannel): Reports the number of messages purged. """ - return args.read_long() + return message_count ############# # @@ -1589,20 +1593,21 @@ class Channel(AbstractChannel): wait=None if nowait else [Basic_CancelOk], ) - def _basic_cancel_notify(self, args): + @inbound(Basic_Cancel, 's') + def _basic_cancel_notify(self, consumer_tag): """Consumer cancelled by server. Most likely the queue was deleted. """ - consumer_tag = args.read_shortstr() callback = self._on_cancel(consumer_tag) if callback: callback(consumer_tag) else: raise ConsumerCancelled(consumer_tag, Basic_Cancel) - def _basic_cancel_ok(self, args): + @inbound(Basic_CancelOk, 's') + def _basic_cancel_ok(self, consumer_tag): """Confirm a cancelled consumer This method confirms that the cancellation was completed. @@ -1623,7 +1628,6 @@ class Channel(AbstractChannel): use it in another. """ - consumer_tag = args.read_shortstr() self._on_cancel(consumer_tag) def _on_cancel(self, consumer_tag): @@ -1749,7 +1753,8 @@ class Channel(AbstractChannel): self.no_ack_consumers.add(consumer_tag) return p - def _basic_consume_ok(self, args): + @inbound(Basic_ConsumeOk, 's') + def _basic_consume_ok(self, consumer_tag): """Confirm a new consumer The server provides the client with a consumer tag, which is @@ -1763,9 +1768,11 @@ class Channel(AbstractChannel): provided by the server. """ - return args.read_shortstr() + return consumer_tag - def _basic_deliver(self, args, msg): + @inbound(Basic_Deliver, 'sLbss', content=True) + def _basic_deliver(self, consumer_tag, delivery_tag, redelivered, + exchange, routing_key, msg): """Notify the client of a consumer message This method delivers a message to the client, via a consumer. @@ -1838,12 +1845,6 @@ class Channel(AbstractChannel): message was published. """ - consumer_tag = args.read_shortstr() - delivery_tag = args.read_longlong() - redelivered = args.read_bit() - exchange = args.read_shortstr() - routing_key = args.read_shortstr() - msg.channel = self msg.delivery_info = { 'consumer_tag': consumer_tag, @@ -1902,7 +1903,8 @@ class Channel(AbstractChannel): wait=[Basic_GetOk, Basic_GetEmpty], ) - def _basic_get_empty(self, args): + @inbound(Basic_GetEmpty, 's') + def _basic_get_empty(self, cluster_id): """Indicate no messages available This method tells the client that the queue has no messages @@ -1917,9 +1919,11 @@ class Channel(AbstractChannel): client applications. """ - cluster_id = args.read_shortstr() # noqa + pass - def _basic_get_ok(self, args, msg): + @inbound(Basic_GetOk, 'Lbssl', content=True) + def _basic_get_ok(self, delivery_tag, redelivered, exchange, routing_key, + message_count, msg): """Provide client with a message This method delivers a message to the client following a get @@ -1977,12 +1981,6 @@ class Channel(AbstractChannel): queue and removed by other clients. """ - delivery_tag = args.read_longlong() - redelivered = args.read_bit() - exchange = args.read_shortstr() - routing_key = args.read_shortstr() - message_count = args.read_long() - msg.channel = self msg.delivery_info = { 'delivery_tag': delivery_tag, @@ -2144,7 +2142,8 @@ class Channel(AbstractChannel): wait=[Basic_QosOk], ) - def _basic_qos_ok(self, args): + @inbound(Basic_QosOk) + def _basic_qos_ok(self): """Confirm the requested qos This method tells the client that the requested QoS levels @@ -2189,7 +2188,8 @@ class Channel(AbstractChannel): def basic_recover_async(self, requeue=False): return self.send_method(Basic_RecoverAsync, 'b', (requeue, )) - def _basic_recover_ok(self, args): + @inbound(Basic_RecoverOk) + def _basic_recover_ok(self): """In 0-9-1 the deprecated recover solicits a response.""" pass @@ -2265,7 +2265,9 @@ class Channel(AbstractChannel): """ return self.send_method(Basic_Reject, argsig, (delivery_tag, requeue)) - def _basic_return(self, args, msg): + @inbound(Basic_Return, 'Bsss', content=True) + def _basic_return(self, reply_code, reply_text, + exchange, routing_key, message): """Return a failed message This method returns an undeliverable message that was @@ -2298,13 +2300,14 @@ class Channel(AbstractChannel): message was published. """ - self.returned_messages.put(basic_return_t( - args.read_short(), - args.read_shortstr(), - args.read_shortstr(), - args.read_shortstr(), - msg, - )) + exc = error_for_code( + reply_code, reply_text, Basic_Return, ChannelError, + ) + handlers = self.events.get('basic_return') + if not handlers: + raise exc + for callback in handlers: + callback(exc, exchange, routing_key, message) ############# # @@ -2344,7 +2347,8 @@ class Channel(AbstractChannel): """ return self.send_method(Tx_Commit, wait=[Tx_CommitOk]) - def _tx_commit_ok(self, args): + @inbound(Tx_CommitOk) + def _tx_commit_ok(self): """Confirm a successful commit This method confirms to the client that the commit succeeded. @@ -2364,7 +2368,8 @@ class Channel(AbstractChannel): """ return self.send_method(Tx_Rollback, wait=[Tx_RollbackOk]) - def _tx_rollback_ok(self, args): + @inbound(Tx_RollbackOk) + def _tx_rollback_ok(self): """Confirm a successful rollback This method confirms to the client that the rollback @@ -2384,7 +2389,8 @@ class Channel(AbstractChannel): """ return self.send_method(Tx_Select, wait=[Tx_SelectOk]) - def _tx_select_ok(self, args): + @inbound(Tx_SelectOk) + def _tx_select_ok(self): """Confirm transaction mode This method confirms to the client that the channel was @@ -2411,14 +2417,14 @@ class Channel(AbstractChannel): wait=None if nowait else [Confirm_SelectOk], ) - def _confirm_select_ok(self, args): + @inbound(Confirm_SelectOk) + def _confirm_select_ok(self): """With this method the broker confirms to the client that the channel is now using publisher confirms.""" pass - def _basic_ack_recv(self, args): - delivery_tag = args.read_longlong() - multiple = args.read_bit() + @inbound(Basic_Ack, 'Lb') + def _basic_ack_recv(self, delivery_tag, multiple): for callback in self.events['basic_ack']: callback(delivery_tag, multiple) diff --git a/amqp/connection.py b/amqp/connection.py index 982df8e..0c9cef3 100644 --- a/amqp/connection.py +++ b/amqp/connection.py @@ -28,10 +28,10 @@ except ImportError: pass from . import __version__ -from .abstract_channel import AbstractChannel +from .abstract_channel import AbstractChannel, inbound from .channel import Channel from .exceptions import ( - AMQPNotImplementedError, ChannelError, ResourceError, + ChannelError, ResourceError, ConnectionForced, ConnectionError, error_for_code, RecoverableConnectionError, RecoverableChannelError, ) @@ -146,7 +146,6 @@ class Connection(AbstractChannel): login_response = login_response.getvalue()[4:] d = dict(LIBRARY_PROPERTIES, **client_properties or {}) - self._method_override = {(60, 50): self._dispatch_basic_return} self.channels = {} # The connection object itself is treated as channel 0 @@ -248,14 +247,14 @@ class Connection(AbstractChannel): # Nothing queued, need to wait for a method from the peer # while 1: - channel, method_sig, args, content = \ + channel, method_sig, payload, content = \ self.method_reader.read_method() if channel == channel_id and ( allowed_methods is None or method_sig in allowed_methods or method_sig == (20, 40)): - return method_sig, args, content + return method_sig, payload, content # # Certain methods like basic_return should be dispatched @@ -264,7 +263,7 @@ class Connection(AbstractChannel): # if channel and method_sig in self.Channel._IMMEDIATE_METHODS: self.channels[channel].dispatch_method( - method_sig, args, content, + method_sig, payload, content, ) continue @@ -273,7 +272,7 @@ class Connection(AbstractChannel): # this method for later # self.channels[channel].method_queue.append( - (method_sig, args, content), + (method_sig, payload, content), ) # @@ -298,31 +297,12 @@ class Connection(AbstractChannel): def drain_events(self, timeout=None): """Wait for an event on a channel.""" chanmap = self.channels - chanid, method_sig, args, content = self._wait_multiple( + chanid, method_sig, payload, content = self._wait_multiple( chanmap, None, timeout=timeout, ) channel = chanmap[chanid] - - if (content and - channel.auto_decode and - hasattr(content, 'content_encoding')): - try: - content.body = content.body.decode(content.content_encoding) - except Exception: - pass - - amqp_method = (self._method_override.get(method_sig) or - channel._METHOD_MAP.get(method_sig, None)) - - if amqp_method is None: - raise AMQPNotImplementedError( - 'Unknown AMQP method {0!r}'.format(method_sig)) - - if content is None: - return amqp_method(channel, args) - else: - return amqp_method(channel, args, content) + return channel.dispatch_method(method_sig, payload, content) def read_timeout(self, timeout=None): if timeout is None: @@ -355,24 +335,26 @@ class Connection(AbstractChannel): method_sig in allowed_methods or method_sig == (20, 40)): method_queue.remove(queued_method) - method_sig, args, content = queued_method - return channel_id, method_sig, args, content + method_sig, payload, content = queued_method + return channel_id, method_sig, payload, content # Nothing queued, need to wait for a method from the peer read_timeout = self.read_timeout wait = self.wait while 1: - channel, method_sig, args, content = read_timeout(timeout) + channel, method_sig, payload, content = read_timeout(timeout) if channel in channels and ( allowed_methods is None or method_sig in allowed_methods or method_sig == (20, 40)): - return channel, method_sig, args, content + return channel, method_sig, payload, content # Not the channel and/or method we were looking for. Queue # this method for later - channels[channel].method_queue.append((method_sig, args, content)) + channels[channel].method_queue.append( + (method_sig, payload, content), + ) # # If we just queued up a method for channel 0 (the Connection @@ -382,19 +364,6 @@ class Connection(AbstractChannel): if channel == 0: wait() - def _dispatch_basic_return(self, channel, args, msg): - reply_code = args.read_short() - reply_text = args.read_shortstr() - exchange = args.read_shortstr() - routing_key = args.read_shortstr() - - exc = error_for_code(reply_code, reply_text, (50, 60), ChannelError) - handlers = channel.events.get('basic_return') - if not handlers: - raise exc - for callback in handlers: - callback(exc, exchange, routing_key, msg) - def close(self, reply_code=0, reply_text='', method_sig=(0, 0), argsig='BssBB'): @@ -461,7 +430,8 @@ class Connection(AbstractChannel): wait=[Connection_Close, Connection_CloseOk], ) - def _close(self, args): + @inbound(Connection_Close, 'BsBB') + def _close(self, reply_code, reply_text, class_id, method_id): """Request a connection close This method indicates that the sender wants to close the @@ -515,30 +485,19 @@ class Connection(AbstractChannel): is the ID of the method. """ - reply_code = args.read_short() - reply_text = args.read_shortstr() - class_id = args.read_short() - method_id = args.read_short() - self._x_close_ok() - raise error_for_code(reply_code, reply_text, (class_id, method_id), ConnectionError) - def _blocked(self, args): + @inbound(Connection_Blocked) + def _blocked(self): """RabbitMQ Extension.""" - try: - reason = args.read_shortstr() - except UnicodeDecodeError: - # XXX Spec say this is a shortstr, but amqplib seems to - # except strings to be in utf-8, even though the spec does - # not dictate any special encoding. - # (see amqp.serialization:AMQPReader.read_shortstr) - reason = 'connection blocked, see broker logs' + reason = 'connection blocked, see broker logs' if self.on_blocked: return self.on_blocked(reason) - def _unblocked(self, *args): + @inbound(Connection_Unblocked) + def _unblocked(self): if self.on_unblocked: return self.on_unblocked() @@ -558,7 +517,8 @@ class Connection(AbstractChannel): self._send_method(Connection_CloseOk) self._do_close() - def _close_ok(self, args): + @inbound(Connection_CloseOk) + def _close_ok(self): """Confirm a connection close This method confirms a Connection.Close method and tells the @@ -627,7 +587,8 @@ class Connection(AbstractChannel): wait=[Connection_OpenOk], ) - def _open_ok(self, args): + @inbound(Connection_OpenOk) + def _open_ok(self): """Signal that the connection is ready This method signals to the client that the connection is ready @@ -639,7 +600,8 @@ class Connection(AbstractChannel): """ AMQP_LOGGER.debug('Open OK!') - def _secure(self, args): + @inbound(Connection_Secure, 's') + def _secure(self, challenge): """Security mechanism challenge The SASL protocol works by exchanging challenges and responses @@ -656,7 +618,7 @@ class Connection(AbstractChannel): passed to the security mechanism. """ - challenge = args.read_longstr() # noqa + pass def _x_secure_ok(self, response): """Security mechanism response @@ -676,7 +638,9 @@ class Connection(AbstractChannel): """ return self.send_method(Connection_SecureOk, 'S', (response, )) - def _start(self, args): + @inbound(Connection_Start, 'ooFSS') + def _start(self, version_major, version_minor, server_properties, + mechanisms, locales): """Start connection negotiation This method starts the connection negotiation process by @@ -740,12 +704,11 @@ class Connection(AbstractChannel): locale. """ - self.version_major = args.read_octet() - self.version_minor = args.read_octet() - self.server_properties = args.read_table() - self.mechanisms = args.read_longstr().split(' ') - self.locales = args.read_longstr().split(' ') - + self.version_major = version_major + self.version_minor = version_minor + self.server_properties = server_properties + self.mechanisms = mechanisms.split(' ') + self.locales = locales.split(' ') AMQP_LOGGER.debug( START_DEBUG_FMT, self.version_major, self.version_minor, @@ -815,7 +778,8 @@ class Connection(AbstractChannel): (client_properties, mechanism, response, locale), ) - def _tune(self, args): + @inbound(Connection_Tune, 'BlB') + def _tune(self, channel_max, frame_max, server_heartbeat): """Propose connection tuning parameters This method proposes a set of connection configuration values @@ -858,10 +822,10 @@ class Connection(AbstractChannel): """ client_heartbeat = self.client_heartbeat or 0 - self.channel_max = args.read_short() or self.channel_max - self.frame_max = args.read_long() or self.frame_max + self.channel_max = channel_max or self.channel_max + self.frame_max = frame_max or self.frame_max self.method_writer.frame_max = self.frame_max - self.server_heartbeat = args.read_short() or 0 + self.server_heartbeat = server_heartbeat or 0 # negotiate the heartbeat interval to the smaller of the # specified values diff --git a/amqp/method_framing.py b/amqp/method_framing.py index b454524..cedda44 100644 --- a/amqp/method_framing.py +++ b/amqp/method_framing.py @@ -17,12 +17,11 @@ from __future__ import absolute_import from collections import defaultdict, deque -from struct import pack, unpack +from struct import pack, unpack, unpack_from from .basic_message import Message from .exceptions import AMQPError, UnexpectedFrame from .five import range, string -from .serialization import AMQPReader __all__ = ['MethodReader'] @@ -132,21 +131,21 @@ class MethodReader(object): def _process_heartbeat(self, channel, payload): self.heartbeats += 1 - def _process_method_frame(self, channel, payload): + def _process_method_frame(self, channel, payload, + unpack_from=unpack_from): """Process Method frames""" - method_sig = unpack('>HH', payload[:4]) - args = AMQPReader(payload[4:]) + method_sig = unpack_from('>HH', payload, 0) if method_sig in _CONTENT_METHODS: # # Save what we've got so far and wait for the content-header # self.partial_messages[channel] = _PartialMessage( - method_sig, args, channel, + method_sig, payload, channel, ) self.expected_types[channel] = 2 else: - self._quick_put((channel, method_sig, args, None)) + self._quick_put((channel, method_sig, payload, None)) def _process_content_header(self, channel, payload): """Process Content Header frames""" diff --git a/amqp/serialization.py b/amqp/serialization.py index 8366a1e..cecd522 100644 --- a/amqp/serialization.py +++ b/amqp/serialization.py @@ -26,7 +26,7 @@ import sys from datetime import datetime from decimal import Decimal from io import BytesIO -from struct import pack, unpack +from struct import pack, unpack, unpack_from from time import mktime from .exceptions import FrameSyntaxError @@ -50,6 +50,200 @@ ILLEGAL_TABLE_TYPE = """\ """ +def _read_item(buf, offset=0, unpack=unpack_from): + ftype = buf[offset] + offset += 1 + + # 'S': long string + if ftype == 'S': + slen, = unpack('>I', buf, offset) + offset += 4 + val = buf[offset:offset + slen] + offset += slen + # 's': short string + elif ftype == 's': + slen, = unpack('>B', buf, offset) + offset += 1 + val = buf[offset:offset + slen] + offset += slen + # 'b': short-short int + elif ftype == 'b': + val, = unpack('>B', buf, offset) + offset += 1 + # 'B': short-short unsigned int + elif ftype == 'B': + val, = unpack('>b', buf, offset) + offset += 1 + # 'U': short int + elif ftype == 'U': + val, = unpack('>h', buf, offset) + offset += 2 + # 'u': short unsigned int + elif ftype == 'u': + val, = unpack('>H', buf, offset) + offset += 2 + # 'I': long int + elif ftype == 'I': + val, = unpack('>i', buf, offset) + offset += 4 + # 'i': long unsigned int + elif ftype == 'i': + val, = unpack('>I', buf, offset) + offset += 4 + # 'L': long long int + elif ftype == 'L': + val, = unpack('>q', buf, offset) + offset += 8 + # 'l': long long unsigned int + elif ftype == 'l': + val, = unpack('>Q', buf, offset) + offset += 8 + # 'f': float + elif ftype == 'f': + val, = unpack('>f', buf, offset) + offset += 4 + # 'd': double + elif ftype == 'd': + val, = unpack('>d', buf, offset) + offset += 8 + # 'D': decimal + elif ftype == 'D': + d, = unpack('>B', buf, offset) + offset += 1 + n, = unpack('>i', buf, offset) + offset += 4 + val = Decimal(n) / Decimal(10 ** d) + # 'F': table + elif ftype == 'F': + tlen, = unpack('>I', buf, offset) + offset += 4 + limit = offset + tlen + val = {} + while offset < limit: + keylen, = unpack('>B', buf, offset) + offset += 1 + key = buf[offset:offset + keylen] + offset += keylen + val[key], offset = _read_item(buf, offset) + # 'A': array + elif ftype == 'A': + alen, = unpack('>I', buf, offset) + offset += 4 + limit = offset + alen + val = [] + while offset < limit: + v, offset = _read_item(buf, offset) + val.append(v) + # 't' (bool) + elif ftype == 't': + val, = unpack('>B', buf, offset) + val = bool(val) + offset += 1 + # 'T': timestamp + elif ftype == 'T': + val, = unpack('>Q', buf, offset) + offset += 8 + val = datetime.utcfromtimestamp(val) + # 'V': void + elif ftype == 'V': + val = None + else: + raise FrameSyntaxError( + 'Unknown value in table: {0!r} ({1!r})'.format( + ftype, type(ftype))) + return val, offset + + +def loads(format, buf, offset=0, + ord=ord, unpack=unpack_from, _read_item=_read_item): + """ + bit = b + octet = o + short = B + long = l + long long = L + float = f + shortstr = s + longstr = S + table = F + array = A + """ + bitcount = bits = 0 + + values = [] + append = values.append + + for p in format: + if p == 'b': + if not bitcount: + bits = ord(buf[offset:offset + 1]) + bitcount = 8 + val = (bits & 1) == 1 + bits >>= 1 + bitcount -= 1 + offset += 1 + elif p == 'o': + bitcount = bits = 0 + val, = unpack('>B', buf, offset) + offset += 1 + elif p == 'B': + bitcount = bits = 0 + val, = unpack('>H', buf, offset) + offset += 2 + elif p == 'l': + bitcount = bits = 0 + val, = unpack('>I', buf, offset) + offset += 4 + elif p == 'L': + bitcount = bits = 0 + val, = unpack('>Q', buf, offset) + offset += 8 + elif p == 'f': + bitcount = bits = 0 + val, = unpack('>d', buf, offset) + offset += 8 + elif p == 's': + bitcount = bits = 0 + slen, = unpack('B', buf, offset) + offset += 1 + val = buf[offset:offset + slen].decode('utf-8') + offset += slen + elif p == 'S': + bitcount = bits = 0 + slen, = unpack('>I', buf, offset) + offset += 4 + val = buf[offset:offset + slen].decode('utf-8') + offset += slen + elif p == 'F': + bitcount = bits = 0 + tlen, = unpack('>I', buf, offset) + offset += 4 + limit = offset + tlen + val = {} + while offset < limit: + keylen, = unpack('>B', buf, offset) + offset += 1 + key = buf[offset:offset + keylen] + offset += keylen + val[key], offset = _read_item(buf, offset) + elif p == 'A': + bitcount = bits = 0 + alen, = unpack('>I', buf, offset) + offset += 4 + limit = offset + alen + val = [] + while offset < limit: + aval, offset = _read_item(buf, offset) + val.append(aval) + elif p == 'T': + bitcount = bits = 0 + val, = unpack('>Q', buf, offset) + offset += 8 + val = datetime.fromtimestamp(val) + append(val) + return values + + class AMQPReader(object): """Read higher-level AMQP types from a bytestream.""" def __init__(self, source): @@ -261,39 +455,39 @@ def dumps(format, values): bits.append(0) bits[-1] |= (val << shift) bitcount += 1 - if p == 'o': + elif p == 'o': bitcount = _flushbits(bits, write) write(pack('B', val)) - if p == 'B': + elif p == 'B': bitcount = _flushbits(bits, write) write(pack('>H', int(val))) - if p == 'l': + elif p == 'l': bitcount = _flushbits(bits, write) write(pack('>I', val)) - if p == 'L': + elif p == 'L': bitcount = _flushbits(bits, write) write(pack('>Q', val)) - if p == 's': + elif p == 's': val = val or '' bitcount = _flushbits(bits, write) if isinstance(val, string): val = val.encode('utf-8') write(pack('B', len(val))) write(val) - if p == 'S': + elif p == 'S': val = val or '' bitcount = _flushbits(bits, write) if isinstance(val, string): val = val.encode('utf-8') write(pack('>I', len(val))) write(val) - if p == 'F': + elif p == 'F': bitcount = _flushbits(bits, write) _write_table(val or {}, write, bits) - if p == 'A': + elif p == 'A': bitcount = _flushbits(bits, write) _write_array(val or [], write, bits) - if p == 'T': + elif p == 'T': write(pack('>q', long_t(mktime(val.timetuple())))) _flushbits(bits, write) |