summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Stinner <vstinner@redhat.com>2015-07-20 17:18:59 +0200
committerVictor Stinner <vstinner@redhat.com>2015-07-20 17:36:41 +0200
commit13b71a16ee9d5af4939a7e214e92fa89cb96f6a3 (patch)
treecb9525ce018206804591e7e255c5745a044018c6
parent9bb67431adc916d9d4b4e23ca257658c980d035d (diff)
downloadtrollius-git-closing.tar.gz
Add closing read-only property to transportsclosing
* Disallow write() on closing transports * Disallow aslo calling pause_writing() and resume_writing() on StreamReaderProtocol if the transport is closing
-rw-r--r--asyncio/proactor_events.py3
-rw-r--r--asyncio/selector_events.py3
-rw-r--r--asyncio/streams.py11
-rw-r--r--asyncio/subprocess.py3
-rw-r--r--asyncio/transports.py6
-rw-r--r--asyncio/unix_events.py2
-rw-r--r--tests/test_proactor_events.py6
-rw-r--r--tests/test_selector_events.py4
-rw-r--r--tests/test_streams.py17
9 files changed, 47 insertions, 8 deletions
diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py
index 9c2b8f1..3a0960b 100644
--- a/asyncio/proactor_events.py
+++ b/asyncio/proactor_events.py
@@ -34,7 +34,6 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin,
self._write_fut = None
self._pending_write = 0
self._conn_lost = 0
- self._closing = False # Set when close() called.
self._eof_written = False
if self._server is not None:
self._server._attach()
@@ -225,6 +224,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport,
type(data))
if self._eof_written:
raise RuntimeError('write_eof() already called')
+ if self._closing:
+ raise RuntimeError('Cannot call write() after close()')
if not data:
return
diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py
index 7c5b9b5..6898138 100644
--- a/asyncio/selector_events.py
+++ b/asyncio/selector_events.py
@@ -521,7 +521,6 @@ class _SelectorTransport(transports._FlowControlMixin,
self._server = server
self._buffer = self._buffer_factory()
self._conn_lost = 0 # Set when call to connection_lost scheduled.
- self._closing = False # Set when close() called.
if self._server is not None:
self._server._attach()
@@ -681,6 +680,8 @@ class _SelectorSocketTransport(_SelectorTransport):
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError('data argument must be byte-ish (%r)',
type(data))
+ if self._closing:
+ raise RuntimeError('Cannot call write() after close()')
if self._eof:
raise RuntimeError('Cannot call write() after write_eof()')
if not data:
diff --git a/asyncio/streams.py b/asyncio/streams.py
index 6484c43..409b2f9 100644
--- a/asyncio/streams.py
+++ b/asyncio/streams.py
@@ -153,15 +153,25 @@ class FlowControlMixin(protocols.Protocol):
self._paused = False
self._drain_waiter = None
self._connection_lost = False
+ self._transport = None
+
+ def connection_made(self, transport):
+ self._transport = transport
def pause_writing(self):
assert not self._paused
+ if self._transport is not None and self._transport.closing:
+ raise RuntimeError('Cannot call pause_writing() '
+ 'on closing or closed transport')
self._paused = True
if self._loop.get_debug():
logger.debug("%r pauses writing", self)
def resume_writing(self):
assert self._paused
+ if self._transport is not None and self._transport.closing:
+ raise RuntimeError('Cannot call resume_writing() '
+ 'on closing or closed transport')
self._paused = False
if self._loop.get_debug():
logger.debug("%r resumes writing", self)
@@ -217,6 +227,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
self._client_connected_cb = client_connected_cb
def connection_made(self, transport):
+ super().connection_made(transport)
self._stream_reader.set_transport(transport)
if self._client_connected_cb is not None:
self._stream_writer = StreamWriter(transport, self,
diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py
index ead4039..53a8c68 100644
--- a/asyncio/subprocess.py
+++ b/asyncio/subprocess.py
@@ -23,7 +23,6 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
super().__init__(loop=loop)
self._limit = limit
self.stdin = self.stdout = self.stderr = None
- self._transport = None
def __repr__(self):
info = [self.__class__.__name__]
@@ -36,7 +35,7 @@ class SubprocessStreamProtocol(streams.FlowControlMixin,
return '<%s>' % ' '.join(info)
def connection_made(self, transport):
- self._transport = transport
+ super().connection_made(transport)
stdout_transport = transport.get_pipe_transport(1)
if stdout_transport is not None:
diff --git a/asyncio/transports.py b/asyncio/transports.py
index 70b323f..5540f7e 100644
--- a/asyncio/transports.py
+++ b/asyncio/transports.py
@@ -14,6 +14,7 @@ class BaseTransport:
if extra is None:
extra = {}
self._extra = extra
+ self._closing = False
def get_extra_info(self, name, default=None):
"""Get optional transport information."""
@@ -29,6 +30,11 @@ class BaseTransport:
"""
raise NotImplementedError
+ @property
+ def closing(self):
+ """Is the transport being closed?"""
+ return self._closing
+
class ReadTransport(BaseTransport):
"""Interface for read-only transports."""
diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py
index 75e7c9c..323a27a 100644
--- a/asyncio/unix_events.py
+++ b/asyncio/unix_events.py
@@ -311,7 +311,6 @@ class _UnixReadPipeTransport(transports.ReadTransport):
raise ValueError("Pipe transport is for pipes/sockets only.")
_set_nonblocking(self._fileno)
self._protocol = protocol
- self._closing = False
self._loop.call_soon(self._protocol.connection_made, self)
# only start reading when connection_made() has been called
self._loop.call_soon(self._loop.add_reader,
@@ -424,7 +423,6 @@ class _UnixWritePipeTransport(transports._FlowControlMixin,
self._protocol = protocol
self._buffer = []
self._conn_lost = 0
- self._closing = False # Set when close() or write_eof() called.
self._loop.call_soon(self._protocol.connection_made, self)
diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py
index fcd9ab1..0c0d44d 100644
--- a/tests/test_proactor_events.py
+++ b/tests/test_proactor_events.py
@@ -144,6 +144,12 @@ class ProactorSocketTransportTests(test_utils.TestCase):
self.assertEqual(tr._buffer, b'data')
self.assertFalse(tr._loop_writing.called)
+ def test_write_closing(self):
+ transport = self.socket_transport()
+ transport.close()
+ # write() is disallowed after close()
+ self.assertRaises(RuntimeError, transport.write, b'data')
+
def test_loop_writing(self):
tr = self.socket_transport()
tr._buffer = bytearray(b'data')
diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py
index f0fcdd2..80f93ec 100644
--- a/tests/test_selector_events.py
+++ b/tests/test_selector_events.py
@@ -1000,8 +1000,8 @@ class SelectorSocketTransportTests(test_utils.TestCase):
transport = self.socket_transport()
transport.close()
self.assertEqual(transport._conn_lost, 1)
- transport.write(b'data')
- self.assertEqual(transport._conn_lost, 2)
+ # write() is disallowed after close()
+ self.assertRaises(RuntimeError, transport.write, b'data')
def test_write_ready(self):
data = b'data'
diff --git a/tests/test_streams.py b/tests/test_streams.py
index 242b377..c883ba3 100644
--- a/tests/test_streams.py
+++ b/tests/test_streams.py
@@ -636,6 +636,23 @@ os.close(fd)
protocol = asyncio.StreamReaderProtocol(reader)
self.assertIs(protocol._loop, self.loop)
+ def test_pause_writing_closing(self):
+ reader = mock.Mock()
+ transport = asyncio.ReadTransport()
+ protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
+ protocol.connection_made(transport)
+ transport._closing = True
+ self.assertRaises(RuntimeError, protocol.pause_writing)
+
+ def test_resume_writing_closing(self):
+ reader = mock.Mock()
+ transport = asyncio.ReadTransport()
+ protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
+ protocol.connection_made(transport)
+ protocol.pause_writing()
+ transport._closing = True
+ self.assertRaises(RuntimeError, protocol.resume_writing)
+
if __name__ == '__main__':
unittest.main()