summaryrefslogtreecommitdiff
path: root/asyncio/streams.py
diff options
context:
space:
mode:
authorGuido van Rossum <guido@python.org>2014-01-28 09:54:06 -0800
committerGuido van Rossum <guido@python.org>2014-01-28 09:54:06 -0800
commit7e6d76887676303977c3e576f66868795b3992e6 (patch)
treedf066df07b7d205b0879a3d9f23b4f62c419a140 /asyncio/streams.py
parent200e887a1430f6bce7f1a9ea324966ca58004eb9 (diff)
downloadtrollius-7e6d76887676303977c3e576f66868795b3992e6.tar.gz
Refactor drain logic in streams.py to be reusable.
Diffstat (limited to 'asyncio/streams.py')
-rw-r--r--asyncio/streams.py97
1 files changed, 61 insertions, 36 deletions
diff --git a/asyncio/streams.py b/asyncio/streams.py
index 10d3591..bd77cab 100644
--- a/asyncio/streams.py
+++ b/asyncio/streams.py
@@ -94,8 +94,63 @@ def start_server(client_connected_cb, host=None, port=None, *,
return (yield from loop.create_server(factory, host, port, **kwds))
-class StreamReaderProtocol(protocols.Protocol):
- """Trivial helper class to adapt between Protocol and StreamReader.
+class FlowControlMixin(protocols.Protocol):
+ """Reusable flow control logic for StreamWriter.drain().
+
+ This implements the protocol methods pause_writing(),
+ resume_reading() and connection_lost(). If the subclass overrides
+ these it must call the super methods.
+
+ StreamWriter.drain() must check for error conditions and then call
+ _make_drain_waiter(), which will return either () or a Future
+ depending on the paused state.
+ """
+
+ def __init__(self, loop=None):
+ self._loop = loop # May be None; we may never need it.
+ self._paused = False
+ self._drain_waiter = None
+
+ def pause_writing(self):
+ assert not self._paused
+ self._paused = True
+
+ def resume_writing(self):
+ assert self._paused
+ self._paused = False
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+ def connection_lost(self, exc):
+ # Wake up the writer if currently paused.
+ if not self._paused:
+ return
+ waiter = self._drain_waiter
+ if waiter is None:
+ return
+ self._drain_waiter = None
+ if waiter.done():
+ return
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
+
+ def _make_drain_waiter(self):
+ if not self._paused:
+ return ()
+ waiter = self._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = futures.Future(loop=self._loop)
+ self._drain_waiter = waiter
+ return waiter
+
+
+class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
+ """Helper class to adapt between Protocol and StreamReader.
(This is a helper class instead of making StreamReader itself a
Protocol subclass, because the StreamReader has other potential
@@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol):
"""
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
+ super().__init__(loop=loop)
self._stream_reader = stream_reader
self._stream_writer = None
- self._drain_waiter = None
- self._paused = False
self._client_connected_cb = client_connected_cb
- self._loop = loop # May be None; we may never need it.
def connection_made(self, transport):
self._stream_reader.set_transport(transport)
@@ -127,16 +180,7 @@ class StreamReaderProtocol(protocols.Protocol):
self._stream_reader.feed_eof()
else:
self._stream_reader.set_exception(exc)
- # Also wake up the writing side.
- if self._paused:
- waiter = self._drain_waiter
- if waiter is not None:
- self._drain_waiter = None
- if not waiter.done():
- if exc is None:
- waiter.set_result(None)
- else:
- waiter.set_exception(exc)
+ super().connection_lost(exc)
def data_received(self, data):
self._stream_reader.feed_data(data)
@@ -144,19 +188,6 @@ class StreamReaderProtocol(protocols.Protocol):
def eof_received(self):
self._stream_reader.feed_eof()
- def pause_writing(self):
- assert not self._paused
- self._paused = True
-
- def resume_writing(self):
- assert self._paused
- self._paused = False
- waiter = self._drain_waiter
- if waiter is not None:
- self._drain_waiter = None
- if not waiter.done():
- waiter.set_result(None)
-
class StreamWriter:
"""Wraps a Transport.
@@ -211,17 +242,11 @@ class StreamWriter:
completed, which will happen when the buffer is (partially)
drained and the protocol is resumed.
"""
- if self._reader._exception is not None:
+ if self._reader is not None and self._reader._exception is not None:
raise self._reader._exception
if self._transport._conn_lost: # Uses private variable.
raise ConnectionResetError('Connection lost')
- if not self._protocol._paused:
- return ()
- waiter = self._protocol._drain_waiter
- assert waiter is None or waiter.cancelled()
- waiter = futures.Future(loop=self._loop)
- self._protocol._drain_waiter = waiter
- return waiter
+ return self._protocol._make_drain_waiter()
class StreamReader: