summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2015-04-29 16:39:25 +0100
committerAsk Solem <ask@celeryproject.org>2015-04-29 16:39:25 +0100
commit96e1c8504799cf51a3c7cbec5d7ceccb7f78112d (patch)
treeb1d56c7dec0f7dbe2c22684c07d02f2179ae30b4
parentf80a72400a4bd27a09694ec43dbcceaabdb0d22b (diff)
downloadpy-amqp-96e1c8504799cf51a3c7cbec5d7ceccb7f78112d.tar.gz
basic_publish now attempts to read error frames from socket (Issue celery/celery#2595
-rw-r--r--amqp/channel.py12
-rw-r--r--amqp/connection.py77
-rw-r--r--amqp/exceptions.py2
3 files changed, 80 insertions, 11 deletions
diff --git a/amqp/channel.py b/amqp/channel.py
index 05eb09a..896588e 100644
--- a/amqp/channel.py
+++ b/amqp/channel.py
@@ -22,7 +22,9 @@ from collections import defaultdict
from warnings import warn
from .abstract_channel import AbstractChannel
-from .exceptions import ChannelError, ConsumerCancelled, error_for_code
+from .exceptions import (
+ ChannelError, ConnectionError, ConsumerCancelled, error_for_code,
+)
from .five import Queue
from .protocol import basic_return_t, queue_declare_ok_t
from .serialization import AMQPWriter
@@ -2120,7 +2122,13 @@ class Channel(AbstractChannel):
args.write_bit(immediate)
self._send_method((60, 40), args, msg)
- basic_publish = _basic_publish
+
+ def basic_publish(self, *args, **kwargs):
+ if self.connection is None:
+ raise ConnectionError('Channel already closed.')
+ if self.connection._readable():
+ self.connection._maybe_read_error(self.channel_id)
+ return self._basic_publish(*args, **kwargs)
def basic_publish_confirm(self, *args, **kwargs):
if not self._confirm_selected:
diff --git a/amqp/connection.py b/amqp/connection.py
index f1781c0..8435fde 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -16,7 +16,9 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301
from __future__ import absolute_import
+import errno
import logging
+import select
import socket
from array import array
@@ -41,6 +43,11 @@ from .transport import create_transport
HAS_MSG_PEEK = hasattr(socket, 'MSG_PEEK')
+try:
+ SELECT_BAD_FD = set((errno.EBADF, errno.WSAENOTSOCK))
+except AttributeError:
+ SELECT_BAD_FD = set((errno.EBADF,))
+
START_DEBUG_FMT = """
Start from server, version: %d.%d, properties: %s, mechanisms: %s, locales: %s
""".strip()
@@ -59,6 +66,40 @@ LIBRARY_PROPERTIES = {
AMQP_LOGGER = logging.getLogger('amqp')
+if hasattr(select, 'poll'):
+ def _select(readers, writers, err, timeout=0,
+ _poll=select.poll, POLLIN=select.POLLIN,
+ POLLOUT=select.POLLOUT, POLLERR=select.POLLERR):
+ poller = _poll()
+
+ register = poller.register
+ if readers:
+ [register(fd, POLLIN) for fd in readers]
+ if writers:
+ [register(fd, POLLOUT) for fd in writers]
+ if err:
+ [register(fd, POLLERR) for fd in err]
+
+ R, W = set(), set()
+ timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3)
+ events = poller.poll(timeout)
+ for fd, event in events:
+ if event & POLLIN:
+ R.add(fd)
+ if event & POLLOUT:
+ R.add(fd)
+ if event & POLLERR:
+ R.add(fd)
+ return R, W
+else:
+ def _select(readers, writers, err, timeout=0,
+ _poll=select.select):
+ r, w, e = _poll(readers, writers, err, timeout)
+ if e:
+ r = list(set(r) | set(e))
+ return r, w
+
+
class Connection(AbstractChannel):
"""The connection class provides methods for a client to establish a
network connection to a server, and for both peers to operate the
@@ -182,6 +223,20 @@ class Connection(AbstractChannel):
return self._x_open(virtual_host)
+ def _select(self, readers, writers, err, timeout=0):
+ try:
+ return _select(readers, writers, [], timeout)
+ except (select.error, socket.error) as exc:
+ _errno = getattr(exc, 'errno', None)
+ if _errno == errno.EINTR:
+ return [], []
+ elif _errno in SELECT_BAD_FD:
+ raise self.ConnectionError('Socket closed: {0!r}'.format(exc))
+ raise self.ConnectionError('Socket error: {0!r}'.format(exc))
+
+ def _readable(self):
+ return self._select([self.sock], [], [self.sock])[0]
+
def Transport(self, host, connect_timeout, ssl=False):
return create_transport(host, connect_timeout, ssl)
@@ -295,14 +350,8 @@ class Connection(AbstractChannel):
sock.settimeout(prev)
return True
- def drain_events(self, timeout=None):
- """Wait for an event on a channel."""
- chanmap = self.channels
- chanid, method_sig, args, content = self._wait_multiple(
- chanmap, None, timeout=timeout,
- )
-
- channel = chanmap[chanid]
+ def _dispatch_method(self, chanid, method_sig, args, content):
+ channel = self.channels[chanid]
if (content and
channel.auto_decode and
@@ -324,6 +373,12 @@ class Connection(AbstractChannel):
else:
return amqp_method(channel, args, content)
+ def drain_events(self, timeout=None):
+ """Wait for an event on a channel."""
+ self._dispatch_method(*self._wait_multiple(
+ self.channels, None, timeout=timeout,
+ ))
+
def read_timeout(self, timeout=None):
if timeout is None:
return self.method_reader.read_method()
@@ -346,6 +401,12 @@ class Connection(AbstractChannel):
if prev != timeout:
sock.settimeout(prev)
+ def _maybe_read_error(self, wanted_channel):
+ channel, method_sig, args, content = self.method_reader.read_method()
+ if channel == wanted_channel and method_sig == (20, 40):
+ return self._dispatch_method(channel, method_sig, args, content)
+ self.channels[channel].method_queue.append((method_sig, args, content))
+
def _wait_multiple(self, channels, allowed_methods, timeout=None):
for channel_id, channel in items(channels):
method_queue = channel.method_queue
diff --git a/amqp/exceptions.py b/amqp/exceptions.py
index e3e144a..50c12c6 100644
--- a/amqp/exceptions.py
+++ b/amqp/exceptions.py
@@ -108,7 +108,7 @@ class AccessRefused(IrrecoverableChannelError):
code = 403
-class NotFound(IrrecoverableChannelError):
+class NotFound(RecoverableChannelError):
code = 404