From 9f5ff0847af168721b18f3b1050b8f673e4c40d9 Mon Sep 17 00:00:00 2001 From: Andrey Pavlov Date: Sat, 21 Mar 2015 11:29:41 +0300 Subject: Migrate ssh.py from tempest This commit migrates the ssh common module from tempest. This module provides a wrapper around paramiko to provide ssh connectivity to a remote server for testing. At the time of this migration commit the latest Change-Id for the module was: * tempest/common/ssh.py: Ia0de957b681cb924a57af98d99a9389ee234ed5b Change-Id: Iad5ffbfe18690f1b6f178dde31a8ec0d0666c815 --- requirements.txt | 1 + tempest_lib/common/ssh.py | 152 ++++++++++++++++++++++++++++++++++ tempest_lib/exceptions.py | 11 +++ tempest_lib/tests/test_ssh.py | 188 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 352 insertions(+) create mode 100644 tempest_lib/common/ssh.py create mode 100644 tempest_lib/tests/test_ssh.py diff --git a/requirements.txt b/requirements.txt index dd1377e..7162de6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ fixtures>=0.3.14 iso8601>=0.1.9 jsonschema>=2.0.0,<3.0.0 httplib2>=0.7.5 +paramiko>=1.13.0 six>=1.9.0 oslo.log>=1.0.0,<1.1.0 # Apache-2.0 diff --git a/tempest_lib/common/ssh.py b/tempest_lib/common/ssh.py new file mode 100644 index 0000000..6ee0daf --- /dev/null +++ b/tempest_lib/common/ssh.py @@ -0,0 +1,152 @@ +# Copyright 2012 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import select +import socket +import time +import warnings + +from oslo_log import log as logging +import six + +from tempest_lib import exceptions + + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import paramiko + + +LOG = logging.getLogger(__name__) + + +class Client(object): + + def __init__(self, host, username, password=None, timeout=300, pkey=None, + channel_timeout=10, look_for_keys=False, key_filename=None): + self.host = host + self.username = username + self.password = password + if isinstance(pkey, six.string_types): + pkey = paramiko.RSAKey.from_private_key( + six.StringIO(str(pkey))) + self.pkey = pkey + self.look_for_keys = look_for_keys + self.key_filename = key_filename + self.timeout = int(timeout) + self.channel_timeout = float(channel_timeout) + self.buf_size = 1024 + + def _get_ssh_connection(self, sleep=1.5, backoff=1): + """Returns an ssh connection to the specified host.""" + bsleep = sleep + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy( + paramiko.AutoAddPolicy()) + _start_time = time.time() + if self.pkey is not None: + LOG.info("Creating ssh connection to '%s' as '%s'" + " with public key authentication", + self.host, self.username) + else: + LOG.info("Creating ssh connection to '%s' as '%s'" + " with password %s", + self.host, self.username, str(self.password)) + attempts = 0 + while True: + try: + ssh.connect(self.host, username=self.username, + password=self.password, + look_for_keys=self.look_for_keys, + key_filename=self.key_filename, + timeout=self.channel_timeout, pkey=self.pkey) + LOG.info("ssh connection to %s@%s successfuly created", + self.username, self.host) + return ssh + except (socket.error, + paramiko.SSHException) as e: + if self._is_timed_out(_start_time): + LOG.exception("Failed to establish authenticated ssh" + " connection to %s@%s after %d attempts", + self.username, self.host, attempts) + raise exceptions.SSHTimeout(host=self.host, + user=self.username, + password=self.password) + bsleep += backoff + attempts += 1 + LOG.warning("Failed to establish authenticated ssh" + " connection to %s@%s (%s). Number attempts: %s." + " Retry after %d seconds.", + self.username, self.host, e, attempts, bsleep) + time.sleep(bsleep) + + def _is_timed_out(self, start_time): + return (time.time() - self.timeout) > start_time + + def exec_command(self, cmd): + """Execute the specified command on the server + + Note that this method is reading whole command outputs to memory, thus + shouldn't be used for large outputs. + + :param str cmd: Command to run at remote server. + :returns: data read from standard output of the command. + :raises: SSHExecCommandFailed if command returns nonzero + status. The exception contains command status stderr content. + :raises: TimeoutException if cmd doesn't end when timeout expires. + """ + ssh = self._get_ssh_connection() + transport = ssh.get_transport() + channel = transport.open_session() + channel.fileno() # Register event pipe + channel.exec_command(cmd) + channel.shutdown_write() + out_data = [] + err_data = [] + poll = select.poll() + poll.register(channel, select.POLLIN) + start_time = time.time() + + while True: + ready = poll.poll(self.channel_timeout) + if not any(ready): + if not self._is_timed_out(start_time): + continue + raise exceptions.TimeoutException( + "Command: '{0}' executed on host '{1}'.".format( + cmd, self.host)) + if not ready[0]: # If there is nothing to read. + continue + out_chunk = err_chunk = None + if channel.recv_ready(): + out_chunk = channel.recv(self.buf_size) + out_data += out_chunk, + if channel.recv_stderr_ready(): + err_chunk = channel.recv_stderr(self.buf_size) + err_data += err_chunk, + if channel.closed and not err_chunk and not out_chunk: + break + exit_status = channel.recv_exit_status() + if 0 != exit_status: + raise exceptions.SSHExecCommandFailed( + command=cmd, exit_status=exit_status, + strerror=''.join(err_data)) + return ''.join(out_data) + + def test_connection_auth(self): + """Raises an exception when we can not connect to server via ssh.""" + connection = self._get_ssh_connection() + connection.close() diff --git a/tempest_lib/exceptions.py b/tempest_lib/exceptions.py index 07952b6..90b0aca 100644 --- a/tempest_lib/exceptions.py +++ b/tempest_lib/exceptions.py @@ -155,3 +155,14 @@ class EndpointNotFound(TempestException): class InvalidCredentials(TempestException): message = "Invalid Credentials" + + +class SSHTimeout(TempestException): + message = ("Connection to the %(host)s via SSH timed out.\n" + "User: %(user)s, Password: %(password)s") + + +class SSHExecCommandFailed(TempestException): + """Raised when remotely executed command returns nonzero status.""" + message = ("Command '%(command)s', exit status: %(exit_status)d, " + "Error:\n%(strerror)s") diff --git a/tempest_lib/tests/test_ssh.py b/tempest_lib/tests/test_ssh.py new file mode 100644 index 0000000..ab0a198 --- /dev/null +++ b/tempest_lib/tests/test_ssh.py @@ -0,0 +1,188 @@ +# Copyright 2014 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import socket +import time + +import mock +import testtools + +from tempest_lib.common import ssh +from tempest_lib import exceptions +from tempest_lib.tests import base + + +class TestSshClient(base.TestCase): + + @mock.patch('paramiko.RSAKey.from_private_key') + @mock.patch('six.StringIO') + def test_pkey_calls_paramiko_RSAKey(self, cs_mock, rsa_mock): + cs_mock.return_value = mock.sentinel.csio + pkey = 'mykey' + ssh.Client('localhost', 'root', pkey=pkey) + rsa_mock.assert_called_once_with(mock.sentinel.csio) + cs_mock.assert_called_once_with('mykey') + rsa_mock.reset_mock() + cs_mock.reset_mock() + pkey = mock.sentinel.pkey + # Shouldn't call out to load a file from RSAKey, since + # a sentinel isn't a basestring... + ssh.Client('localhost', 'root', pkey=pkey) + self.assertEqual(0, rsa_mock.call_count) + self.assertEqual(0, cs_mock.call_count) + + def _set_ssh_connection_mocks(self): + client_mock = mock.MagicMock() + client_mock.connect.return_value = True + return (self.patch('paramiko.SSHClient'), + self.patch('paramiko.AutoAddPolicy'), + client_mock) + + def test_get_ssh_connection(self): + c_mock, aa_mock, client_mock = self._set_ssh_connection_mocks() + s_mock = self.patch('time.sleep') + + c_mock.return_value = client_mock + aa_mock.return_value = mock.sentinel.aa + + # Test normal case for successful connection on first try + client = ssh.Client('localhost', 'root', timeout=2) + client._get_ssh_connection(sleep=1) + + aa_mock.assert_called_once_with() + client_mock.set_missing_host_key_policy.assert_called_once_with( + mock.sentinel.aa) + expected_connect = [mock.call( + 'localhost', + username='root', + pkey=None, + key_filename=None, + look_for_keys=False, + timeout=10.0, + password=None + )] + self.assertEqual(expected_connect, client_mock.connect.mock_calls) + self.assertEqual(0, s_mock.call_count) + + def test_get_ssh_connection_two_attemps(self): + c_mock, aa_mock, client_mock = self._set_ssh_connection_mocks() + + c_mock.return_value = client_mock + client_mock.connect.side_effect = [ + socket.error, + mock.MagicMock() + ] + + client = ssh.Client('localhost', 'root', timeout=1) + start_time = int(time.time()) + client._get_ssh_connection(sleep=1) + end_time = int(time.time()) + self.assertLess((end_time - start_time), 4) + self.assertGreater((end_time - start_time), 1) + + def test_get_ssh_connection_timeout(self): + c_mock, aa_mock, client_mock = self._set_ssh_connection_mocks() + + c_mock.return_value = client_mock + client_mock.connect.side_effect = [ + socket.error, + socket.error, + socket.error, + ] + + client = ssh.Client('localhost', 'root', timeout=2) + start_time = int(time.time()) + with testtools.ExpectedException(exceptions.SSHTimeout): + client._get_ssh_connection() + end_time = int(time.time()) + self.assertLess((end_time - start_time), 5) + self.assertGreaterEqual((end_time - start_time), 2) + + def test_exec_command(self): + gsc_mock = self.patch('tempest_lib.common.ssh.Client.' + '_get_ssh_connection') + ito_mock = self.patch('tempest_lib.common.ssh.Client._is_timed_out') + select_mock = self.patch('select.poll') + + client_mock = mock.MagicMock() + tran_mock = mock.MagicMock() + chan_mock = mock.MagicMock() + poll_mock = mock.MagicMock() + + def reset_mocks(): + gsc_mock.reset_mock() + ito_mock.reset_mock() + select_mock.reset_mock() + poll_mock.reset_mock() + client_mock.reset_mock() + tran_mock.reset_mock() + chan_mock.reset_mock() + + select_mock.return_value = poll_mock + gsc_mock.return_value = client_mock + ito_mock.return_value = True + client_mock.get_transport.return_value = tran_mock + tran_mock.open_session.return_value = chan_mock + poll_mock.poll.side_effect = [ + [0, 0, 0] + ] + + # Test for a timeout condition immediately raised + client = ssh.Client('localhost', 'root', timeout=2) + with testtools.ExpectedException(exceptions.TimeoutException): + client.exec_command("test") + + chan_mock.fileno.assert_called_once_with() + chan_mock.exec_command.assert_called_once_with("test") + chan_mock.shutdown_write.assert_called_once_with() + + SELECT_POLLIN = 1 + poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN) + poll_mock.poll.assert_called_once_with(10) + + # Test for proper reading of STDOUT and STDERROR and closing + # of all file descriptors. + + reset_mocks() + + select_mock.return_value = poll_mock + gsc_mock.return_value = client_mock + ito_mock.return_value = False + client_mock.get_transport.return_value = tran_mock + tran_mock.open_session.return_value = chan_mock + poll_mock.poll.side_effect = [ + [1, 0, 0] + ] + closed_prop = mock.PropertyMock(return_value=True) + type(chan_mock).closed = closed_prop + chan_mock.recv_exit_status.return_value = 0 + chan_mock.recv.return_value = '' + chan_mock.recv_stderr.return_value = '' + + client = ssh.Client('localhost', 'root', timeout=2) + client.exec_command("test") + + chan_mock.fileno.assert_called_once_with() + chan_mock.exec_command.assert_called_once_with("test") + chan_mock.shutdown_write.assert_called_once_with() + + SELECT_POLLIN = 1 + poll_mock.register.assert_called_once_with(chan_mock, SELECT_POLLIN) + poll_mock.poll.assert_called_once_with(10) + chan_mock.recv_ready.assert_called_once_with() + chan_mock.recv.assert_called_once_with(1024) + chan_mock.recv_stderr_ready.assert_called_once_with() + chan_mock.recv_stderr.assert_called_once_with(1024) + chan_mock.recv_exit_status.assert_called_once_with() + closed_prop.assert_called_once_with() -- cgit v1.2.1