summaryrefslogtreecommitdiff
path: root/asyncio/subprocess_stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'asyncio/subprocess_stream.py')
-rw-r--r--asyncio/subprocess_stream.py193
1 files changed, 193 insertions, 0 deletions
diff --git a/asyncio/subprocess_stream.py b/asyncio/subprocess_stream.py
new file mode 100644
index 0000000..d065c16
--- /dev/null
+++ b/asyncio/subprocess_stream.py
@@ -0,0 +1,193 @@
+__all__ = ['subprocess_shell', 'subprocess_exec']
+
+from . import base_subprocess
+from . import events
+from . import protocols
+from . import streams
+from . import tasks
+
+class WriteSubprocessPipeStreamProto(base_subprocess.WriteSubprocessPipeProto):
+ def __init__(self, process_transport, fd):
+ base_subprocess.WriteSubprocessPipeProto.__init__(self, process_transport, fd)
+ self._drain_waiter = None
+ self._paused = False
+ self.writer = WritePipeStream(None, self, None)
+
+ def connection_made(self, transport):
+ super().connection_made(transport)
+ self.writer._transport = transport
+ self.writer._loop = transport._loop
+
+ def connection_lost(self, exc):
+ # Also wake up the writing side.
+ if self._paused:
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ if exc is None:
+ waiter.set_result(None)
+ else:
+ waiter.set_exception(exc)
+
+ def pause_writing(self):
+ assert not self._paused
+ self._paused = True
+
+ def resume_writing(self):
+ assert self._paused
+ self._paused = False
+ waiter = self._drain_waiter
+ if waiter is not None:
+ self._drain_waiter = None
+ if not waiter.done():
+ waiter.set_result(None)
+
+
+class WritePipeStream:
+ """Wraps a Transport.
+
+ This exposes write(), writelines(), [can_]write_eof(),
+ get_extra_info() and close(). It adds drain() which returns an
+ optional Future on which you can wait for flow control. It also
+ adds a transport property which references the Transport
+ directly.
+ """
+
+ def __init__(self, transport, protocol, loop):
+ self._transport = transport
+ self._protocol = protocol
+ self._loop = loop
+
+ @property
+ def transport(self):
+ return self._transport
+
+ def write(self, data):
+ self._transport.write(data)
+
+ def writelines(self, data):
+ self._transport.writelines(data)
+
+ def write_eof(self):
+ return self._transport.write_eof()
+
+ def can_write_eof(self):
+ return self._transport.can_write_eof()
+
+ def close(self):
+ return self._transport.close()
+
+ def get_extra_info(self, name, default=None):
+ return self._transport.get_extra_info(name, default)
+
+ def drain(self):
+ """This method has an unusual return value.
+
+ The intended use is to write
+
+ w.write(data)
+ yield from w.drain()
+
+ When there's nothing to wait for, drain() returns (), and the
+ yield-from continues immediately. When the transport buffer
+ is full (the protocol is paused), drain() creates and returns
+ a Future and the yield-from will block until that Future is
+ completed, which will happen when the buffer is (partially)
+ drained and the protocol is resumed.
+ """
+ if self._transport._conn_lost: # Uses private variable.
+ raise ConnectionResetError('Connection lost')
+ if not self._protocol._paused:
+ return ()
+ waiter = self._protocol._drain_waiter
+ assert waiter is None or waiter.cancelled()
+ waiter = futures.Future(loop=self._loop)
+ self._protocol._drain_waiter = waiter
+ return waiter
+
+
+class SubprocessStreamProtocol(protocols.SubprocessProtocol):
+ def __init__(self, limit=streams._DEFAULT_LIMIT):
+ self._pipes = {}
+ self.limit = limit
+ self.stdin = None
+ self.stdout = None
+ self.stderr = None
+ self._waiters = []
+ self._returncode = None
+ self._loop = None
+
+ def connection_made(self, transport):
+ self._loop = transport._loop
+ proc = transport._proc
+ if proc.stdout is not None:
+ self.stdout = self._get_protocol(1)._stream_reader
+ if proc.stderr is not None:
+ self.stderr = self._get_protocol(2)._stream_reader
+
+ def get_pipe_reader(self, fd):
+ if fd in self._pipes:
+ return self._pipes[fd]._stream_reader
+ else:
+ return None
+
+ def _get_protocol(self, fd):
+ try:
+ return self._pipes[fd]
+ except KeyError:
+ reader = streams.StreamReader(limit=self.limit)
+ protocol = streams.StreamReaderProtocol(reader, loop=self._loop)
+ self._pipes[fd] = protocol
+ return protocol
+
+ def pipe_data_received(self, fd, data):
+ protocol = self._get_protocol(fd)
+ protocol.data_received(data)
+
+ def pipe_connection_lost(self, fd, exc):
+ protocol = self._get_protocol(fd)
+ protocol.connection_lost(exc)
+
+ @tasks.coroutine
+ def wait(self):
+ """
+ Wait until the process exit and return the process return code.
+ """
+ if self._returncode:
+ return self._returncode
+
+ fut = tasks.Future()
+ self._waiters.append(fut)
+ yield from fut
+ return fut.result()
+
+ def process_exited(self, returncode):
+ self._returncode = returncode
+ # FIXME: not thread safe
+ waiters = self._waiters.copy()
+ self._waiters.clear()
+ for waiter in waiters:
+ waiter.set_result(returncode)
+
+ def pipe_connection_made(self, fd, pipe):
+ if fd == 0:
+ self.stdin = pipe.writer
+
+@tasks.coroutine
+def subprocess_exec(*args, **kwargs):
+ loop = kwargs.pop('loop', None)
+ if loop is None:
+ loop = events.get_event_loop()
+ kwargs['write_pipe_proto_factory'] = WriteSubprocessPipeStreamProto
+ yield from loop.subprocess_exec(SubprocessStreamProtocol, *args, **kwargs)
+
+
+@tasks.coroutine
+def subprocess_shell(*args, **kwargs):
+ loop = kwargs.pop('loop', None)
+ if loop is None:
+ loop = events.get_event_loop()
+ kwargs['write_pipe_protocol_factory'] = WriteSubprocessPipeStreamProto
+ return (yield from loop.subprocess_shell(SubprocessStreamProtocol, *args, **kwargs))
+