summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2014-06-04 18:50:02 +0100
committerAsk Solem <ask@celeryproject.org>2014-06-04 18:50:02 +0100
commitf075ae618c19e3c0fb155b3397c13d1c297ebe56 (patch)
treebc07c56dd1147149e7bf1312aa90cc9b56dc55dd
parentb8b3def2f5f6c6bf5cae30d44a90c7998b3a3654 (diff)
downloadpy-amqp-f075ae618c19e3c0fb155b3397c13d1c297ebe56.tar.gz
Refactors AMQPReader
-rw-r--r--amqp/abstract_channel.py36
-rw-r--r--amqp/channel.py144
-rw-r--r--amqp/connection.py120
-rw-r--r--amqp/method_framing.py13
-rw-r--r--amqp/serialization.py214
5 files changed, 356 insertions, 171 deletions
diff --git a/amqp/abstract_channel.py b/amqp/abstract_channel.py
index dfc1ef8..bce82ff 100644
--- a/amqp/abstract_channel.py
+++ b/amqp/abstract_channel.py
@@ -18,11 +18,21 @@ from __future__ import absolute_import
from .exceptions import AMQPNotImplementedError, RecoverableConnectionError
from .promise import promise
-from .serialization import dumps
+from .serialization import dumps, loads
__all__ = ['AbstractChannel']
+def inbound(method, spec=None, content=False):
+
+ def _inner(fun):
+ fun.__amqp_method__ = method
+ fun.__amqp_argspec__ = spec
+ fun.__amqp_content__ = content
+ return fun
+ return _inner
+
+
class AbstractChannel(object):
"""Superclass for both the Connection, which is treated
as channel 0, and other user-created Channel objects.
@@ -79,12 +89,12 @@ class AbstractChannel(object):
def wait(self, allowed_methods=None):
"""Wait for a method that matches our allowed_methods parameter (the
default value of None means match any method), and dispatch to it."""
- method_sig, args, content = self.connection._wait_method(
+ method_sig, payload, content = self.connection._wait_method(
self.channel_id, allowed_methods)
- return self.dispatch_method(method_sig, args, content)
+ return self.dispatch_method(method_sig, payload, content)
- def dispatch_method(self, method_sig, args, content):
+ def dispatch_method(self, method_sig, payload, content):
if content and \
self.auto_decode and \
hasattr(content, 'content_encoding'):
@@ -99,10 +109,22 @@ class AbstractChannel(object):
raise AMQPNotImplementedError(
'Unknown AMQP method {0!r}'.format(method_sig))
- if content is None:
- return amqp_method(self, args)
+ try:
+ expects_content = amqp_method.__amqp_content__
+ except AttributeError:
+ expects_content = False
+
+ try:
+ argspec = amqp_method.__amqp_argspec__
+ except AttributeError:
+ args = []
+ else:
+ args = loads(argspec, payload, 4) if argspec else []
+
+ if expects_content:
+ return amqp_method(self, *args + [content])
else:
- return amqp_method(self, args, content)
+ return amqp_method(self, *args)
#: Placeholder, the concrete implementations will have to
#: supply their own versions of _METHOD_MAP
diff --git a/amqp/channel.py b/amqp/channel.py
index 2ebb3fa..4195d87 100644
--- a/amqp/channel.py
+++ b/amqp/channel.py
@@ -21,10 +21,10 @@ import logging
from collections import defaultdict
from warnings import warn
-from .abstract_channel import AbstractChannel
+from .abstract_channel import AbstractChannel, inbound
from .exceptions import ChannelError, ConsumerCancelled, error_for_code
from .five import Queue
-from .protocol import basic_return_t, queue_declare_ok_t
+from .protocol import queue_declare_ok_t
__all__ = ['Channel']
@@ -226,7 +226,7 @@ class Channel(AbstractChannel):
finally:
self.connection = None
- def _close(self, args):
+ def _close(self, reply_code, reply_text, class_id, method_id):
"""Request a channel close
This method indicates that the sender wants to close the
@@ -274,11 +274,6 @@ class Channel(AbstractChannel):
"""
- reply_code = args.read_short()
- reply_text = args.read_shortstr()
- class_id = args.read_short()
- method_id = args.read_short()
-
self._send_method(Channel_CloseOk)
self._do_revive()
@@ -286,7 +281,8 @@ class Channel(AbstractChannel):
reply_code, reply_text, (class_id, method_id), ChannelError,
)
- def _close_ok(self, args):
+ @inbound(Channel_CloseOk)
+ def _close_ok(self):
"""Confirm a channel close
This method confirms a Channel.Close method and tells the
@@ -352,7 +348,8 @@ class Channel(AbstractChannel):
Channel_Flow, 'b', (active, ), wait=[Channel_FlowOk],
)
- def _flow(self, args):
+ @inbound(Channel_Flow, 'b')
+ def _flow(self, active):
"""Enable/disable flow from peer
This method asks the peer to pause or restart the flow of
@@ -398,7 +395,7 @@ class Channel(AbstractChannel):
False, the peer stops sending content frames.
"""
- self.active = args.read_bit()
+ self.active = active
self._x_flow_ok(self.active)
def _x_flow_ok(self, active):
@@ -419,7 +416,8 @@ class Channel(AbstractChannel):
"""
return self.send_method(Channel_FlowOk, 'b', (active, ))
- def _flow_ok(self, args):
+ @inbound(Channel_FlowOk, 'b')
+ def _flow_ok(self, active):
"""Confirm a flow method
Confirms to the peer that a flow command was received and
@@ -435,7 +433,7 @@ class Channel(AbstractChannel):
to send content frames; False means it will not.
"""
- return args.read_bit()
+ return active
def _x_open(self):
"""Open a channel for use
@@ -464,7 +462,8 @@ class Channel(AbstractChannel):
Channel_Open, 's', ('', ), wait=[Channel_OpenOk],
)
- def _open_ok(self, args):
+ @inbound(Channel_OpenOk)
+ def _open_ok(self):
"""Signal that the channel is ready
This method signals to the client that the channel is ready
@@ -647,7 +646,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Exchange_DeclareOk],
)
- def _exchange_declare_ok(self, args):
+ @inbound(Exchange_DeclareOk)
+ def _exchange_declare_ok(self):
"""Confirms an exchange declaration
This method confirms a Declare method and confirms the name of
@@ -705,7 +705,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Exchange_DeleteOk],
)
- def _exchange_delete_ok(self, args):
+ @inbound(Exchange_DeleteOk)
+ def _exchange_delete_ok(self):
"""Confirm deletion of an exchange
This method confirms the deletion of an exchange.
@@ -850,7 +851,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Exchange_UnbindOk],
)
- def _exchange_bind_ok(self, args):
+ @inbound(Exchange_BindOk)
+ def _exchange_bind_ok(self):
"""Confirm bind successful
This method confirms that the bind was successful.
@@ -858,7 +860,8 @@ class Channel(AbstractChannel):
"""
pass
- def _exchange_unbind_ok(self, args):
+ @inbound(Exchange_UnbindOk)
+ def _exchange_unbind_ok(self):
"""Confirm unbind successful
This method confirms that the unbind was successful.
@@ -997,7 +1000,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Queue_BindOk],
)
- def _queue_bind_ok(self, args):
+ @inbound(Queue_BindOk)
+ def _queue_bind_ok(self):
"""Confirm bind successful
This method confirms that the bind was successful.
@@ -1063,7 +1067,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Queue_UnbindOk],
)
- def _queue_unbind_ok(self, args):
+ @inbound(Queue_UnbindOk)
+ def _queue_unbind_ok(self):
"""Confirm unbind successful
This method confirms that the unbind was successful.
@@ -1233,7 +1238,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Queue_DeclareOk],
)
- def _queue_declare_ok(self, args):
+ @inbound(Queue_DeclareOk, 'sll')
+ def _queue_declare_ok(self, queue, message_count, consumer_count):
"""Confirms a queue definition
This method confirms a Declare method and confirms the name of
@@ -1262,11 +1268,7 @@ class Channel(AbstractChannel):
this count.
"""
- return queue_declare_ok_t(
- args.read_shortstr(),
- args.read_long(),
- args.read_long(),
- )
+ return queue_declare_ok_t(queue, message_count, consumer_count)
def queue_delete(self, queue='',
if_unused=False, if_empty=False, nowait=False,
@@ -1342,7 +1344,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Queue_DeleteOk],
)
- def _queue_delete_ok(self, args):
+ @inbound(Queue_DeleteOk, 'l')
+ def _queue_delete_ok(self, message_count):
"""Confirm deletion of a queue
This method confirms the deletion of a queue.
@@ -1355,7 +1358,7 @@ class Channel(AbstractChannel):
Reports the number of messages purged.
"""
- return args.read_long()
+ return message_count
def queue_purge(self, queue='', nowait=False, argsig='Bsb'):
"""Purge a queue
@@ -1418,7 +1421,8 @@ class Channel(AbstractChannel):
wait=None if nowait else [Queue_PurgeOk],
)
- def _queue_purge_ok(self, args):
+ @inbound(Queue_PurgeOk, 'l')
+ def _queue_purge_ok(self, message_count):
"""Confirms a queue purge
This method confirms the purge of a queue.
@@ -1431,7 +1435,7 @@ class Channel(AbstractChannel):
Reports the number of messages purged.
"""
- return args.read_long()
+ return message_count
#############
#
@@ -1589,20 +1593,21 @@ class Channel(AbstractChannel):
wait=None if nowait else [Basic_CancelOk],
)
- def _basic_cancel_notify(self, args):
+ @inbound(Basic_Cancel, 's')
+ def _basic_cancel_notify(self, consumer_tag):
"""Consumer cancelled by server.
Most likely the queue was deleted.
"""
- consumer_tag = args.read_shortstr()
callback = self._on_cancel(consumer_tag)
if callback:
callback(consumer_tag)
else:
raise ConsumerCancelled(consumer_tag, Basic_Cancel)
- def _basic_cancel_ok(self, args):
+ @inbound(Basic_CancelOk, 's')
+ def _basic_cancel_ok(self, consumer_tag):
"""Confirm a cancelled consumer
This method confirms that the cancellation was completed.
@@ -1623,7 +1628,6 @@ class Channel(AbstractChannel):
use it in another.
"""
- consumer_tag = args.read_shortstr()
self._on_cancel(consumer_tag)
def _on_cancel(self, consumer_tag):
@@ -1749,7 +1753,8 @@ class Channel(AbstractChannel):
self.no_ack_consumers.add(consumer_tag)
return p
- def _basic_consume_ok(self, args):
+ @inbound(Basic_ConsumeOk, 's')
+ def _basic_consume_ok(self, consumer_tag):
"""Confirm a new consumer
The server provides the client with a consumer tag, which is
@@ -1763,9 +1768,11 @@ class Channel(AbstractChannel):
provided by the server.
"""
- return args.read_shortstr()
+ return consumer_tag
- def _basic_deliver(self, args, msg):
+ @inbound(Basic_Deliver, 'sLbss', content=True)
+ def _basic_deliver(self, consumer_tag, delivery_tag, redelivered,
+ exchange, routing_key, msg):
"""Notify the client of a consumer message
This method delivers a message to the client, via a consumer.
@@ -1838,12 +1845,6 @@ class Channel(AbstractChannel):
message was published.
"""
- consumer_tag = args.read_shortstr()
- delivery_tag = args.read_longlong()
- redelivered = args.read_bit()
- exchange = args.read_shortstr()
- routing_key = args.read_shortstr()
-
msg.channel = self
msg.delivery_info = {
'consumer_tag': consumer_tag,
@@ -1902,7 +1903,8 @@ class Channel(AbstractChannel):
wait=[Basic_GetOk, Basic_GetEmpty],
)
- def _basic_get_empty(self, args):
+ @inbound(Basic_GetEmpty, 's')
+ def _basic_get_empty(self, cluster_id):
"""Indicate no messages available
This method tells the client that the queue has no messages
@@ -1917,9 +1919,11 @@ class Channel(AbstractChannel):
client applications.
"""
- cluster_id = args.read_shortstr() # noqa
+ pass
- def _basic_get_ok(self, args, msg):
+ @inbound(Basic_GetOk, 'Lbssl', content=True)
+ def _basic_get_ok(self, delivery_tag, redelivered, exchange, routing_key,
+ message_count, msg):
"""Provide client with a message
This method delivers a message to the client following a get
@@ -1977,12 +1981,6 @@ class Channel(AbstractChannel):
queue and removed by other clients.
"""
- delivery_tag = args.read_longlong()
- redelivered = args.read_bit()
- exchange = args.read_shortstr()
- routing_key = args.read_shortstr()
- message_count = args.read_long()
-
msg.channel = self
msg.delivery_info = {
'delivery_tag': delivery_tag,
@@ -2144,7 +2142,8 @@ class Channel(AbstractChannel):
wait=[Basic_QosOk],
)
- def _basic_qos_ok(self, args):
+ @inbound(Basic_QosOk)
+ def _basic_qos_ok(self):
"""Confirm the requested qos
This method tells the client that the requested QoS levels
@@ -2189,7 +2188,8 @@ class Channel(AbstractChannel):
def basic_recover_async(self, requeue=False):
return self.send_method(Basic_RecoverAsync, 'b', (requeue, ))
- def _basic_recover_ok(self, args):
+ @inbound(Basic_RecoverOk)
+ def _basic_recover_ok(self):
"""In 0-9-1 the deprecated recover solicits a response."""
pass
@@ -2265,7 +2265,9 @@ class Channel(AbstractChannel):
"""
return self.send_method(Basic_Reject, argsig, (delivery_tag, requeue))
- def _basic_return(self, args, msg):
+ @inbound(Basic_Return, 'Bsss', content=True)
+ def _basic_return(self, reply_code, reply_text,
+ exchange, routing_key, message):
"""Return a failed message
This method returns an undeliverable message that was
@@ -2298,13 +2300,14 @@ class Channel(AbstractChannel):
message was published.
"""
- self.returned_messages.put(basic_return_t(
- args.read_short(),
- args.read_shortstr(),
- args.read_shortstr(),
- args.read_shortstr(),
- msg,
- ))
+ exc = error_for_code(
+ reply_code, reply_text, Basic_Return, ChannelError,
+ )
+ handlers = self.events.get('basic_return')
+ if not handlers:
+ raise exc
+ for callback in handlers:
+ callback(exc, exchange, routing_key, message)
#############
#
@@ -2344,7 +2347,8 @@ class Channel(AbstractChannel):
"""
return self.send_method(Tx_Commit, wait=[Tx_CommitOk])
- def _tx_commit_ok(self, args):
+ @inbound(Tx_CommitOk)
+ def _tx_commit_ok(self):
"""Confirm a successful commit
This method confirms to the client that the commit succeeded.
@@ -2364,7 +2368,8 @@ class Channel(AbstractChannel):
"""
return self.send_method(Tx_Rollback, wait=[Tx_RollbackOk])
- def _tx_rollback_ok(self, args):
+ @inbound(Tx_RollbackOk)
+ def _tx_rollback_ok(self):
"""Confirm a successful rollback
This method confirms to the client that the rollback
@@ -2384,7 +2389,8 @@ class Channel(AbstractChannel):
"""
return self.send_method(Tx_Select, wait=[Tx_SelectOk])
- def _tx_select_ok(self, args):
+ @inbound(Tx_SelectOk)
+ def _tx_select_ok(self):
"""Confirm transaction mode
This method confirms to the client that the channel was
@@ -2411,14 +2417,14 @@ class Channel(AbstractChannel):
wait=None if nowait else [Confirm_SelectOk],
)
- def _confirm_select_ok(self, args):
+ @inbound(Confirm_SelectOk)
+ def _confirm_select_ok(self):
"""With this method the broker confirms to the client that
the channel is now using publisher confirms."""
pass
- def _basic_ack_recv(self, args):
- delivery_tag = args.read_longlong()
- multiple = args.read_bit()
+ @inbound(Basic_Ack, 'Lb')
+ def _basic_ack_recv(self, delivery_tag, multiple):
for callback in self.events['basic_ack']:
callback(delivery_tag, multiple)
diff --git a/amqp/connection.py b/amqp/connection.py
index 982df8e..0c9cef3 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -28,10 +28,10 @@ except ImportError:
pass
from . import __version__
-from .abstract_channel import AbstractChannel
+from .abstract_channel import AbstractChannel, inbound
from .channel import Channel
from .exceptions import (
- AMQPNotImplementedError, ChannelError, ResourceError,
+ ChannelError, ResourceError,
ConnectionForced, ConnectionError, error_for_code,
RecoverableConnectionError, RecoverableChannelError,
)
@@ -146,7 +146,6 @@ class Connection(AbstractChannel):
login_response = login_response.getvalue()[4:]
d = dict(LIBRARY_PROPERTIES, **client_properties or {})
- self._method_override = {(60, 50): self._dispatch_basic_return}
self.channels = {}
# The connection object itself is treated as channel 0
@@ -248,14 +247,14 @@ class Connection(AbstractChannel):
# Nothing queued, need to wait for a method from the peer
#
while 1:
- channel, method_sig, args, content = \
+ channel, method_sig, payload, content = \
self.method_reader.read_method()
if channel == channel_id and (
allowed_methods is None or
method_sig in allowed_methods or
method_sig == (20, 40)):
- return method_sig, args, content
+ return method_sig, payload, content
#
# Certain methods like basic_return should be dispatched
@@ -264,7 +263,7 @@ class Connection(AbstractChannel):
#
if channel and method_sig in self.Channel._IMMEDIATE_METHODS:
self.channels[channel].dispatch_method(
- method_sig, args, content,
+ method_sig, payload, content,
)
continue
@@ -273,7 +272,7 @@ class Connection(AbstractChannel):
# this method for later
#
self.channels[channel].method_queue.append(
- (method_sig, args, content),
+ (method_sig, payload, content),
)
#
@@ -298,31 +297,12 @@ class Connection(AbstractChannel):
def drain_events(self, timeout=None):
"""Wait for an event on a channel."""
chanmap = self.channels
- chanid, method_sig, args, content = self._wait_multiple(
+ chanid, method_sig, payload, content = self._wait_multiple(
chanmap, None, timeout=timeout,
)
channel = chanmap[chanid]
-
- if (content and
- channel.auto_decode and
- hasattr(content, 'content_encoding')):
- try:
- content.body = content.body.decode(content.content_encoding)
- except Exception:
- pass
-
- amqp_method = (self._method_override.get(method_sig) or
- channel._METHOD_MAP.get(method_sig, None))
-
- if amqp_method is None:
- raise AMQPNotImplementedError(
- 'Unknown AMQP method {0!r}'.format(method_sig))
-
- if content is None:
- return amqp_method(channel, args)
- else:
- return amqp_method(channel, args, content)
+ return channel.dispatch_method(method_sig, payload, content)
def read_timeout(self, timeout=None):
if timeout is None:
@@ -355,24 +335,26 @@ class Connection(AbstractChannel):
method_sig in allowed_methods or
method_sig == (20, 40)):
method_queue.remove(queued_method)
- method_sig, args, content = queued_method
- return channel_id, method_sig, args, content
+ method_sig, payload, content = queued_method
+ return channel_id, method_sig, payload, content
# Nothing queued, need to wait for a method from the peer
read_timeout = self.read_timeout
wait = self.wait
while 1:
- channel, method_sig, args, content = read_timeout(timeout)
+ channel, method_sig, payload, content = read_timeout(timeout)
if channel in channels and (
allowed_methods is None or
method_sig in allowed_methods or
method_sig == (20, 40)):
- return channel, method_sig, args, content
+ return channel, method_sig, payload, content
# Not the channel and/or method we were looking for. Queue
# this method for later
- channels[channel].method_queue.append((method_sig, args, content))
+ channels[channel].method_queue.append(
+ (method_sig, payload, content),
+ )
#
# If we just queued up a method for channel 0 (the Connection
@@ -382,19 +364,6 @@ class Connection(AbstractChannel):
if channel == 0:
wait()
- def _dispatch_basic_return(self, channel, args, msg):
- reply_code = args.read_short()
- reply_text = args.read_shortstr()
- exchange = args.read_shortstr()
- routing_key = args.read_shortstr()
-
- exc = error_for_code(reply_code, reply_text, (50, 60), ChannelError)
- handlers = channel.events.get('basic_return')
- if not handlers:
- raise exc
- for callback in handlers:
- callback(exc, exchange, routing_key, msg)
-
def close(self, reply_code=0, reply_text='', method_sig=(0, 0),
argsig='BssBB'):
@@ -461,7 +430,8 @@ class Connection(AbstractChannel):
wait=[Connection_Close, Connection_CloseOk],
)
- def _close(self, args):
+ @inbound(Connection_Close, 'BsBB')
+ def _close(self, reply_code, reply_text, class_id, method_id):
"""Request a connection close
This method indicates that the sender wants to close the
@@ -515,30 +485,19 @@ class Connection(AbstractChannel):
is the ID of the method.
"""
- reply_code = args.read_short()
- reply_text = args.read_shortstr()
- class_id = args.read_short()
- method_id = args.read_short()
-
self._x_close_ok()
-
raise error_for_code(reply_code, reply_text,
(class_id, method_id), ConnectionError)
- def _blocked(self, args):
+ @inbound(Connection_Blocked)
+ def _blocked(self):
"""RabbitMQ Extension."""
- try:
- reason = args.read_shortstr()
- except UnicodeDecodeError:
- # XXX Spec say this is a shortstr, but amqplib seems to
- # except strings to be in utf-8, even though the spec does
- # not dictate any special encoding.
- # (see amqp.serialization:AMQPReader.read_shortstr)
- reason = 'connection blocked, see broker logs'
+ reason = 'connection blocked, see broker logs'
if self.on_blocked:
return self.on_blocked(reason)
- def _unblocked(self, *args):
+ @inbound(Connection_Unblocked)
+ def _unblocked(self):
if self.on_unblocked:
return self.on_unblocked()
@@ -558,7 +517,8 @@ class Connection(AbstractChannel):
self._send_method(Connection_CloseOk)
self._do_close()
- def _close_ok(self, args):
+ @inbound(Connection_CloseOk)
+ def _close_ok(self):
"""Confirm a connection close
This method confirms a Connection.Close method and tells the
@@ -627,7 +587,8 @@ class Connection(AbstractChannel):
wait=[Connection_OpenOk],
)
- def _open_ok(self, args):
+ @inbound(Connection_OpenOk)
+ def _open_ok(self):
"""Signal that the connection is ready
This method signals to the client that the connection is ready
@@ -639,7 +600,8 @@ class Connection(AbstractChannel):
"""
AMQP_LOGGER.debug('Open OK!')
- def _secure(self, args):
+ @inbound(Connection_Secure, 's')
+ def _secure(self, challenge):
"""Security mechanism challenge
The SASL protocol works by exchanging challenges and responses
@@ -656,7 +618,7 @@ class Connection(AbstractChannel):
passed to the security mechanism.
"""
- challenge = args.read_longstr() # noqa
+ pass
def _x_secure_ok(self, response):
"""Security mechanism response
@@ -676,7 +638,9 @@ class Connection(AbstractChannel):
"""
return self.send_method(Connection_SecureOk, 'S', (response, ))
- def _start(self, args):
+ @inbound(Connection_Start, 'ooFSS')
+ def _start(self, version_major, version_minor, server_properties,
+ mechanisms, locales):
"""Start connection negotiation
This method starts the connection negotiation process by
@@ -740,12 +704,11 @@ class Connection(AbstractChannel):
locale.
"""
- self.version_major = args.read_octet()
- self.version_minor = args.read_octet()
- self.server_properties = args.read_table()
- self.mechanisms = args.read_longstr().split(' ')
- self.locales = args.read_longstr().split(' ')
-
+ self.version_major = version_major
+ self.version_minor = version_minor
+ self.server_properties = server_properties
+ self.mechanisms = mechanisms.split(' ')
+ self.locales = locales.split(' ')
AMQP_LOGGER.debug(
START_DEBUG_FMT,
self.version_major, self.version_minor,
@@ -815,7 +778,8 @@ class Connection(AbstractChannel):
(client_properties, mechanism, response, locale),
)
- def _tune(self, args):
+ @inbound(Connection_Tune, 'BlB')
+ def _tune(self, channel_max, frame_max, server_heartbeat):
"""Propose connection tuning parameters
This method proposes a set of connection configuration values
@@ -858,10 +822,10 @@ class Connection(AbstractChannel):
"""
client_heartbeat = self.client_heartbeat or 0
- self.channel_max = args.read_short() or self.channel_max
- self.frame_max = args.read_long() or self.frame_max
+ self.channel_max = channel_max or self.channel_max
+ self.frame_max = frame_max or self.frame_max
self.method_writer.frame_max = self.frame_max
- self.server_heartbeat = args.read_short() or 0
+ self.server_heartbeat = server_heartbeat or 0
# negotiate the heartbeat interval to the smaller of the
# specified values
diff --git a/amqp/method_framing.py b/amqp/method_framing.py
index b454524..cedda44 100644
--- a/amqp/method_framing.py
+++ b/amqp/method_framing.py
@@ -17,12 +17,11 @@
from __future__ import absolute_import
from collections import defaultdict, deque
-from struct import pack, unpack
+from struct import pack, unpack, unpack_from
from .basic_message import Message
from .exceptions import AMQPError, UnexpectedFrame
from .five import range, string
-from .serialization import AMQPReader
__all__ = ['MethodReader']
@@ -132,21 +131,21 @@ class MethodReader(object):
def _process_heartbeat(self, channel, payload):
self.heartbeats += 1
- def _process_method_frame(self, channel, payload):
+ def _process_method_frame(self, channel, payload,
+ unpack_from=unpack_from):
"""Process Method frames"""
- method_sig = unpack('>HH', payload[:4])
- args = AMQPReader(payload[4:])
+ method_sig = unpack_from('>HH', payload, 0)
if method_sig in _CONTENT_METHODS:
#
# Save what we've got so far and wait for the content-header
#
self.partial_messages[channel] = _PartialMessage(
- method_sig, args, channel,
+ method_sig, payload, channel,
)
self.expected_types[channel] = 2
else:
- self._quick_put((channel, method_sig, args, None))
+ self._quick_put((channel, method_sig, payload, None))
def _process_content_header(self, channel, payload):
"""Process Content Header frames"""
diff --git a/amqp/serialization.py b/amqp/serialization.py
index 8366a1e..cecd522 100644
--- a/amqp/serialization.py
+++ b/amqp/serialization.py
@@ -26,7 +26,7 @@ import sys
from datetime import datetime
from decimal import Decimal
from io import BytesIO
-from struct import pack, unpack
+from struct import pack, unpack, unpack_from
from time import mktime
from .exceptions import FrameSyntaxError
@@ -50,6 +50,200 @@ ILLEGAL_TABLE_TYPE = """\
"""
+def _read_item(buf, offset=0, unpack=unpack_from):
+ ftype = buf[offset]
+ offset += 1
+
+ # 'S': long string
+ if ftype == 'S':
+ slen, = unpack('>I', buf, offset)
+ offset += 4
+ val = buf[offset:offset + slen]
+ offset += slen
+ # 's': short string
+ elif ftype == 's':
+ slen, = unpack('>B', buf, offset)
+ offset += 1
+ val = buf[offset:offset + slen]
+ offset += slen
+ # 'b': short-short int
+ elif ftype == 'b':
+ val, = unpack('>B', buf, offset)
+ offset += 1
+ # 'B': short-short unsigned int
+ elif ftype == 'B':
+ val, = unpack('>b', buf, offset)
+ offset += 1
+ # 'U': short int
+ elif ftype == 'U':
+ val, = unpack('>h', buf, offset)
+ offset += 2
+ # 'u': short unsigned int
+ elif ftype == 'u':
+ val, = unpack('>H', buf, offset)
+ offset += 2
+ # 'I': long int
+ elif ftype == 'I':
+ val, = unpack('>i', buf, offset)
+ offset += 4
+ # 'i': long unsigned int
+ elif ftype == 'i':
+ val, = unpack('>I', buf, offset)
+ offset += 4
+ # 'L': long long int
+ elif ftype == 'L':
+ val, = unpack('>q', buf, offset)
+ offset += 8
+ # 'l': long long unsigned int
+ elif ftype == 'l':
+ val, = unpack('>Q', buf, offset)
+ offset += 8
+ # 'f': float
+ elif ftype == 'f':
+ val, = unpack('>f', buf, offset)
+ offset += 4
+ # 'd': double
+ elif ftype == 'd':
+ val, = unpack('>d', buf, offset)
+ offset += 8
+ # 'D': decimal
+ elif ftype == 'D':
+ d, = unpack('>B', buf, offset)
+ offset += 1
+ n, = unpack('>i', buf, offset)
+ offset += 4
+ val = Decimal(n) / Decimal(10 ** d)
+ # 'F': table
+ elif ftype == 'F':
+ tlen, = unpack('>I', buf, offset)
+ offset += 4
+ limit = offset + tlen
+ val = {}
+ while offset < limit:
+ keylen, = unpack('>B', buf, offset)
+ offset += 1
+ key = buf[offset:offset + keylen]
+ offset += keylen
+ val[key], offset = _read_item(buf, offset)
+ # 'A': array
+ elif ftype == 'A':
+ alen, = unpack('>I', buf, offset)
+ offset += 4
+ limit = offset + alen
+ val = []
+ while offset < limit:
+ v, offset = _read_item(buf, offset)
+ val.append(v)
+ # 't' (bool)
+ elif ftype == 't':
+ val, = unpack('>B', buf, offset)
+ val = bool(val)
+ offset += 1
+ # 'T': timestamp
+ elif ftype == 'T':
+ val, = unpack('>Q', buf, offset)
+ offset += 8
+ val = datetime.utcfromtimestamp(val)
+ # 'V': void
+ elif ftype == 'V':
+ val = None
+ else:
+ raise FrameSyntaxError(
+ 'Unknown value in table: {0!r} ({1!r})'.format(
+ ftype, type(ftype)))
+ return val, offset
+
+
+def loads(format, buf, offset=0,
+ ord=ord, unpack=unpack_from, _read_item=_read_item):
+ """
+ bit = b
+ octet = o
+ short = B
+ long = l
+ long long = L
+ float = f
+ shortstr = s
+ longstr = S
+ table = F
+ array = A
+ """
+ bitcount = bits = 0
+
+ values = []
+ append = values.append
+
+ for p in format:
+ if p == 'b':
+ if not bitcount:
+ bits = ord(buf[offset:offset + 1])
+ bitcount = 8
+ val = (bits & 1) == 1
+ bits >>= 1
+ bitcount -= 1
+ offset += 1
+ elif p == 'o':
+ bitcount = bits = 0
+ val, = unpack('>B', buf, offset)
+ offset += 1
+ elif p == 'B':
+ bitcount = bits = 0
+ val, = unpack('>H', buf, offset)
+ offset += 2
+ elif p == 'l':
+ bitcount = bits = 0
+ val, = unpack('>I', buf, offset)
+ offset += 4
+ elif p == 'L':
+ bitcount = bits = 0
+ val, = unpack('>Q', buf, offset)
+ offset += 8
+ elif p == 'f':
+ bitcount = bits = 0
+ val, = unpack('>d', buf, offset)
+ offset += 8
+ elif p == 's':
+ bitcount = bits = 0
+ slen, = unpack('B', buf, offset)
+ offset += 1
+ val = buf[offset:offset + slen].decode('utf-8')
+ offset += slen
+ elif p == 'S':
+ bitcount = bits = 0
+ slen, = unpack('>I', buf, offset)
+ offset += 4
+ val = buf[offset:offset + slen].decode('utf-8')
+ offset += slen
+ elif p == 'F':
+ bitcount = bits = 0
+ tlen, = unpack('>I', buf, offset)
+ offset += 4
+ limit = offset + tlen
+ val = {}
+ while offset < limit:
+ keylen, = unpack('>B', buf, offset)
+ offset += 1
+ key = buf[offset:offset + keylen]
+ offset += keylen
+ val[key], offset = _read_item(buf, offset)
+ elif p == 'A':
+ bitcount = bits = 0
+ alen, = unpack('>I', buf, offset)
+ offset += 4
+ limit = offset + alen
+ val = []
+ while offset < limit:
+ aval, offset = _read_item(buf, offset)
+ val.append(aval)
+ elif p == 'T':
+ bitcount = bits = 0
+ val, = unpack('>Q', buf, offset)
+ offset += 8
+ val = datetime.fromtimestamp(val)
+ append(val)
+ return values
+
+
class AMQPReader(object):
"""Read higher-level AMQP types from a bytestream."""
def __init__(self, source):
@@ -261,39 +455,39 @@ def dumps(format, values):
bits.append(0)
bits[-1] |= (val << shift)
bitcount += 1
- if p == 'o':
+ elif p == 'o':
bitcount = _flushbits(bits, write)
write(pack('B', val))
- if p == 'B':
+ elif p == 'B':
bitcount = _flushbits(bits, write)
write(pack('>H', int(val)))
- if p == 'l':
+ elif p == 'l':
bitcount = _flushbits(bits, write)
write(pack('>I', val))
- if p == 'L':
+ elif p == 'L':
bitcount = _flushbits(bits, write)
write(pack('>Q', val))
- if p == 's':
+ elif p == 's':
val = val or ''
bitcount = _flushbits(bits, write)
if isinstance(val, string):
val = val.encode('utf-8')
write(pack('B', len(val)))
write(val)
- if p == 'S':
+ elif p == 'S':
val = val or ''
bitcount = _flushbits(bits, write)
if isinstance(val, string):
val = val.encode('utf-8')
write(pack('>I', len(val)))
write(val)
- if p == 'F':
+ elif p == 'F':
bitcount = _flushbits(bits, write)
_write_table(val or {}, write, bits)
- if p == 'A':
+ elif p == 'A':
bitcount = _flushbits(bits, write)
_write_array(val or [], write, bits)
- if p == 'T':
+ elif p == 'T':
write(pack('>q', long_t(mktime(val.timetuple()))))
_flushbits(bits, write)