summaryrefslogtreecommitdiff
path: root/morphlib/exts/ssh.configure
blob: 2f3167e76325d955fd724de10f957655931c538a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/python
# Copyright (C) 2013  Codethink Limited
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

'''A Morph deployment configuration to copy SSH keys.

Keys are copied from the host to the new system.
'''

import cliapp
import os
import sys
import shutil
import glob
import re
import logging

import morphlib

class SshConfigurationExtension(cliapp.Application):

    '''Copy over SSH keys to new system from host.
    
    The extension requires SSH_KEY_DIR to be set at the command line as it
    will otherwise pass with only a status update. SSH_KEY_DIR should be
    set to the location of the SSH keys to be passed to the new system.
    
    '''

    def process_args(self, args):
        if 'SSH_KEY_DIR' in os.environ:
            # Copies ssh_host keys.
            key = 'ssh_host_*_key'
            mode = 0755
            dest = os.path.join(args[0], 'etc/ssh/')
            sshhost, sshhostpub = self.find_keys(key)
            if sshhost or sshhostpub:
                self.check_dir(dest, mode)
                self.copy_keys(sshhost, sshhostpub, dest)

            # Copies root keys.
            key = 'root_*_key'
            mode = 0700
            dest = os.path.join(args[0], 'root/.ssh/')
            roothost, roothostpub = self.find_keys(key)
            key = 'root_authorized_key_*.pub'
            authkey, bleh = self.find_keys(key)
            if roothost or roothostpub:
                self.check_dir(dest, mode)
                self.copy_rename_keys(roothost, 
                    roothostpub, dest, 'id_', [5, 4])
            if authkey:
                self.check_dir(dest, mode)
                self.comb_auth_key(authkey, dest)

            # Fills the known_hosts file
            key = 'root_known_host_*_key.pub'
            src = os.path.join(os.environ['SSH_KEY_DIR'], key)
            known_hosts_keys = glob.glob(src)
            if known_hosts_keys:
                self.check_dir(dest, mode)
                known_hosts_path = os.path.join(dest, 'known_hosts')
                with open(known_hosts_path, "a") as known_hosts_file:
                    for filename in known_hosts_keys:
                        hostname = re.search('root_known_host_(.+?)_key.pub',
                                             filename).group(1)
                        known_hosts_file.write(hostname + " ")
                        with open(filename, "r") as f:
                            shutil.copyfileobj(f, known_hosts_file)

        else:
            self.status(msg="No SSH key directory found.")
            pass

    def find_keys(self, key_name):
        '''Uses glob to find public and
        private SSH keys and returns their path'''

        src = os.path.join(os.environ['SSH_KEY_DIR'], key_name)
        keys = glob.glob(src)
        pubkeys = glob.glob(src + '.pub')
        if not (keys or pubkeys):
            self.status(msg="No SSH keys of pattern %(src)s found.", src=src)
        return keys, pubkeys

    def check_dir(self, dest, mode):
        '''Checks if destination directory exists
        and creates it if necessary'''

        if os.path.exists(dest) == False:
            self.status(msg="Creating SSH key directory: %(dest)s", dest=dest)
            os.mkdir(dest)
            os.chmod(dest, mode)
        else:
            pass

    def copy_keys(self, keys, pubkeys, dest):
        '''Copies SSH keys to new VM'''

        for key in keys:
            shutil.copy(key, dest)
            path = os.path.join(dest, os.path.basename(key))
            os.chmod(path, 0600)
        for key in pubkeys:
            shutil.copy(key, dest)
            path = os.path.join(dest, os.path.basename(key))
            os.chmod(path, 0644)

    def copy_rename_keys(self, keys, pubkeys, dest, new, snip):
        '''Copies SSH keys to new VM and renames them'''

        st, fi = snip
        for key in keys:
            base = os.path.basename(key)
            s = len(base)
            nw_dst = os.path.join(dest, new + base[st:s-fi])
            shutil.copy(key, nw_dst)
            os.chmod(nw_dst, 0600)
        for key in pubkeys:
            base = os.path.basename(key)
            s = len(base)
            nw_dst = os.path.join(dest, new + base[st:s-fi-4])
            shutil.copy(key, nw_dst + '.pub')
            os.chmod(nw_dst + '.pub', 0644)

    def comb_auth_key(self, keys, dest):
        '''Combines authorized_keys file in new VM'''

        dest = os.path.join(dest, 'authorized_keys')
        fout = open(dest, 'a')
        for key in keys:
            fin = open(key, 'r')
            data = fin.read()
            fout.write(data)
            fin.close()
        fout.close()
        os.chmod(dest, 0600)

    def status(self, **kwargs):
        '''Provide status output.
        
        The ``msg`` keyword argument is the actual message,
        the rest are values for fields in the message as interpolated
        by %.
        
        '''

        self.output.write('%s\n' % (kwargs['msg'] % kwargs))

SshConfigurationExtension().run()