diff options
author | Ask Solem <ask@celeryproject.org> | 2016-06-17 10:49:35 -0700 |
---|---|---|
committer | Ask Solem <ask@celeryproject.org> | 2016-06-17 10:49:35 -0700 |
commit | deb91fec4a6ecdfa4daf655e34dba3d40649e98b (patch) | |
tree | df415a98a05cd81ec4bc84bf546a3d94f1a61c8e | |
parent | 8b5789948873215f80fbf3d0f29a963b7a76912b (diff) | |
download | py-amqp-typing.tar.gz |
Typing experimentstyping
-rw-r--r-- | amqp/abstract_channel.py | 2 | ||||
-rw-r--r-- | amqp/connection.py | 111 | ||||
-rw-r--r-- | amqp/serialization.py | 2 | ||||
-rw-r--r-- | amqp/tests/test_serialization.py | 2 | ||||
-rw-r--r-- | amqp/tests/test_transport.py | 2 | ||||
-rw-r--r-- | amqp/transport.py | 103 | ||||
-rw-r--r-- | amqp/types.py | 5 |
7 files changed, 116 insertions, 111 deletions
diff --git a/amqp/abstract_channel.py b/amqp/abstract_channel.py index da60769..ede206b 100644 --- a/amqp/abstract_channel.py +++ b/amqp/abstract_channel.py @@ -25,7 +25,7 @@ from .serialization import dumps, loads __all__ = ['AbstractChannel'] -class AbstractChannel(object): +class AbstractChannel: """Superclass for both the Connection, which is treated as channel 0, and other user-created Channel objects. diff --git a/amqp/connection.py b/amqp/connection.py index a000bfc..000b52d 100644 --- a/amqp/connection.py +++ b/amqp/connection.py @@ -38,7 +38,7 @@ from .exceptions import ( from .five import array, range, values, monotonic from .method_framing import frame_handler, frame_writer from .serialization import _write_table -from .transport import Transport +from .transport import BaseTransport, Transport from .types import SSLArg, Timeout, MaybeDict from typing import Mapping, Generator, Callable, List, Any @@ -78,16 +78,17 @@ AMQP_LOGGER = logging.getLogger('amqp') ConnectionBlockedCallback = typing.Callable[[str], None] ConnectionUnblockedCallback = typing.Callable[[], None] ConnectionInboundMethodHandler = typing.Callable[ - [int, spec.method_sig_t, typing.ByteString, typing.ByteString], typing.Any, + [int, typing.Tuple[Any], typing.ByteString, typing.ByteString], typing.Any, ] ConnectionFrameHandler = typing.Callable[ ['Connection', ConnectionInboundMethodHandler], - typing.Generator[typing.Any], + typing.Generator[typing.Any, typing.Any, typing.Any], ] ConnectionFrameWriter = typing.Callable[ - ['Connection', BaseTransport], typing.Generator[typing.Any], + ['Connection', BaseTransport], typing.Generator[typing.Any, typing.Any, + typing.Any], ] -MethodSigMethodMapping = Mapping[spec.method_sig_t, spec.method] +MethodSigMethodMapping = Mapping[typing.Tuple[Any], typing.Tuple[Any]] class Connection(AbstractChannel): @@ -168,16 +169,26 @@ class Connection(AbstractChannel): RecoverableChannelError, ) - def __init__(self, host='localhost:5672', userid='guest', password='guest', - login_method='AMQPLAIN', login_response=None, - virtual_host='/', locale='en_US', client_properties=None, - ssl=False, connect_timeout=None, channel_max=None, - frame_max=None, heartbeat=0, on_open=None, on_blocked=None, - on_unblocked=None, confirm_publish=False, - on_tune_ok=None, read_timeout=None, write_timeout=None, - socket_settings=None, frame_handler=frame_handler, - frame_writer=frame_writer, **kwargs): - # type: (str, str, str, str, Any, str, str, Mapping, SSLArg, Timeout, int, int, Timeout, Thenable, ConnectionBlockedCallback, ConnectionUnblockedCallback, bool, ConnectionTuneOkCallback, Timeout, Timeout, MaybeDict, ConnectionFrameHandler, ConnectionFrameWriter, **Any) -> None + def __init__(self, host: str='localhost:5672', + userid: str='guest', password: str='guest', + login_method: str='AMQPLAIN', login_response: Any=None, + virtual_host: str='/', locale: str='en_US', + client_properties: MaybeDict=None, + ssl: Optional[SSLArg]=False, connect_timeout: Timeout=None, + channel_max: Optional[int]=None, + frame_max: Optional[int]=None, + heartbeat: Timeout=0, + on_open: Thenable=None, + on_blocked: Optional[ConnectionBlockedCallback]=None, + on_unblocked: Optional[ConnectionUnblockedCallback]=None, + bool: confirm_publish=False, + on_tune_ok: Optional[ConnectionTuneOkCallback]=None, + read_timeout: Timeout=None, + write_timeout: Timeout=None, + socket_settings: MaybeDict=None, + frame_handler: ConnectionFrameHandler=frame_handler, + frame_writer: ConnectionFrameWriter=frame_writer, + **kwargs): """Create a connection to the specified host, which should be a 'host[:port]', such as 'localhost', or '1.2.3.4:5672' (defaults to 'localhost', if a port is not specified then @@ -257,21 +268,17 @@ class Connection(AbstractChannel): self.connect_timeout = connect_timeout # type: Timeout - def __enter__(self): - # type: () -> Connection + def __enter__(self) -> Any: self.connect() return self - def __exit__(self, *eargs): - # type: (*Any) -> None + def __exit__(self, *eargs) -> None: self.close() - def then(self, on_success, on_error=None): - # type: (Thenable, Thenable) -> Thenable + def then(self, on_success: Thenable, on_error: Thenable=None) -> Thenable: return self.on_open.then(on_success, on_error) - def _setup_listeners(self): - # type: () -> None + def _setup_listeners(self) -> None: self._callbacks.update({ spec.Connection.Start: self._on_start, spec.Connection.OpenOk: self._on_open_ok, @@ -283,8 +290,7 @@ class Connection(AbstractChannel): spec.Connection.CloseOk: self._on_close_ok, }) - def connect(self, callback=None): - # type: (Callable[[], None]) -> None + def connect(self, callback: Optional[Callable[[], None]=None) -> None: # Let the transport.py module setup the actual # socket connection to the broker. # @@ -303,14 +309,12 @@ class Connection(AbstractChannel): while not self._handshake_complete: self.drain_events(timeout=self.connect_timeout) - def _warn_force_connect(self, attr): - # type: (str) -> None + def _warn_force_connect(self, attr: str) -> None: warnings.warn(AMQPDeprecationWarning( W_FORCE_CONNECT.format(attr=attr))) @property - def transport(self): - # type: () -> BaseTransport + def transport(self) -> BaseTransport: if self._transport is None: self._warn_force_connect('transport') self.connect() @@ -321,8 +325,7 @@ class Connection(AbstractChannel): self._transport = transport @property - def on_inbound_frame(self): - # type: () -> ConnectionInboundFrameHandler + def on_inbound_frame(self) -> ConnectionInboundFrameHandler: if self._on_inbound_frame is None: self._warn_force_connect('on_inbound_frame') self.connect() @@ -333,8 +336,7 @@ class Connection(AbstractChannel): self._on_inbound_frame = on_inbound_frame @property - def frame_writer(self): - # type: () -> Generator[Any] + def frame_writer(self) -> Generator[Any, Any, Any]: if self._frame_writer is None: self._warn_force_connect('frame_writer') self.connect() @@ -344,8 +346,9 @@ class Connection(AbstractChannel): def frame_writer(self, frame_writer): self._frame_writer = frame_writer - def _on_start(self, version_major, version_minor, server_properties, - mechanisms, locales, argsig='FsSs'): + def _on_start(self, version_major: int, version_minor: int, + server_properties: Mapping[Any, Any], + mechanisms: str, locales: str, argsig: str='FsSs') -> None: # type: (int, int, Mapping[Any, Any], str, str, str) -> None client_properties = self.client_properties self.version_major = version_major @@ -374,12 +377,11 @@ class Connection(AbstractChannel): self.login_response, self.locale), ) - def _on_secure(self, challenge): - # type: (str) -> None + def _on_secure(self, challenge: str) -> None: pass - def _on_tune(self, channel_max, frame_max, server_heartbeat, argsig='BlB'): - # type: (int, int, Timeout, str) -> None + def _on_tune(self, channel_max: int, frame_max: int, + server_heartbeat: Timeout, argsig: str='BlB'): client_heartbeat = self.client_heartbeat or 0 self.channel_max = channel_max or self.channel_max self.frame_max = frame_max or self.frame_max @@ -402,28 +404,31 @@ class Connection(AbstractChannel): callback=self._on_tune_sent, ) - def _on_tune_sent(self, argsig='ssb'): + def _on_tune_sent(self, argsig: str='ssb'): self.send_method( spec.Connection.Open, argsig, (self.virtual_host, '', False), ) - def _on_open_ok(self): + def _on_open_ok(self) -> None: self._handshake_complete = True self.on_open(self) - def Transport(self, host, connect_timeout, - ssl=False, read_timeout=None, write_timeout=None, - socket_settings=None, **kwargs): + def Transport(self, host: str, connect_timeout: Timeout, + ssl: Optional[SSLArg]=False, + read_timeout: Timeout=None, + write_timeout: Timeout=None, + socket_settings: MaybeDict=None, + **kwargs) -> BaseTransport: return Transport( host, connect_timeout=connect_timeout, ssl=ssl, read_timeout=read_timeout, write_timeout=write_timeout, socket_settings=socket_settings, **kwargs) @property - def connected(self): + def connected(self) -> bool: return self._transport and self._transport.connected - def collect(self): + def collect(self) -> None: try: self.transport.close() @@ -435,7 +440,7 @@ class Connection(AbstractChannel): finally: self._transport = self.connection = self.channels = None - def _get_free_channel_id(self): + def _get_free_channel_id(self) -> int: try: return self._avail_channel_ids.pop() except IndexError: @@ -443,14 +448,14 @@ class Connection(AbstractChannel): 'No free channel ids, current={0}, channel_max={1}'.format( len(self.channels), self.channel_max), spec.Channel.Open) - def _claim_channel_id(self, channel_id): + def _claim_channel_id(self, channel_id) -> None: try: - return self._avail_channel_ids.remove(channel_id) + self._avail_channel_ids.remove(channel_id) except ValueError: raise ConnectionError( 'Channel %r already open' % (channel_id,)) - def channel(self, channel_id=None, callback=None): + def channel(self, channel_id: int=None, callback=None): """Fetch a Channel object identified by the numeric channel_id, or create that object if it doesn't already exist.""" if self.channels is not None: @@ -462,13 +467,13 @@ class Connection(AbstractChannel): return channel raise RecoverableConnectionError('Connection already closed.') - def is_alive(self): + def is_alive(self) -> bool: raise NotImplementedError('Use AMQP heartbeats') - def drain_events(self, timeout=None): + def drain_events(self, timeout: Timeout=None) -> None: return self.blocking_read(timeout) - def blocking_read(self, timeout=None): + def blocking_read(self, timeout: Timeout=None) -> Frame: with self.transport.having_timeout(timeout): frame = self.transport.read_frame() return self.on_inbound_frame(frame) diff --git a/amqp/serialization.py b/amqp/serialization.py index 69cc23a..02b575c 100644 --- a/amqp/serialization.py +++ b/amqp/serialization.py @@ -474,7 +474,7 @@ PROPERTY_CLASSES = { } -class GenericContent(object): +class GenericContent: """Abstract base class for AMQP content. Subclasses should override the PROPERTIES attribute. diff --git a/amqp/tests/test_serialization.py b/amqp/tests/test_serialization.py index ea912f3..69061c1 100644 --- a/amqp/tests/test_serialization.py +++ b/amqp/tests/test_serialization.py @@ -12,7 +12,7 @@ from amqp.serialization import GenericContent, _read_item, dumps, loads from .case import Case -class ANY(object): +class ANY: def __eq__(self, other): return other is not None diff --git a/amqp/tests/test_transport.py b/amqp/tests/test_transport.py index dd09568..279208b 100644 --- a/amqp/tests/test_transport.py +++ b/amqp/tests/test_transport.py @@ -11,7 +11,7 @@ from amqp.exceptions import UnexpectedFrame from .case import Case, Mock, patch -class MockSocket(object): +class MockSocket: options = {} def setsockopt(self, family, key, value): diff --git a/amqp/transport.py b/amqp/transport.py index 3203ee7..c2e50c1 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -24,9 +24,13 @@ import ssl from collections import namedtuple from contextlib import contextmanager from struct import unpack +from typing import ( + Any, AnyStr, ByteString, Callable, IO, Mapping, Optional, Set, Tuple, +) from .exceptions import UnexpectedFrame from .five import items +from .types import SSLArg, MaybeDict, Timeout from .utils import get_errno, set_cloexec # Jython does not have this attribute @@ -67,7 +71,9 @@ TCP_OPTS = [getattr(socket, opt) Frame = namedtuple('Frame', ('tuple', 'channel', 'data')) -from .types import SSLArg, MaybeDict, Timeout +StructUnpackT = Callable[ + [AnyStr, ByteString], Tuple[Any] +] def to_host_port(host, default=AMQP_PORT): @@ -85,14 +91,14 @@ def to_host_port(host, default=AMQP_PORT): return host, port -class BaseTransport(object): +class BaseTransport: """Common superclass for TCP and SSL transports""" connected = False - def __init__(self, host, connect_timeout=None, - read_timeout=None, write_timeout=None, - socket_settings=None, raise_on_initial_eintr=True, **kwargs): - # type: (str, Timeout, Timeout, Timeout, MaybeDict, bool, **Any) -> None + def __init__(self, host: str, connect_timeout: Timeout=None, + read_timeout: Timeout=None, write_timeout: Timeout=None, + socket_settings: MaybeDict=None, + raise_on_initial_eintr: bool=True, **kwargs) -> None: self.connected = True # type: bool self.sock = None # type: Socket self._read_buffer = EMPTY_BUFFER # type: ByteString @@ -103,16 +109,14 @@ class BaseTransport(object): self.socket_settings = socket_settings # type: MaybeDict self.raise_on_initial_eintr = raise_on_initial_eintr # type: bool - def connect(self): - # type: () -> None + def connect(self) -> None: self._connect(self.host, self.port, self.connect_timeout) self._init_socket( self.socket_settings, self.read_timeout, self.write_timeout, ) @contextmanager - def having_timeout(self, timeout): - # type: (Timeout) -> ContextManager + def having_timeout(self, timeout: Timeout) -> Any: if timeout is None: yield self.sock else: @@ -134,7 +138,7 @@ class BaseTransport(object): if timeout != prev: sock.settimeout(prev) - def __del__(self): + def __del__(self) -> None: try: # socket module may have been collected by gc # if this is called by a thread at shutdown. @@ -146,8 +150,7 @@ class BaseTransport(object): finally: self.sock = None - def _connect(self, host, port, timeout): - # type: (str, int, Timeout) -> None + def _connect(self, host: str, port: int, timeout: Timeout) -> None: entries = socket.getaddrinfo( host, port, 0, socket.SOCK_STREAM, SOL_TCP, ) @@ -169,8 +172,8 @@ class BaseTransport(object): else: break - def _init_socket(self, socket_settings, read_timeout, write_timeout): - # type: (MaybeDict, Timeout, Timeout) -> None + def _init_socket(self, socket_settings: MaybeDict, + read_timeout: Timeout, write_timeout: Timeout) -> None: try: self.sock.settimeout(None) # set socket back to blocking mode self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) @@ -192,14 +195,13 @@ class BaseTransport(object): self.connected = False raise - def _get_tcp_socket_defaults(self, sock): - # type: (Socket) -> Mapping[AnyStr, Any] + def _get_tcp_socket_defaults( + self, sock: socket.socket) -> Mapping[AnyStr, Any]: return { opt: sock.getsockopt(SOL_TCP, opt) for opt in TCP_OPTS } - def _set_socket_options(self, socket_settings): - # type: (MaybeDict) -> None + def _set_socket_options(self, socket_settings: MaybeDict) -> None: if not socket_settings: self.sock.setsockopt(SOL_TCP, socket.TCP_NODELAY, 1) return @@ -211,29 +213,24 @@ class BaseTransport(object): for opt, val in items(tcp_opts): self.sock.setsockopt(SOL_TCP, opt, val) - def _read(self, n, initial=False): - # type: (int, bool) -> ByteString + def _read(self, n: int, initial: bool=False) -> ByteString: """Read exactly n BytesString from the peer""" raise NotImplementedError('Must be overriden in subclass') - def _setup_transport(self): - # type: () -> None + def _setup_transport(self) -> None: """Do any additional initialization of the class (used by the subclasses).""" pass - def _shutdown_transport(self): - # type: () -> None + def _shutdown_transport(self) -> None: """Do any preliminary work in shutting down the connection.""" pass - def _write(self, s): - # type: (ByteString) -> int + def _write(self, s: ByteString) -> int: """Completely write a string to the peer.""" raise NotImplementedError('Must be overriden in subclass') - def close(self): - # type: () -> None + def close(self) -> None: if self.sock is not None: self._shutdown_transport() # Call shutdown first to make sure that pending messages @@ -244,8 +241,7 @@ class BaseTransport(object): self.sock = None self.connected = False - def read_frame(self, unpack=unpack): - # type: (Callable[AnyStr, ByteString]) -> Frame + def read_frame(self, unpack: StructUnpackT=unpack) -> Frame: read = self._read read_frame_buffer = EMPTY_BUFFER try: @@ -279,8 +275,7 @@ class BaseTransport(object): raise UnexpectedFrame( 'Received {0:#04x} while expecting 0xce'.format(ch)) - def write(self, s): - # type: (ByteString) -> None + def write(self, s: ByteString) -> None: try: self._write(s) except socket.timeout: @@ -294,35 +289,34 @@ class BaseTransport(object): class SSLTransport(BaseTransport): """Transport that works over SSL""" - def __init__(self, host, connect_timeout=None, ssl=None, **kwargs): - # type: (str, float, Optional[SSLArg], **Any) -> None + def __init__(self, host: str, connect_timeout: Timeout=None, + ssl: Optional[SSLArg]=None, **kwargs) -> None: if isinstance(ssl, dict): self.sslopts = ssl # type: Dict[AnyStr, Any] self._read_buffer = EMPTY_BUFFER # type: ByteString super(SSLTransport, self).__init__( host, connect_timeout=connect_timeout, **kwargs) - def _setup_transport(self): - # type: () -> None + def _setup_transport(self) -> None: """Wrap the socket in an SSL object.""" self.sock = self._wrap_socket(self.sock, **self.sslopts or {}) self.sock.do_handshake() self._quick_recv = self.sock.read - def _wrap_socket(self, sock, context=None, **sslopts): - # type: (Socket, MaybeDict, **Any) -> Socket + def _wrap_socket(self, sock: socket.socket, + context: MaybeDict=None, **sslopts) -> socket.socket: if context: return self._wrap_context(sock, sslopts, **context) return ssl.wrap_socket(sock, **sslopts) - def _wrap_context(self, sock, sslopts, check_hostname=None, **ctx_options): - # type: (Socket, Dict[AnyStr, Any], bool, **Any) -> Socket + def _wrap_context(self, sock: socket.socket, sslopts: Mapping[AnyStr, Any], + check_hostname: bool=None, + **ctx_options) -> socket.socket: ctx = ssl.create_default_context(**ctx_options) ctx.check_hostname = check_hostname return ctx.wrap_socket(sock, **sslopts) - def _shutdown_transport(self): - # type: () -> None + def _shutdown_transport(self) -> None: """Unwrap a Python 2.6 SSL socket, so we can call shutdown()""" if self.sock is not None: try: @@ -331,9 +325,10 @@ class SSLTransport(BaseTransport): return self.sock = unwrap() - def _read(self, n, initial=False, - _errnos=(errno.ENOENT, errno.EAGAIN, errno.EINTR)): - # type: (int, bool, Set[int]) -> ByteString + def _read(self, n: int, + initial: bool=False, + _errnos: Set[int]=(errno.ENOENT, errno.EAGAIN, errno.EINTR) + ) -> ByteString: # According to SSL_read(3), it can at most return 16kb of data. # Thus, we use an internal read buffer like TCPTransport._read @@ -361,8 +356,7 @@ class SSLTransport(BaseTransport): result, self._read_buffer = rbuf[:n], rbuf[n:] return result - def _write(self, s): - # type: (str) -> None + def _write(self, s: str) -> None: """Write a string out to the SSL socket fully.""" write = self.sock.write # type: Callable[[ByteString], int] while s: @@ -385,16 +379,16 @@ class SSLTransport(BaseTransport): class TCPTransport(BaseTransport): """Transport that deals directly with TCP socket.""" - def _setup_transport(self): - # type: () -> None + def _setup_transport(self) -> None: """Setup to _write() directly to the socket, and do our own buffered reads.""" self._write = self.sock.sendall # type: Callable[[ByteString], int] self._read_buffer = EMPTY_BUFFER # type: ByteString self._quick_recv = self.sock.recv # type: Callable[[int], ByteString] - def _read(self, n, initial=False, _errnos={errno.EAGAIN, errno.EINTR}): - # type: (int, bool, Set[int]) -> ByteString + def _read(self, n: int, + initial: bool=False, + _errnos: Set[int]={errno.EAGAIN, errno.EINTR}) -> ByteString: """Read exactly n bytes from the socket""" recv = self._quick_recv rbuf = self._read_buffer @@ -419,8 +413,9 @@ class TCPTransport(BaseTransport): return result -def Transport(host, connect_timeout=None, ssl=False, **kwargs): - # type: (str, Timeout, Optional[SSLArg], **Any) -> BaseTransport +def Transport(host: str, + connect_timeout: Timeout=None, + ssl: Optional[SSLArg]=False, **kwargs) -> BaseTransport: """Given a few parameters from the Connection constructor, select and create a subclass of BaseTransport.""" transport = SSLTransport if ssl else TCPTransport diff --git a/amqp/types.py b/amqp/types.py new file mode 100644 index 0000000..c811f92 --- /dev/null +++ b/amqp/types.py @@ -0,0 +1,5 @@ +from typing import Any, AnyStr, Mapping, Optional, Union + +SSLArg = Union[Mapping[AnyStr, Any], bool] +MaybeDict = Optional[Mapping[AnyStr, Any]] +Timeout = Optional[float] |