diff options
author | Adrien Guinet <aguinet@quarkslab.com> | 2013-07-12 12:03:23 +0200 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2013-07-31 14:18:11 +0100 |
commit | 325fbb2449636eac1952679fe61e4f160d33b29e (patch) | |
tree | 98b349d157a3352916a4af10863ec87447a89ab3 | |
parent | 57b11515a8756e665523046293b75cca6edc5b3b (diff) | |
download | py-amqp-325fbb2449636eac1952679fe61e4f160d33b29e.tar.gz |
Fix AMQP SSL support
This patch changes SSLTransport so that it returns the SSL "wrapped"
socket for the "sock" property, and not the original socket.
-rw-r--r-- | amqp/transport.py | 54 |
1 files changed, 33 insertions, 21 deletions
diff --git a/amqp/transport.py b/amqp/transport.py index 5ad42f0..5dde9a2 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -78,29 +78,29 @@ class _AbstractTransport(object): host, port = host.rsplit(':', 1) port = int(port) - self.sock = None + self._sock = None last_err = None for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, SOL_TCP): af, socktype, proto, canonname, sa = res try: - self.sock = socket.socket(af, socktype, proto) - self.sock.settimeout(connect_timeout) - self.sock.connect(sa) + self._sock = socket.socket(af, socktype, proto) + self._sock.settimeout(connect_timeout) + self._sock.connect(sa) except socket.error, msg: - self.sock.close() - self.sock = None + self._sock.close() + self._sock = None last_err = msg continue break - if not self.sock: + if not self._sock: # Didn't connect, return the most recent error message raise socket.error(last_err) - self.sock.settimeout(None) - self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1) - self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self._sock.settimeout(None) + self._sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1) + self._sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self._setup_transport() @@ -112,7 +112,7 @@ class _AbstractTransport(object): except socket.error: pass finally: - self.sock = None + self._sock = None def _read(self, n, initial=False): """Read exactly n bytes from the peer""" @@ -132,14 +132,14 @@ class _AbstractTransport(object): raise NotImplementedError('Must be overriden in subclass') def close(self): - if self.sock is not None: + if self._sock is not None: self._shutdown_transport() # Call shutdown first to make sure that pending messages # reach the AMQP broker if the program exits after # calling this method. - self.sock.shutdown(socket.SHUT_RDWR) - self.sock.close() - self.sock = None + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + self._sock = None def read_frame(self): """Read an AMQP frame.""" @@ -159,6 +159,14 @@ class _AbstractTransport(object): pack('>BHI%dsB' % size, frame_type, channel, size, payload, 0xce), ) + @property + def sock(self): + return self._sock + + @sock.setter + def sock(self, v): + self._sock = v + class SSLTransport(_AbstractTransport): """Transport that works over SSL""" @@ -175,17 +183,17 @@ class SSLTransport(_AbstractTransport): lower version.""" if HAVE_PY26_SSL: if hasattr(self, 'sslopts'): - self.sslobj = ssl.wrap_socket(self.sock, **self.sslopts) + self.sslobj = ssl.wrap_socket(self._sock, **self.sslopts) else: - self.sslobj = ssl.wrap_socket(self.sock) + self.sslobj = ssl.wrap_socket(self._sock) self.sslobj.do_handshake() else: - self.sslobj = socket.ssl(self.sock) + self.sslobj = socket.ssl(self._sock) def _shutdown_transport(self): """Unwrap a Python 2.6 SSL socket, so we can call shutdown()""" if HAVE_PY26_SSL and (self.sslobj is not None): - self.sock = self.sslobj.unwrap() + self._sock = self.sslobj.unwrap() self.sslobj = None def _read(self, n, initial=False): @@ -216,6 +224,10 @@ class SSLTransport(_AbstractTransport): raise IOError('Socket closed') s = s[n:] + @property + def sock(self): + return self.sslobj + class TCPTransport(_AbstractTransport): """Transport that deals directly with TCP socket.""" @@ -223,14 +235,14 @@ class TCPTransport(_AbstractTransport): def _setup_transport(self): """Setup to _write() directly to the socket, and do our own buffered reads.""" - self._write = self.sock.sendall + self._write = self._sock.sendall self._read_buffer = bytes() def _read(self, n, initial=False): """Read exactly n bytes from the socket""" while len(self._read_buffer) < n: try: - s = self.sock.recv(65536) + s = self._sock.recv(65536) except socket.error, exc: if not initial and exc.errno in (errno.EAGAIN, errno.EINTR): continue |