diff options
author | Victor Stinner <vstinner@redhat.com> | 2015-07-20 17:18:59 +0200 |
---|---|---|
committer | Victor Stinner <vstinner@redhat.com> | 2015-07-20 17:36:41 +0200 |
commit | 13b71a16ee9d5af4939a7e214e92fa89cb96f6a3 (patch) | |
tree | cb9525ce018206804591e7e255c5745a044018c6 | |
parent | 9bb67431adc916d9d4b4e23ca257658c980d035d (diff) | |
download | trollius-git-closing.tar.gz |
Add closing read-only property to transportsclosing
* Disallow write() on closing transports
* Disallow aslo calling pause_writing() and resume_writing() on
StreamReaderProtocol if the transport is closing
-rw-r--r-- | asyncio/proactor_events.py | 3 | ||||
-rw-r--r-- | asyncio/selector_events.py | 3 | ||||
-rw-r--r-- | asyncio/streams.py | 11 | ||||
-rw-r--r-- | asyncio/subprocess.py | 3 | ||||
-rw-r--r-- | asyncio/transports.py | 6 | ||||
-rw-r--r-- | asyncio/unix_events.py | 2 | ||||
-rw-r--r-- | tests/test_proactor_events.py | 6 | ||||
-rw-r--r-- | tests/test_selector_events.py | 4 | ||||
-rw-r--r-- | tests/test_streams.py | 17 |
9 files changed, 47 insertions, 8 deletions
diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 9c2b8f1..3a0960b 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -34,7 +34,6 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, self._write_fut = None self._pending_write = 0 self._conn_lost = 0 - self._closing = False # Set when close() called. self._eof_written = False if self._server is not None: self._server._attach() @@ -225,6 +224,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, type(data)) if self._eof_written: raise RuntimeError('write_eof() already called') + if self._closing: + raise RuntimeError('Cannot call write() after close()') if not data: return diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 7c5b9b5..6898138 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -521,7 +521,6 @@ class _SelectorTransport(transports._FlowControlMixin, self._server = server self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. - self._closing = False # Set when close() called. if self._server is not None: self._server._attach() @@ -681,6 +680,8 @@ class _SelectorSocketTransport(_SelectorTransport): if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError('data argument must be byte-ish (%r)', type(data)) + if self._closing: + raise RuntimeError('Cannot call write() after close()') if self._eof: raise RuntimeError('Cannot call write() after write_eof()') if not data: diff --git a/asyncio/streams.py b/asyncio/streams.py index 6484c43..409b2f9 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -153,15 +153,25 @@ class FlowControlMixin(protocols.Protocol): self._paused = False self._drain_waiter = None self._connection_lost = False + self._transport = None + + def connection_made(self, transport): + self._transport = transport def pause_writing(self): assert not self._paused + if self._transport is not None and self._transport.closing: + raise RuntimeError('Cannot call pause_writing() ' + 'on closing or closed transport') self._paused = True if self._loop.get_debug(): logger.debug("%r pauses writing", self) def resume_writing(self): assert self._paused + if self._transport is not None and self._transport.closing: + raise RuntimeError('Cannot call resume_writing() ' + 'on closing or closed transport') self._paused = False if self._loop.get_debug(): logger.debug("%r resumes writing", self) @@ -217,6 +227,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): self._client_connected_cb = client_connected_cb def connection_made(self, transport): + super().connection_made(transport) self._stream_reader.set_transport(transport) if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index ead4039..53a8c68 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -23,7 +23,6 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, super().__init__(loop=loop) self._limit = limit self.stdin = self.stdout = self.stderr = None - self._transport = None def __repr__(self): info = [self.__class__.__name__] @@ -36,7 +35,7 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, return '<%s>' % ' '.join(info) def connection_made(self, transport): - self._transport = transport + super().connection_made(transport) stdout_transport = transport.get_pipe_transport(1) if stdout_transport is not None: diff --git a/asyncio/transports.py b/asyncio/transports.py index 70b323f..5540f7e 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -14,6 +14,7 @@ class BaseTransport: if extra is None: extra = {} self._extra = extra + self._closing = False def get_extra_info(self, name, default=None): """Get optional transport information.""" @@ -29,6 +30,11 @@ class BaseTransport: """ raise NotImplementedError + @property + def closing(self): + """Is the transport being closed?""" + return self._closing + class ReadTransport(BaseTransport): """Interface for read-only transports.""" diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 75e7c9c..323a27a 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -311,7 +311,6 @@ class _UnixReadPipeTransport(transports.ReadTransport): raise ValueError("Pipe transport is for pipes/sockets only.") _set_nonblocking(self._fileno) self._protocol = protocol - self._closing = False self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called self._loop.call_soon(self._loop.add_reader, @@ -424,7 +423,6 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, self._protocol = protocol self._buffer = [] self._conn_lost = 0 - self._closing = False # Set when close() or write_eof() called. self._loop.call_soon(self._protocol.connection_made, self) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index fcd9ab1..0c0d44d 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -144,6 +144,12 @@ class ProactorSocketTransportTests(test_utils.TestCase): self.assertEqual(tr._buffer, b'data') self.assertFalse(tr._loop_writing.called) + def test_write_closing(self): + transport = self.socket_transport() + transport.close() + # write() is disallowed after close() + self.assertRaises(RuntimeError, transport.write, b'data') + def test_loop_writing(self): tr = self.socket_transport() tr._buffer = bytearray(b'data') diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index f0fcdd2..80f93ec 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1000,8 +1000,8 @@ class SelectorSocketTransportTests(test_utils.TestCase): transport = self.socket_transport() transport.close() self.assertEqual(transport._conn_lost, 1) - transport.write(b'data') - self.assertEqual(transport._conn_lost, 2) + # write() is disallowed after close() + self.assertRaises(RuntimeError, transport.write, b'data') def test_write_ready(self): data = b'data' diff --git a/tests/test_streams.py b/tests/test_streams.py index 242b377..c883ba3 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -636,6 +636,23 @@ os.close(fd) protocol = asyncio.StreamReaderProtocol(reader) self.assertIs(protocol._loop, self.loop) + def test_pause_writing_closing(self): + reader = mock.Mock() + transport = asyncio.ReadTransport() + protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop) + protocol.connection_made(transport) + transport._closing = True + self.assertRaises(RuntimeError, protocol.pause_writing) + + def test_resume_writing_closing(self): + reader = mock.Mock() + transport = asyncio.ReadTransport() + protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop) + protocol.connection_made(transport) + protocol.pause_writing() + transport._closing = True + self.assertRaises(RuntimeError, protocol.resume_writing) + if __name__ == '__main__': unittest.main() |