From 7e6d76887676303977c3e576f66868795b3992e6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Jan 2014 09:54:06 -0800 Subject: Refactor drain logic in streams.py to be reusable. --- asyncio/streams.py | 97 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 61 insertions(+), 36 deletions(-) (limited to 'asyncio/streams.py') diff --git a/asyncio/streams.py b/asyncio/streams.py index 10d3591..bd77cab 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -94,8 +94,63 @@ def start_server(client_connected_cb, host=None, port=None, *, return (yield from loop.create_server(factory, host, port, **kwds)) -class StreamReaderProtocol(protocols.Protocol): - """Trivial helper class to adapt between Protocol and StreamReader. +class FlowControlMixin(protocols.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + + This implements the protocol methods pause_writing(), + resume_reading() and connection_lost(). If the subclass overrides + these it must call the super methods. + + StreamWriter.drain() must check for error conditions and then call + _make_drain_waiter(), which will return either () or a Future + depending on the paused state. + """ + + def __init__(self, loop=None): + self._loop = loop # May be None; we may never need it. + self._paused = False + self._drain_waiter = None + + def pause_writing(self): + assert not self._paused + self._paused = True + + def resume_writing(self): + assert self._paused + self._paused = False + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + def _make_drain_waiter(self): + if not self._paused: + return () + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._drain_waiter = waiter + return waiter + + +class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. (This is a helper class instead of making StreamReader itself a Protocol subclass, because the StreamReader has other potential @@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol): """ def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) self._stream_reader = stream_reader self._stream_writer = None - self._drain_waiter = None - self._paused = False self._client_connected_cb = client_connected_cb - self._loop = loop # May be None; we may never need it. def connection_made(self, transport): self._stream_reader.set_transport(transport) @@ -127,16 +180,7 @@ class StreamReaderProtocol(protocols.Protocol): self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) - # Also wake up the writing side. - if self._paused: - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + super().connection_lost(exc) def data_received(self, data): self._stream_reader.feed_data(data) @@ -144,19 +188,6 @@ class StreamReaderProtocol(protocols.Protocol): def eof_received(self): self._stream_reader.feed_eof() - def pause_writing(self): - assert not self._paused - self._paused = True - - def resume_writing(self): - assert self._paused - self._paused = False - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - waiter.set_result(None) - class StreamWriter: """Wraps a Transport. @@ -211,17 +242,11 @@ class StreamWriter: completed, which will happen when the buffer is (partially) drained and the protocol is resumed. """ - if self._reader._exception is not None: + if self._reader is not None and self._reader._exception is not None: raise self._reader._exception if self._transport._conn_lost: # Uses private variable. raise ConnectionResetError('Connection lost') - if not self._protocol._paused: - return () - waiter = self._protocol._drain_waiter - assert waiter is None or waiter.cancelled() - waiter = futures.Future(loop=self._loop) - self._protocol._drain_waiter = waiter - return waiter + return self._protocol._make_drain_waiter() class StreamReader: -- cgit v1.2.1