summaryrefslogtreecommitdiff
path: root/extensions/writeexts.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions/writeexts.py')
-rw-r--r--extensions/writeexts.py330
1 files changed, 313 insertions, 17 deletions
diff --git a/extensions/writeexts.py b/extensions/writeexts.py
index 61d40789..f84f2288 100644
--- a/extensions/writeexts.py
+++ b/extensions/writeexts.py
@@ -1,4 +1,5 @@
# Copyright (C) 2012-2015 Codethink Limited
+# Copyright (C) 2011, 2012 Lars Wirzenius
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
@@ -15,9 +16,11 @@
import contextlib
import errno
+import fcntl
import logging
import os
import re
+import select
import shutil
import stat
import subprocess
@@ -26,8 +29,221 @@ import tempfile
import time
-def shell_quote(string):
- '''Return a shell-quoted version of `string`.'''
+def get_data_path(relative_path):
+ '''Return path to a data file in the morphlib Python package.
+
+ ``relative_path`` is the name of the data file, relative to the
+ `extensions/` directory.
+
+ '''
+
+ extensions_dir = os.path.dirname(__file__)
+ return os.path.join(extensions_dir, relative_path)
+
+
+def get_data(relative_path): # pragma: no cover
+ '''Return contents of a data file from the morphlib Python package.
+
+ ``relative_path`` is the name of the data file, relative to the
+ `extensions/` directory.
+
+ '''
+
+ with open(get_data_path(relative_path)) as f:
+ return f.read()
+
+
+def runcmd(argv, *args, **kwargs):
+ '''Run external command or pipeline.
+
+ Example: ``runcmd(['grep', 'foo'], ['wc', '-l'],
+ feed_stdin='foo\nbar\n')``
+
+ Return the standard output of the command.
+
+ Raise ``ExtensionError`` if external command returns
+ non-zero exit code. ``*args`` and ``**kwargs`` are passed
+ onto ``subprocess.Popen``.
+
+ '''
+
+ our_options = (
+ ('ignore_fail', False),
+ ('log_error', True),
+ )
+ opts = {}
+ for name, default in our_options:
+ opts[name] = default
+ if name in kwargs:
+ opts[name] = kwargs[name]
+ del kwargs[name]
+
+ exit, out, err = runcmd_unchecked(argv, *args, **kwargs)
+ if exit != 0:
+ msg = 'Command failed: %s\n%s' % (' '.join(argv), err)
+ if opts['ignore_fail']:
+ if opts['log_error']:
+ logging.info(msg)
+ else:
+ if opts['log_error']:
+ logging.error(msg)
+ raise ExtensionError(msg)
+ return out
+
+def runcmd_unchecked(argv, *argvs, **kwargs):
+ '''Run external command or pipeline.
+
+ Return the exit code, and contents of standard output and error
+ of the command.
+
+ See also ``runcmd``.
+
+ '''
+
+ argvs = [argv] + list(argvs)
+ logging.debug('run external command: %s' % repr(argvs))
+
+ def pop_kwarg(name, default):
+ if name in kwargs:
+ value = kwargs[name]
+ del kwargs[name]
+ return value
+ else:
+ return default
+
+ feed_stdin = pop_kwarg('feed_stdin', '')
+ pipe_stdin = pop_kwarg('stdin', subprocess.PIPE)
+ pipe_stdout = pop_kwarg('stdout', subprocess.PIPE)
+ pipe_stderr = pop_kwarg('stderr', subprocess.PIPE)
+
+ try:
+ pipeline = _build_pipeline(argvs,
+ pipe_stdin,
+ pipe_stdout,
+ pipe_stderr,
+ kwargs)
+ return _run_pipeline(pipeline, feed_stdin, pipe_stdin,
+ pipe_stdout, pipe_stderr)
+ except OSError, e: # pragma: no cover
+ if e.errno == errno.ENOENT and e.filename is None:
+ e.filename = argv[0]
+ raise e
+ else:
+ raise
+
+def _build_pipeline(argvs, pipe_stdin, pipe_stdout, pipe_stderr, kwargs):
+ procs = []
+ for i, argv in enumerate(argvs):
+ if i == 0 and i == len(argvs) - 1:
+ stdin = pipe_stdin
+ stdout = pipe_stdout
+ stderr = pipe_stderr
+ elif i == 0:
+ stdin = pipe_stdin
+ stdout = subprocess.PIPE
+ stderr = pipe_stderr
+ elif i == len(argvs) - 1:
+ stdin = procs[-1].stdout
+ stdout = pipe_stdout
+ stderr = pipe_stderr
+ else:
+ stdin = procs[-1].stdout
+ stdout = subprocess.PIPE
+ stderr = pipe_stderr
+ p = subprocess.Popen(argv, stdin=stdin, stdout=stdout,
+ stderr=stderr, close_fds=True, **kwargs)
+ procs.append(p)
+
+ return procs
+
+def _run_pipeline(procs, feed_stdin, pipe_stdin, pipe_stdout, pipe_stderr):
+
+ stdout_eof = False
+ stderr_eof = False
+ out = []
+ err = []
+ pos = 0
+ io_size = 1024
+
+ def set_nonblocking(fd):
+ flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
+ flags = flags | os.O_NONBLOCK
+ fcntl.fcntl(fd, fcntl.F_SETFL, flags)
+
+ if feed_stdin and pipe_stdin == subprocess.PIPE:
+ set_nonblocking(procs[0].stdin.fileno())
+ if pipe_stdout == subprocess.PIPE:
+ set_nonblocking(procs[-1].stdout.fileno())
+ if pipe_stderr == subprocess.PIPE:
+ set_nonblocking(procs[-1].stderr.fileno())
+
+ def still_running():
+ for p in procs:
+ p.poll()
+ for p in procs:
+ if p.returncode is None:
+ return True
+ if pipe_stdout == subprocess.PIPE and not stdout_eof:
+ return True
+ if pipe_stderr == subprocess.PIPE and not stderr_eof:
+ return True # pragma: no cover
+ return False
+
+ while still_running():
+ rlist = []
+ if not stdout_eof and pipe_stdout == subprocess.PIPE:
+ rlist.append(procs[-1].stdout)
+ if not stderr_eof and pipe_stderr == subprocess.PIPE:
+ rlist.append(procs[-1].stderr)
+
+ wlist = []
+ if pipe_stdin == subprocess.PIPE and pos < len(feed_stdin):
+ wlist.append(procs[0].stdin)
+
+ if rlist or wlist:
+ try:
+ r, w, x = select.select(rlist, wlist, [])
+ except select.error, e:
+ err, msg = e.args
+ if err == errno.EINTR:
+ break
+ raise
+ else:
+ break # Let's not busywait waiting for processes to die.
+
+ if procs[0].stdin in w and pos < len(feed_stdin):
+ data = feed_stdin[pos : pos+io_size]
+ procs[0].stdin.write(data)
+ pos += len(data)
+ if pos >= len(feed_stdin):
+ procs[0].stdin.close()
+
+ if procs[-1].stdout in r:
+ data = procs[-1].stdout.read(io_size)
+ if data:
+ out.append(data)
+ else:
+ stdout_eof = True
+
+ if procs[-1].stderr in r:
+ data = procs[-1].stderr.read(io_size)
+ if data:
+ err.append(data)
+ else:
+ stderr_eof = True
+
+ while still_running():
+ for p in procs:
+ if p.returncode is None:
+ p.wait()
+
+ errorcodes = [p.returncode for p in procs if p.returncode != 0] or [0]
+ return errorcodes[-1], ''.join(out), ''.join(err)
+
+
+def shell_quote(s):
+ '''Return a shell-quoted version of s.'''
+
lower_ascii = 'abcdefghijklmnopqrstuvwxyz'
upper_ascii = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
digits = '0123456789'
@@ -35,25 +251,65 @@ def shell_quote(string):
safe = set(lower_ascii + upper_ascii + digits + punctuation)
quoted = []
- for character in string:
- if character in safe:
- quoted.append(character)
- elif character == "'":
+ for c in s:
+ if c in safe:
+ quoted.append(c)
+ elif c == "'":
quoted.append('"\'"')
else:
- quoted.append("'%c'" % character)
+ quoted.append("'%c'" % c)
return ''.join(quoted)
-def run_ssh_command(host, command):
- '''Run `command` over SSH on `host`.'''
- ssh_cmd = ['ssh', host, '--'] + [shell_quote(arg) for arg in command]
- return subprocess.check_output(ssh_cmd)
+def ssh_runcmd(target, argv, **kwargs):
+ '''Run command in argv on remote host target.
+
+ This is similar to runcmd, but the command is run on the remote
+ machine. The command is given as an argv array; elements in the
+ array are automatically quoted so they get passed to the other
+ side correctly.
+ An optional ``tty=`` parameter can be passed to ``ssh_runcmd`` in
+ order to force or disable pseudo-tty allocation. This is often
+ required to run ``sudo`` on another machine and might be useful
+ in other situations as well. Supported values are ``tty=True`` for
+ forcing tty allocation, ``tty=False`` for disabling it and
+ ``tty=None`` for not passing anything tty related to ssh.
+
+ With the ``tty`` option,
+ ``cliapp.runcmd(['ssh', '-tt', 'user@host', '--', 'sudo', 'ls'])``
+ can be written as
+ ``cliapp.ssh_runcmd('user@host', ['sudo', 'ls'], tty=True)``
+ which is more intuitive.
+
+ The target is given as-is to ssh, and may use any syntax ssh
+ accepts.
+
+ Environment variables may or may not be passed to the remote
+ machine: this is dependent on the ssh and sshd configurations.
+ Invoke env(1) explicitly to pass in the variables you need to
+ exist on the other end.
+
+ Pipelines are not supported.
+
+ '''
+
+ tty = kwargs.get('tty', None)
+ if tty:
+ ssh_cmd = ['ssh', '-tt', target, '--']
+ elif tty is False:
+ ssh_cmd = ['ssh', '-T', target, '--']
+ else:
+ ssh_cmd = ['ssh', target, '--']
+ if 'tty' in kwargs:
+ del kwargs['tty']
+
+ local_argv = ssh_cmd + map(shell_quote, argv)
+ return runcmd(local_argv, **kwargs)
def write_from_dict(filepath, d, validate=lambda x, y: True):
- '''Takes a dictionary and appends the contents to a file
+ """Takes a dictionary and appends the contents to a file
An optional validation callback can be passed to perform validation on
each value in the dictionary.
@@ -66,7 +322,8 @@ def write_from_dict(filepath, d, validate=lambda x, y: True):
Any callback supplied to this function should raise an exception
if validation fails.
- '''
+
+ """
# Sort items asciibetically
# the output of the deployment should not depend
@@ -84,6 +341,31 @@ def write_from_dict(filepath, d, validate=lambda x, y: True):
os.fchmod(f.fileno(), 0644)
+def parse_environment_pairs(env, pairs):
+ '''Add key=value pairs to the environment dict.
+
+ Given a dict and a list of strings of the form key=value,
+ set dict[key] = value, unless key is already set in the
+ environment, at which point raise an exception.
+
+ This does not modify the passed in dict.
+
+ Returns the extended dict.
+
+ '''
+
+ extra_env = dict(p.split('=', 1) for p in pairs)
+ conflicting = [k for k in extra_env if k in env]
+ if conflicting:
+ raise EnvironmentAlreadySetError(conflicting)
+
+ # Return a dict that is the union of the two
+ # This is not the most performant, since it creates
+ # 3 unnecessary lists, but I felt this was the most
+ # easy to read. Using itertools.chain may be more efficicent
+ return dict(env.items() + extra_env.items())
+
+
class ExtensionError(Exception):
def __init__(self, msg):
@@ -146,9 +428,9 @@ class Fstab(object):
shutil.move(os.path.abspath(tmp), os.path.abspath(self.filepath))
-class WriteExtension(object):
+class Extension(object):
- '''A base class for deployment write extensions.
+ '''A base class for deployment extensions.
A subclass should subclass this class, and add a
``process_args`` method.
@@ -190,7 +472,7 @@ class WriteExtension(object):
self.setup_logging()
self.process_args(args)
except ExtensionError as e:
- sys.stdout.write('ERROR: %s' % e)
+ sys.stdout.write('ERROR: %s\n' % e)
sys.exit(1)
def status(self, **kwargs):
@@ -204,6 +486,20 @@ class WriteExtension(object):
sys.stdout.write('%s\n' % (kwargs['msg'] % kwargs))
sys.stdout.flush()
+
+class WriteExtension(Extension):
+
+ '''A base class for deployment write extensions.
+
+ A subclass should subclass this class, and add a
+ ``process_args`` method.
+
+ Note that it is not necessary to subclass this class for write
+ extensions. This class is here just to collect common code for
+ write extensions.
+
+ '''
+
def check_for_btrfs_in_deployment_host_kernel(self):
with open('/proc/filesystems') as f:
text = f.read()
@@ -676,7 +972,7 @@ class WriteExtension(object):
def check_ssh_connectivity(self, ssh_host):
try:
- output = run_ssh_command(ssh_host, ['echo', 'test'])
+ output = ssh_runcmd(ssh_host, ['echo', 'test'])
except subprocess.CalledProcessError as e:
logging.error("Error checking SSH connectivity: %s", str(e))
raise ExtensionError(