summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Thursfield <sam.thursfield@codethink.co.uk>2015-06-24 13:21:06 +0100
committerSam Thursfield <sam.thursfield@codethink.co.uk>2015-06-24 13:21:06 +0100
commitf4ece423ebf135992fc7b85b120ee5f20de24bc2 (patch)
tree7105fde44023867ce576613e52b3bd5b95a09b85
parent6dfd1b9c945f47969052a20095d667400befa076 (diff)
downloadsandboxlib-f4ece423ebf135992fc7b85b120ee5f20de24bc2.tar.gz
Add utils.duplicate_streams() helper function
This is based off code written by Richard Maw in git://git.baserock.org/baserock/baserock/morph.git.
-rw-r--r--sandboxlib/utils.py110
-rw-r--r--tests/test_utils.py45
2 files changed, 155 insertions, 0 deletions
diff --git a/sandboxlib/utils.py b/sandboxlib/utils.py
index b3ec867..5092e8e 100644
--- a/sandboxlib/utils.py
+++ b/sandboxlib/utils.py
@@ -13,6 +13,8 @@
# with this program. If not, see <http://www.gnu.org/licenses/>.
+import asynchat
+import asyncore
import logging
import os
import shutil
@@ -51,3 +53,111 @@ def find_program(program_name):
program_name, search_path))
return program_path
+
+
+class AsyncoreFileWrapperWithEOFHandler(asyncore.file_wrapper):
+ '''File wrapper that reports when it hits the end-of-file marker.
+
+ The asyncore.file_wrapper class wraps a file in a way that makes it
+ act like a socket, so that it can easily be used in an 'asyncore' main
+ loop. We use this with the asynchat.async_chat class, which provides a
+ channel that parses the output of a stream and calls a callback function.
+ The one hitch is that asynchat.async_chat doesn't notice when the stream
+ has hit the end-of-file delimeter. This class is a workaround which causes
+ the AsyncoreStreamProcessingChannel instance to close itself once the end
+ of the stream is reached.
+
+ '''
+ def __init__(self, dispatcher, fd):
+ self._dispatcher = dispatcher
+ asyncore.file_wrapper.__init__(self, fd)
+
+ def recv(self, *args):
+ data = asyncore.file_wrapper.recv(self, *args)
+ if not data:
+ self._dispatcher.close_when_done()
+ # ensure any unterminated data is flushed
+ return self._dispatcher.get_terminator()
+ return data
+
+
+class AsyncoreStreamProcessingChannel(asynchat.async_chat,
+ asyncore.file_dispatcher):
+ '''Channel to read from a stream and pass the data to a handler function.
+
+ The 'asyncore' module provides a select()-based main loop. We use this in
+ duplicate_streams() to multiplex reading from various streams and
+ duplicating their output. This class provides a channel that can be added
+ the 'asyncore' main loop and will read from a given file descriptor, and
+ call a given callback function at the end of each line.
+
+ '''
+ def __init__(self, fd, line_handler, map=None):
+ asynchat.async_chat.__init__(self, sock=None, map=map)
+ asyncore.file_dispatcher.__init__(self, fd=fd, map=map)
+ self.set_terminator(b'\n')
+ self._line_handler = line_handler
+
+ class FileWrapperWithEOFHandler(asyncore.file_wrapper):
+ '''File wrapper that reports when it hits the end-of-file marker.
+
+ The asyncore.file_wrapper class makes a stream object behave like a
+ socket object, so it can be managed by the asyncore.dispatcher code.
+ This subclass is a workaround for the fact that the asynchat.async_chat
+ class doesn't handle the end-of-file delimiter, so we need to hook into
+ the recv() method and manually close the channel when we see EOF.
+
+ '''
+ def __init__(self, dispatcher, fd):
+ self._dispatcher = dispatcher
+ asyncore.file_wrapper.__init__(self, fd)
+
+ def recv(self, *args):
+ data = asyncore.file_wrapper.recv(self, *args)
+ if not data:
+ self._dispatcher.close_when_done()
+ # ensure any unterminated data is flushed
+ return self._dispatcher.get_terminator()
+ return data
+
+ collect_incoming_data = asynchat.async_chat._collect_incoming_data
+
+ def set_file(self, fd):
+ # Called on initialisation.
+ self.socket = AsyncoreFileWrapperWithEOFHandler(self, fd)
+ self._fileno = self.socket.fileno()
+ self.add_channel()
+
+ def found_terminator(self):
+ # Called when the \n terminator is found in the input data.
+ for data in self.incoming:
+ self._line_handler(b''.join(self.incoming) + self.terminator)
+ self.incoming = []
+
+
+def duplicate_streams(stream_map, flush_interval=0.0):
+ '''Copy data from one or more input streams to multiple output streams.
+
+ Similar to the `tee` commandline utility, this can be used for echoing
+ 'stdout' and 'stderr' of a subprocess to the parent's 'stdout' and
+ 'stderr', whilst also saving it all to a log file.
+
+ This function will block until the end-of-file terminator has been received
+ on all of the input streams.
+
+ '''
+ # The AsyncoreStreamProcessingChannel instances are tracked in socket_map.
+ socket_map = {}
+
+ for input_fd, output_fds in stream_map.items():
+ def write_line(line):
+ for fd in output_fds:
+ fd.write(line)
+
+ AsyncoreStreamProcessingChannel(
+ line_handler=write_line, fd=input_fd, map=socket_map)
+
+ while socket_map:
+ asyncore.loop(timeout=flush_interval, use_poll=True, map=socket_map,
+ count=1)
+ print(socket_map)
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..aa93480
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,45 @@
+# Copyright (C) 2015 Codethink Limited
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; version 2 of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License along
+# with this program. If not, see <http://www.gnu.org/licenses/>.
+
+
+'''Unit tests for some utility code.'''
+
+
+import os
+import tempfile
+import threading
+
+import sandboxlib
+
+
+def test_duplicate_streams():
+ read_fd, write_fd = os.pipe()
+
+ def write_data(write_fd):
+ write_f = os.fdopen(write_fd, 'wb')
+ write_f.write('hello\n'.encode('utf-8'))
+ write_f.close()
+
+ write_thread = threading.Thread(target=write_data, args=[write_fd])
+ write_thread.run()
+
+ #os.close(write_fd)
+
+ with tempfile.TemporaryFile() as f_1:
+ with tempfile.TemporaryFile() as f_2:
+ sandboxlib.utils.duplicate_streams({read_fd: [f_1, f_2]})
+ f_1.seek(0)
+ assert f_1.read().decode('utf-8') == 'hello\n'
+ f_2.seek(0)
+ assert f_2.read().decode('utf-8') == 'hello\n'