diff options
author | Ask Solem <ask@celeryproject.org> | 2015-04-29 16:39:25 +0100 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2015-04-29 16:39:25 +0100 |
commit | 96e1c8504799cf51a3c7cbec5d7ceccb7f78112d (patch) | |
tree | b1d56c7dec0f7dbe2c22684c07d02f2179ae30b4 | |
parent | f80a72400a4bd27a09694ec43dbcceaabdb0d22b (diff) | |
download | py-amqp-96e1c8504799cf51a3c7cbec5d7ceccb7f78112d.tar.gz |
basic_publish now attempts to read error frames from socket (Issue celery/celery#2595
-rw-r--r-- | amqp/channel.py | 12 | ||||
-rw-r--r-- | amqp/connection.py | 77 | ||||
-rw-r--r-- | amqp/exceptions.py | 2 |
3 files changed, 80 insertions, 11 deletions
diff --git a/amqp/channel.py b/amqp/channel.py index 05eb09a..896588e 100644 --- a/amqp/channel.py +++ b/amqp/channel.py @@ -22,7 +22,9 @@ from collections import defaultdict from warnings import warn from .abstract_channel import AbstractChannel -from .exceptions import ChannelError, ConsumerCancelled, error_for_code +from .exceptions import ( + ChannelError, ConnectionError, ConsumerCancelled, error_for_code, +) from .five import Queue from .protocol import basic_return_t, queue_declare_ok_t from .serialization import AMQPWriter @@ -2120,7 +2122,13 @@ class Channel(AbstractChannel): args.write_bit(immediate) self._send_method((60, 40), args, msg) - basic_publish = _basic_publish + + def basic_publish(self, *args, **kwargs): + if self.connection is None: + raise ConnectionError('Channel already closed.') + if self.connection._readable(): + self.connection._maybe_read_error(self.channel_id) + return self._basic_publish(*args, **kwargs) def basic_publish_confirm(self, *args, **kwargs): if not self._confirm_selected: diff --git a/amqp/connection.py b/amqp/connection.py index f1781c0..8435fde 100644 --- a/amqp/connection.py +++ b/amqp/connection.py @@ -16,7 +16,9 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 from __future__ import absolute_import +import errno import logging +import select import socket from array import array @@ -41,6 +43,11 @@ from .transport import create_transport HAS_MSG_PEEK = hasattr(socket, 'MSG_PEEK') +try: + SELECT_BAD_FD = set((errno.EBADF, errno.WSAENOTSOCK)) +except AttributeError: + SELECT_BAD_FD = set((errno.EBADF,)) + START_DEBUG_FMT = """ Start from server, version: %d.%d, properties: %s, mechanisms: %s, locales: %s """.strip() @@ -59,6 +66,40 @@ LIBRARY_PROPERTIES = { AMQP_LOGGER = logging.getLogger('amqp') +if hasattr(select, 'poll'): + def _select(readers, writers, err, timeout=0, + _poll=select.poll, POLLIN=select.POLLIN, + POLLOUT=select.POLLOUT, POLLERR=select.POLLERR): + poller = _poll() + + register = poller.register + if readers: + [register(fd, POLLIN) for fd in readers] + if writers: + [register(fd, POLLOUT) for fd in writers] + if err: + [register(fd, POLLERR) for fd in err] + + R, W = set(), set() + timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3) + events = poller.poll(timeout) + for fd, event in events: + if event & POLLIN: + R.add(fd) + if event & POLLOUT: + R.add(fd) + if event & POLLERR: + R.add(fd) + return R, W +else: + def _select(readers, writers, err, timeout=0, + _poll=select.select): + r, w, e = _poll(readers, writers, err, timeout) + if e: + r = list(set(r) | set(e)) + return r, w + + class Connection(AbstractChannel): """The connection class provides methods for a client to establish a network connection to a server, and for both peers to operate the @@ -182,6 +223,20 @@ class Connection(AbstractChannel): return self._x_open(virtual_host) + def _select(self, readers, writers, err, timeout=0): + try: + return _select(readers, writers, [], timeout) + except (select.error, socket.error) as exc: + _errno = getattr(exc, 'errno', None) + if _errno == errno.EINTR: + return [], [] + elif _errno in SELECT_BAD_FD: + raise self.ConnectionError('Socket closed: {0!r}'.format(exc)) + raise self.ConnectionError('Socket error: {0!r}'.format(exc)) + + def _readable(self): + return self._select([self.sock], [], [self.sock])[0] + def Transport(self, host, connect_timeout, ssl=False): return create_transport(host, connect_timeout, ssl) @@ -295,14 +350,8 @@ class Connection(AbstractChannel): sock.settimeout(prev) return True - def drain_events(self, timeout=None): - """Wait for an event on a channel.""" - chanmap = self.channels - chanid, method_sig, args, content = self._wait_multiple( - chanmap, None, timeout=timeout, - ) - - channel = chanmap[chanid] + def _dispatch_method(self, chanid, method_sig, args, content): + channel = self.channels[chanid] if (content and channel.auto_decode and @@ -324,6 +373,12 @@ class Connection(AbstractChannel): else: return amqp_method(channel, args, content) + def drain_events(self, timeout=None): + """Wait for an event on a channel.""" + self._dispatch_method(*self._wait_multiple( + self.channels, None, timeout=timeout, + )) + def read_timeout(self, timeout=None): if timeout is None: return self.method_reader.read_method() @@ -346,6 +401,12 @@ class Connection(AbstractChannel): if prev != timeout: sock.settimeout(prev) + def _maybe_read_error(self, wanted_channel): + channel, method_sig, args, content = self.method_reader.read_method() + if channel == wanted_channel and method_sig == (20, 40): + return self._dispatch_method(channel, method_sig, args, content) + self.channels[channel].method_queue.append((method_sig, args, content)) + def _wait_multiple(self, channels, allowed_methods, timeout=None): for channel_id, channel in items(channels): method_queue = channel.method_queue diff --git a/amqp/exceptions.py b/amqp/exceptions.py index e3e144a..50c12c6 100644 --- a/amqp/exceptions.py +++ b/amqp/exceptions.py @@ -108,7 +108,7 @@ class AccessRefused(IrrecoverableChannelError): code = 403 -class NotFound(IrrecoverableChannelError): +class NotFound(RecoverableChannelError): code = 404 |