summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAsk Solem <ask@celeryproject.org>2016-06-17 10:49:35 -0700
committerAsk Solem <ask@celeryproject.org>2016-06-17 10:49:35 -0700
commitdeb91fec4a6ecdfa4daf655e34dba3d40649e98b (patch)
treedf415a98a05cd81ec4bc84bf546a3d94f1a61c8e
parent8b5789948873215f80fbf3d0f29a963b7a76912b (diff)
downloadpy-amqp-typing.tar.gz
Typing experimentstyping
-rw-r--r--amqp/abstract_channel.py2
-rw-r--r--amqp/connection.py111
-rw-r--r--amqp/serialization.py2
-rw-r--r--amqp/tests/test_serialization.py2
-rw-r--r--amqp/tests/test_transport.py2
-rw-r--r--amqp/transport.py103
-rw-r--r--amqp/types.py5
7 files changed, 116 insertions, 111 deletions
diff --git a/amqp/abstract_channel.py b/amqp/abstract_channel.py
index da60769..ede206b 100644
--- a/amqp/abstract_channel.py
+++ b/amqp/abstract_channel.py
@@ -25,7 +25,7 @@ from .serialization import dumps, loads
__all__ = ['AbstractChannel']
-class AbstractChannel(object):
+class AbstractChannel:
"""Superclass for both the Connection, which is treated
as channel 0, and other user-created Channel objects.
diff --git a/amqp/connection.py b/amqp/connection.py
index a000bfc..000b52d 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -38,7 +38,7 @@ from .exceptions import (
from .five import array, range, values, monotonic
from .method_framing import frame_handler, frame_writer
from .serialization import _write_table
-from .transport import Transport
+from .transport import BaseTransport, Transport
from .types import SSLArg, Timeout, MaybeDict
from typing import Mapping, Generator, Callable, List, Any
@@ -78,16 +78,17 @@ AMQP_LOGGER = logging.getLogger('amqp')
ConnectionBlockedCallback = typing.Callable[[str], None]
ConnectionUnblockedCallback = typing.Callable[[], None]
ConnectionInboundMethodHandler = typing.Callable[
- [int, spec.method_sig_t, typing.ByteString, typing.ByteString], typing.Any,
+ [int, typing.Tuple[Any], typing.ByteString, typing.ByteString], typing.Any,
]
ConnectionFrameHandler = typing.Callable[
['Connection', ConnectionInboundMethodHandler],
- typing.Generator[typing.Any],
+ typing.Generator[typing.Any, typing.Any, typing.Any],
]
ConnectionFrameWriter = typing.Callable[
- ['Connection', BaseTransport], typing.Generator[typing.Any],
+ ['Connection', BaseTransport], typing.Generator[typing.Any, typing.Any,
+ typing.Any],
]
-MethodSigMethodMapping = Mapping[spec.method_sig_t, spec.method]
+MethodSigMethodMapping = Mapping[typing.Tuple[Any], typing.Tuple[Any]]
class Connection(AbstractChannel):
@@ -168,16 +169,26 @@ class Connection(AbstractChannel):
RecoverableChannelError,
)
- def __init__(self, host='localhost:5672', userid='guest', password='guest',
- login_method='AMQPLAIN', login_response=None,
- virtual_host='/', locale='en_US', client_properties=None,
- ssl=False, connect_timeout=None, channel_max=None,
- frame_max=None, heartbeat=0, on_open=None, on_blocked=None,
- on_unblocked=None, confirm_publish=False,
- on_tune_ok=None, read_timeout=None, write_timeout=None,
- socket_settings=None, frame_handler=frame_handler,
- frame_writer=frame_writer, **kwargs):
- # type: (str, str, str, str, Any, str, str, Mapping, SSLArg, Timeout, int, int, Timeout, Thenable, ConnectionBlockedCallback, ConnectionUnblockedCallback, bool, ConnectionTuneOkCallback, Timeout, Timeout, MaybeDict, ConnectionFrameHandler, ConnectionFrameWriter, **Any) -> None
+ def __init__(self, host: str='localhost:5672',
+ userid: str='guest', password: str='guest',
+ login_method: str='AMQPLAIN', login_response: Any=None,
+ virtual_host: str='/', locale: str='en_US',
+ client_properties: MaybeDict=None,
+ ssl: Optional[SSLArg]=False, connect_timeout: Timeout=None,
+ channel_max: Optional[int]=None,
+ frame_max: Optional[int]=None,
+ heartbeat: Timeout=0,
+ on_open: Thenable=None,
+ on_blocked: Optional[ConnectionBlockedCallback]=None,
+ on_unblocked: Optional[ConnectionUnblockedCallback]=None,
+ bool: confirm_publish=False,
+ on_tune_ok: Optional[ConnectionTuneOkCallback]=None,
+ read_timeout: Timeout=None,
+ write_timeout: Timeout=None,
+ socket_settings: MaybeDict=None,
+ frame_handler: ConnectionFrameHandler=frame_handler,
+ frame_writer: ConnectionFrameWriter=frame_writer,
+ **kwargs):
"""Create a connection to the specified host, which should be
a 'host[:port]', such as 'localhost', or '1.2.3.4:5672'
(defaults to 'localhost', if a port is not specified then
@@ -257,21 +268,17 @@ class Connection(AbstractChannel):
self.connect_timeout = connect_timeout # type: Timeout
- def __enter__(self):
- # type: () -> Connection
+ def __enter__(self) -> Any:
self.connect()
return self
- def __exit__(self, *eargs):
- # type: (*Any) -> None
+ def __exit__(self, *eargs) -> None:
self.close()
- def then(self, on_success, on_error=None):
- # type: (Thenable, Thenable) -> Thenable
+ def then(self, on_success: Thenable, on_error: Thenable=None) -> Thenable:
return self.on_open.then(on_success, on_error)
- def _setup_listeners(self):
- # type: () -> None
+ def _setup_listeners(self) -> None:
self._callbacks.update({
spec.Connection.Start: self._on_start,
spec.Connection.OpenOk: self._on_open_ok,
@@ -283,8 +290,7 @@ class Connection(AbstractChannel):
spec.Connection.CloseOk: self._on_close_ok,
})
- def connect(self, callback=None):
- # type: (Callable[[], None]) -> None
+ def connect(self, callback: Optional[Callable[[], None]=None) -> None:
# Let the transport.py module setup the actual
# socket connection to the broker.
#
@@ -303,14 +309,12 @@ class Connection(AbstractChannel):
while not self._handshake_complete:
self.drain_events(timeout=self.connect_timeout)
- def _warn_force_connect(self, attr):
- # type: (str) -> None
+ def _warn_force_connect(self, attr: str) -> None:
warnings.warn(AMQPDeprecationWarning(
W_FORCE_CONNECT.format(attr=attr)))
@property
- def transport(self):
- # type: () -> BaseTransport
+ def transport(self) -> BaseTransport:
if self._transport is None:
self._warn_force_connect('transport')
self.connect()
@@ -321,8 +325,7 @@ class Connection(AbstractChannel):
self._transport = transport
@property
- def on_inbound_frame(self):
- # type: () -> ConnectionInboundFrameHandler
+ def on_inbound_frame(self) -> ConnectionInboundFrameHandler:
if self._on_inbound_frame is None:
self._warn_force_connect('on_inbound_frame')
self.connect()
@@ -333,8 +336,7 @@ class Connection(AbstractChannel):
self._on_inbound_frame = on_inbound_frame
@property
- def frame_writer(self):
- # type: () -> Generator[Any]
+ def frame_writer(self) -> Generator[Any, Any, Any]:
if self._frame_writer is None:
self._warn_force_connect('frame_writer')
self.connect()
@@ -344,8 +346,9 @@ class Connection(AbstractChannel):
def frame_writer(self, frame_writer):
self._frame_writer = frame_writer
- def _on_start(self, version_major, version_minor, server_properties,
- mechanisms, locales, argsig='FsSs'):
+ def _on_start(self, version_major: int, version_minor: int,
+ server_properties: Mapping[Any, Any],
+ mechanisms: str, locales: str, argsig: str='FsSs') -> None:
# type: (int, int, Mapping[Any, Any], str, str, str) -> None
client_properties = self.client_properties
self.version_major = version_major
@@ -374,12 +377,11 @@ class Connection(AbstractChannel):
self.login_response, self.locale),
)
- def _on_secure(self, challenge):
- # type: (str) -> None
+ def _on_secure(self, challenge: str) -> None:
pass
- def _on_tune(self, channel_max, frame_max, server_heartbeat, argsig='BlB'):
- # type: (int, int, Timeout, str) -> None
+ def _on_tune(self, channel_max: int, frame_max: int,
+ server_heartbeat: Timeout, argsig: str='BlB'):
client_heartbeat = self.client_heartbeat or 0
self.channel_max = channel_max or self.channel_max
self.frame_max = frame_max or self.frame_max
@@ -402,28 +404,31 @@ class Connection(AbstractChannel):
callback=self._on_tune_sent,
)
- def _on_tune_sent(self, argsig='ssb'):
+ def _on_tune_sent(self, argsig: str='ssb'):
self.send_method(
spec.Connection.Open, argsig, (self.virtual_host, '', False),
)
- def _on_open_ok(self):
+ def _on_open_ok(self) -> None:
self._handshake_complete = True
self.on_open(self)
- def Transport(self, host, connect_timeout,
- ssl=False, read_timeout=None, write_timeout=None,
- socket_settings=None, **kwargs):
+ def Transport(self, host: str, connect_timeout: Timeout,
+ ssl: Optional[SSLArg]=False,
+ read_timeout: Timeout=None,
+ write_timeout: Timeout=None,
+ socket_settings: MaybeDict=None,
+ **kwargs) -> BaseTransport:
return Transport(
host, connect_timeout=connect_timeout, ssl=ssl,
read_timeout=read_timeout, write_timeout=write_timeout,
socket_settings=socket_settings, **kwargs)
@property
- def connected(self):
+ def connected(self) -> bool:
return self._transport and self._transport.connected
- def collect(self):
+ def collect(self) -> None:
try:
self.transport.close()
@@ -435,7 +440,7 @@ class Connection(AbstractChannel):
finally:
self._transport = self.connection = self.channels = None
- def _get_free_channel_id(self):
+ def _get_free_channel_id(self) -> int:
try:
return self._avail_channel_ids.pop()
except IndexError:
@@ -443,14 +448,14 @@ class Connection(AbstractChannel):
'No free channel ids, current={0}, channel_max={1}'.format(
len(self.channels), self.channel_max), spec.Channel.Open)
- def _claim_channel_id(self, channel_id):
+ def _claim_channel_id(self, channel_id) -> None:
try:
- return self._avail_channel_ids.remove(channel_id)
+ self._avail_channel_ids.remove(channel_id)
except ValueError:
raise ConnectionError(
'Channel %r already open' % (channel_id,))
- def channel(self, channel_id=None, callback=None):
+ def channel(self, channel_id: int=None, callback=None):
"""Fetch a Channel object identified by the numeric channel_id, or
create that object if it doesn't already exist."""
if self.channels is not None:
@@ -462,13 +467,13 @@ class Connection(AbstractChannel):
return channel
raise RecoverableConnectionError('Connection already closed.')
- def is_alive(self):
+ def is_alive(self) -> bool:
raise NotImplementedError('Use AMQP heartbeats')
- def drain_events(self, timeout=None):
+ def drain_events(self, timeout: Timeout=None) -> None:
return self.blocking_read(timeout)
- def blocking_read(self, timeout=None):
+ def blocking_read(self, timeout: Timeout=None) -> Frame:
with self.transport.having_timeout(timeout):
frame = self.transport.read_frame()
return self.on_inbound_frame(frame)
diff --git a/amqp/serialization.py b/amqp/serialization.py
index 69cc23a..02b575c 100644
--- a/amqp/serialization.py
+++ b/amqp/serialization.py
@@ -474,7 +474,7 @@ PROPERTY_CLASSES = {
}
-class GenericContent(object):
+class GenericContent:
"""Abstract base class for AMQP content.
Subclasses should override the PROPERTIES attribute.
diff --git a/amqp/tests/test_serialization.py b/amqp/tests/test_serialization.py
index ea912f3..69061c1 100644
--- a/amqp/tests/test_serialization.py
+++ b/amqp/tests/test_serialization.py
@@ -12,7 +12,7 @@ from amqp.serialization import GenericContent, _read_item, dumps, loads
from .case import Case
-class ANY(object):
+class ANY:
def __eq__(self, other):
return other is not None
diff --git a/amqp/tests/test_transport.py b/amqp/tests/test_transport.py
index dd09568..279208b 100644
--- a/amqp/tests/test_transport.py
+++ b/amqp/tests/test_transport.py
@@ -11,7 +11,7 @@ from amqp.exceptions import UnexpectedFrame
from .case import Case, Mock, patch
-class MockSocket(object):
+class MockSocket:
options = {}
def setsockopt(self, family, key, value):
diff --git a/amqp/transport.py b/amqp/transport.py
index 3203ee7..c2e50c1 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -24,9 +24,13 @@ import ssl
from collections import namedtuple
from contextlib import contextmanager
from struct import unpack
+from typing import (
+ Any, AnyStr, ByteString, Callable, IO, Mapping, Optional, Set, Tuple,
+)
from .exceptions import UnexpectedFrame
from .five import items
+from .types import SSLArg, MaybeDict, Timeout
from .utils import get_errno, set_cloexec
# Jython does not have this attribute
@@ -67,7 +71,9 @@ TCP_OPTS = [getattr(socket, opt)
Frame = namedtuple('Frame', ('tuple', 'channel', 'data'))
-from .types import SSLArg, MaybeDict, Timeout
+StructUnpackT = Callable[
+ [AnyStr, ByteString], Tuple[Any]
+]
def to_host_port(host, default=AMQP_PORT):
@@ -85,14 +91,14 @@ def to_host_port(host, default=AMQP_PORT):
return host, port
-class BaseTransport(object):
+class BaseTransport:
"""Common superclass for TCP and SSL transports"""
connected = False
- def __init__(self, host, connect_timeout=None,
- read_timeout=None, write_timeout=None,
- socket_settings=None, raise_on_initial_eintr=True, **kwargs):
- # type: (str, Timeout, Timeout, Timeout, MaybeDict, bool, **Any) -> None
+ def __init__(self, host: str, connect_timeout: Timeout=None,
+ read_timeout: Timeout=None, write_timeout: Timeout=None,
+ socket_settings: MaybeDict=None,
+ raise_on_initial_eintr: bool=True, **kwargs) -> None:
self.connected = True # type: bool
self.sock = None # type: Socket
self._read_buffer = EMPTY_BUFFER # type: ByteString
@@ -103,16 +109,14 @@ class BaseTransport(object):
self.socket_settings = socket_settings # type: MaybeDict
self.raise_on_initial_eintr = raise_on_initial_eintr # type: bool
- def connect(self):
- # type: () -> None
+ def connect(self) -> None:
self._connect(self.host, self.port, self.connect_timeout)
self._init_socket(
self.socket_settings, self.read_timeout, self.write_timeout,
)
@contextmanager
- def having_timeout(self, timeout):
- # type: (Timeout) -> ContextManager
+ def having_timeout(self, timeout: Timeout) -> Any:
if timeout is None:
yield self.sock
else:
@@ -134,7 +138,7 @@ class BaseTransport(object):
if timeout != prev:
sock.settimeout(prev)
- def __del__(self):
+ def __del__(self) -> None:
try:
# socket module may have been collected by gc
# if this is called by a thread at shutdown.
@@ -146,8 +150,7 @@ class BaseTransport(object):
finally:
self.sock = None
- def _connect(self, host, port, timeout):
- # type: (str, int, Timeout) -> None
+ def _connect(self, host: str, port: int, timeout: Timeout) -> None:
entries = socket.getaddrinfo(
host, port, 0, socket.SOCK_STREAM, SOL_TCP,
)
@@ -169,8 +172,8 @@ class BaseTransport(object):
else:
break
- def _init_socket(self, socket_settings, read_timeout, write_timeout):
- # type: (MaybeDict, Timeout, Timeout) -> None
+ def _init_socket(self, socket_settings: MaybeDict,
+ read_timeout: Timeout, write_timeout: Timeout) -> None:
try:
self.sock.settimeout(None) # set socket back to blocking mode
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
@@ -192,14 +195,13 @@ class BaseTransport(object):
self.connected = False
raise
- def _get_tcp_socket_defaults(self, sock):
- # type: (Socket) -> Mapping[AnyStr, Any]
+ def _get_tcp_socket_defaults(
+ self, sock: socket.socket) -> Mapping[AnyStr, Any]:
return {
opt: sock.getsockopt(SOL_TCP, opt) for opt in TCP_OPTS
}
- def _set_socket_options(self, socket_settings):
- # type: (MaybeDict) -> None
+ def _set_socket_options(self, socket_settings: MaybeDict) -> None:
if not socket_settings:
self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1)
return
@@ -211,29 +213,24 @@ class BaseTransport(object):
for opt, val in items(tcp_opts):
self.sock.setsockopt(SOL_TCP, opt, val)
- def _read(self, n, initial=False):
- # type: (int, bool) -> ByteString
+ def _read(self, n: int, initial: bool=False) -> ByteString:
"""Read exactly n BytesString from the peer"""
raise NotImplementedError('Must be overriden in subclass')
- def _setup_transport(self):
- # type: () -> None
+ def _setup_transport(self) -> None:
"""Do any additional initialization of the class (used
by the subclasses)."""
pass
- def _shutdown_transport(self):
- # type: () -> None
+ def _shutdown_transport(self) -> None:
"""Do any preliminary work in shutting down the connection."""
pass
- def _write(self, s):
- # type: (ByteString) -> int
+ def _write(self, s: ByteString) -> int:
"""Completely write a string to the peer."""
raise NotImplementedError('Must be overriden in subclass')
- def close(self):
- # type: () -> None
+ def close(self) -> None:
if self.sock is not None:
self._shutdown_transport()
# Call shutdown first to make sure that pending messages
@@ -244,8 +241,7 @@ class BaseTransport(object):
self.sock = None
self.connected = False
- def read_frame(self, unpack=unpack):
- # type: (Callable[AnyStr, ByteString]) -> Frame
+ def read_frame(self, unpack: StructUnpackT=unpack) -> Frame:
read = self._read
read_frame_buffer = EMPTY_BUFFER
try:
@@ -279,8 +275,7 @@ class BaseTransport(object):
raise UnexpectedFrame(
'Received {0:#04x} while expecting 0xce'.format(ch))
- def write(self, s):
- # type: (ByteString) -> None
+ def write(self, s: ByteString) -> None:
try:
self._write(s)
except socket.timeout:
@@ -294,35 +289,34 @@ class BaseTransport(object):
class SSLTransport(BaseTransport):
"""Transport that works over SSL"""
- def __init__(self, host, connect_timeout=None, ssl=None, **kwargs):
- # type: (str, float, Optional[SSLArg], **Any) -> None
+ def __init__(self, host: str, connect_timeout: Timeout=None,
+ ssl: Optional[SSLArg]=None, **kwargs) -> None:
if isinstance(ssl, dict):
self.sslopts = ssl # type: Dict[AnyStr, Any]
self._read_buffer = EMPTY_BUFFER # type: ByteString
super(SSLTransport, self).__init__(
host, connect_timeout=connect_timeout, **kwargs)
- def _setup_transport(self):
- # type: () -> None
+ def _setup_transport(self) -> None:
"""Wrap the socket in an SSL object."""
self.sock = self._wrap_socket(self.sock, **self.sslopts or {})
self.sock.do_handshake()
self._quick_recv = self.sock.read
- def _wrap_socket(self, sock, context=None, **sslopts):
- # type: (Socket, MaybeDict, **Any) -> Socket
+ def _wrap_socket(self, sock: socket.socket,
+ context: MaybeDict=None, **sslopts) -> socket.socket:
if context:
return self._wrap_context(sock, sslopts, **context)
return ssl.wrap_socket(sock, **sslopts)
- def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options):
- # type: (Socket, Dict[AnyStr, Any], bool, **Any) -> Socket
+ def _wrap_context(self, sock: socket.socket, sslopts: Mapping[AnyStr, Any],
+ check_hostname: bool=None,
+ **ctx_options) -> socket.socket:
ctx = ssl.create_default_context(**ctx_options)
ctx.check_hostname = check_hostname
return ctx.wrap_socket(sock, **sslopts)
- def _shutdown_transport(self):
- # type: () -> None
+ def _shutdown_transport(self) -> None:
"""Unwrap a Python 2.6 SSL socket, so we can call shutdown()"""
if self.sock is not None:
try:
@@ -331,9 +325,10 @@ class SSLTransport(BaseTransport):
return
self.sock = unwrap()
- def _read(self, n, initial=False,
- _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)):
- # type: (int, bool, Set[int]) -> ByteString
+ def _read(self, n: int,
+ initial: bool=False,
+ _errnos: Set[int]=(errno.ENOENT, errno.EAGAIN, errno.EINTR)
+ ) -> ByteString:
# According to SSL_read(3), it can at most return 16kb of data.
# Thus, we use an internal read buffer like TCPTransport._read
@@ -361,8 +356,7 @@ class SSLTransport(BaseTransport):
result, self._read_buffer = rbuf[:n], rbuf[n:]
return result
- def _write(self, s):
- # type: (str) -> None
+ def _write(self, s: str) -> None:
"""Write a string out to the SSL socket fully."""
write = self.sock.write # type: Callable[[ByteString], int]
while s:
@@ -385,16 +379,16 @@ class SSLTransport(BaseTransport):
class TCPTransport(BaseTransport):
"""Transport that deals directly with TCP socket."""
- def _setup_transport(self):
- # type: () -> None
+ def _setup_transport(self) -> None:
"""Setup to _write() directly to the socket, and
do our own buffered reads."""
self._write = self.sock.sendall # type: Callable[[ByteString], int]
self._read_buffer = EMPTY_BUFFER # type: ByteString
self._quick_recv = self.sock.recv # type: Callable[[int], ByteString]
- def _read(self, n, initial=False, _errnos={errno.EAGAIN, errno.EINTR}):
- # type: (int, bool, Set[int]) -> ByteString
+ def _read(self, n: int,
+ initial: bool=False,
+ _errnos: Set[int]={errno.EAGAIN, errno.EINTR}) -> ByteString:
"""Read exactly n bytes from the socket"""
recv = self._quick_recv
rbuf = self._read_buffer
@@ -419,8 +413,9 @@ class TCPTransport(BaseTransport):
return result
-def Transport(host, connect_timeout=None, ssl=False, **kwargs):
- # type: (str, Timeout, Optional[SSLArg], **Any) -> BaseTransport
+def Transport(host: str,
+ connect_timeout: Timeout=None,
+ ssl: Optional[SSLArg]=False, **kwargs) -> BaseTransport:
"""Given a few parameters from the Connection constructor,
select and create a subclass of BaseTransport."""
transport = SSLTransport if ssl else TCPTransport
diff --git a/amqp/types.py b/amqp/types.py
new file mode 100644
index 0000000..c811f92
--- /dev/null
+++ b/amqp/types.py
@@ -0,0 +1,5 @@
+from typing import Any, AnyStr, Mapping, Optional, Union
+
+SSLArg = Union[Mapping[AnyStr, Any], bool]
+MaybeDict = Optional[Mapping[AnyStr, Any]]
+Timeout = Optional[float]